Rangement

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-28 16:46:37 -05:00
parent ca3367d19f
commit 4166922c34
453 changed files with 9797 additions and 7 deletions

102
Old/salvador/grad_cam.py Executable file
View file

@ -0,0 +1,102 @@
import torch
import numpy as np
import torchvision
from PIL import Image
from torch import topk
from torch import nn
import torch.nn.functional as F
from torch import topk
import cv2
from torchvision import transforms
import os
class Lambda(nn.Module):
"Create a layer that simply calls `func` with `x`"
def __init__(self, func):
super().__init__()
self.func=func
def forward(self, x): return self.func(x)
class SaveFeatures():
activations, gradients = None, None
def __init__(self, m):
self.forward = m.register_forward_hook(self.forward_hook_fn)
self.backward = m.register_backward_hook(self.backward_hook_fn)
def forward_hook_fn(self, module, input, output):
self.activations = output.cpu().detach()
def backward_hook_fn(self, module, grad_input, grad_output):
self.gradients = grad_output[0].cpu().detach()
def remove(self):
self.forward.remove()
self.backward.remove()
def main(cam):
device = 'cuda:0'
model_name = 'resnet50'
root = '/mnt/md0/data/cifar10/tmp/cifar/train'
_root = 'cifar'
os.makedirs(os.path.join(_root + '_CAM'), exist_ok=True)
os.makedirs(os.path.join(_root + '_CAM'), exist_ok=True)
train_transform = transforms.Compose([
transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(
root=root, transform=train_transform,
)
loader = torch.utils.data.DataLoader(dataset, batch_size=1)
model = torchvision.models.__dict__[model_name](pretrained=True)
flat = list(model.children())
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(loader.dataset.classes)))
model = nn.Sequential(body, head)
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
model = model.to(device)
model.eval()
activated_features = SaveFeatures(model[0])
for i, (img, target ) in enumerate(loader):
img = img.to(device)
pred = model(img)
import ipdb; ipdb.set_trace()
# get the gradient of the output with respect to the parameters of the model
pred[:, target.item()].backward()
# import ipdb; ipdb.set_trace()
# pull the gradients out of the model
gradients = activated_features.gradients[0]
pooled_gradients = gradients.mean(1).mean(1)
# get the activations of the last convolutional layer
activations = activated_features.activations[0]
heatmap = F.relu(((activations*pooled_gradients[...,None,None])).sum(0))
heatmap /= torch.max(heatmap)
heatmap = heatmap.numpy()
image = cv2.imread(dataset.imgs[i][0])
heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# superimposed_img = heatmap * 0.3 + image * 0.5
superimposed_img = heatmap
clss = dataset.imgs[i][0].split(os.sep)[1]
name = dataset.imgs[i][0].split(os.sep)[2].split('.')[0]
cv2.imwrite(os.path.join(_root+"_CAM", name + '.jpg'), superimposed_img)
print(f'{os.path.join(_root+"_CAM", name + ".jpg")} saved')
activated_features.remove()
if __name__ == "__main__":
main(cam=True)