smart_augmentation/Old/salvador/cams.py

99 lines
3.3 KiB
Python
Raw Normal View History

2019-12-12 16:38:13 -05:00
import torch
import numpy as np
import torchvision
from PIL import Image
from torch import topk
import torch.nn.functional as F
from torch import topk
import cv2
from torchvision import transforms
import os
class SaveFeatures():
features=None
def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output): self.features = ((output.cpu()).data).numpy()
def remove(self): self.hook.remove()
def getCAM(feature_conv, weight_fc, class_idx):
_, nc, h, w = feature_conv.shape
cam = weight_fc[class_idx].dot(feature_conv.reshape((nc, h*w)))
cam = cam.reshape(h, w)
cam = cam - np.min(cam)
cam_img = cam / np.max(cam)
# cam_img = np.uint8(255 * cam_img)
return cam_img
def main(cam):
device = 'cuda:0'
model_name = 'resnet50'
root = 'NEW_SS'
os.makedirs(os.path.join(root + '_CAM', 'OK'), exist_ok=True)
os.makedirs(os.path.join(root + '_CAM', 'NOK'), exist_ok=True)
train_transform = transforms.Compose([
transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(
root=root, transform=train_transform,
)
loader = torch.utils.data.DataLoader(dataset, batch_size=1)
model = torchvision.models.__dict__[model_name](pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
model = model.to(device)
model.eval()
weight_softmax_params = list(model._modules.get('fc').parameters())
weight_softmax = np.squeeze(weight_softmax_params[0].cpu().data.numpy())
final_layer = model._modules.get('layer4')
activated_features = SaveFeatures(final_layer)
for i, (img, target ) in enumerate(loader):
img = img.to(device)
prediction = model(img)
pred_probabilities = F.softmax(prediction, dim=1).data.squeeze()
class_idx = topk(pred_probabilities,1)[1].int()
# if target.item() != class_idx:
# print(dataset.imgs[i][0])
if cam:
overlay = getCAM(activated_features.features, weight_softmax, class_idx )
import ipdb; ipdb.set_trace()
import PIL
from torchvision.transforms import ToPILImage
img = ToPILImage()(overlay).resize(size=(1280, 1024), resample=PIL.Image.BILINEAR)
img.save('heat-pil.jpg')
img = cv2.imread(dataset.imgs[i][0])
height, width, _ = img.shape
overlay = cv2.resize(overlay, (width, height))
heatmap = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
cv2.imwrite('heat-cv2.jpg', heatmap)
img = cv2.imread(dataset.imgs[i][0])
height, width, _ = img.shape
overlay = cv2.resize(overlay, (width, height))
heatmap = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
result = heatmap * 0.3 + img * 0.5
clss = dataset.imgs[i][0].split(os.sep)[1]
name = dataset.imgs[i][0].split(os.sep)[2].split('.')[0]
cv2.imwrite(os.path.join(root+"_CAM", clss, name + '.jpg'), result)
print(f'{os.path.join(root+"_CAM", clss, name + ".jpg")} saved')
activated_features.remove()
if __name__ == "__main__":
main(cam=True)