mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Salvador tests
This commit is contained in:
parent
6c0597e7ea
commit
aade27011a
57 changed files with 29210 additions and 0 deletions
98
salvador/cams.py
Normal file
98
salvador/cams.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch import topk
|
||||
import torch.nn.functional as F
|
||||
from torch import topk
|
||||
import cv2
|
||||
from torchvision import transforms
|
||||
import os
|
||||
|
||||
class SaveFeatures():
|
||||
features=None
|
||||
def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
|
||||
def hook_fn(self, module, input, output): self.features = ((output.cpu()).data).numpy()
|
||||
def remove(self): self.hook.remove()
|
||||
|
||||
def getCAM(feature_conv, weight_fc, class_idx):
|
||||
_, nc, h, w = feature_conv.shape
|
||||
cam = weight_fc[class_idx].dot(feature_conv.reshape((nc, h*w)))
|
||||
cam = cam.reshape(h, w)
|
||||
cam = cam - np.min(cam)
|
||||
cam_img = cam / np.max(cam)
|
||||
# cam_img = np.uint8(255 * cam_img)
|
||||
return cam_img
|
||||
|
||||
def main(cam):
|
||||
device = 'cuda:0'
|
||||
model_name = 'resnet50'
|
||||
root = 'NEW_SS'
|
||||
|
||||
os.makedirs(os.path.join(root + '_CAM', 'OK'), exist_ok=True)
|
||||
os.makedirs(os.path.join(root + '_CAM', 'NOK'), exist_ok=True)
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
dataset = torchvision.datasets.ImageFolder(
|
||||
root=root, transform=train_transform,
|
||||
)
|
||||
|
||||
loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
||||
|
||||
model = torchvision.models.__dict__[model_name](pretrained=False)
|
||||
model.fc = torch.nn.Linear(model.fc.in_features, 2)
|
||||
|
||||
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
weight_softmax_params = list(model._modules.get('fc').parameters())
|
||||
weight_softmax = np.squeeze(weight_softmax_params[0].cpu().data.numpy())
|
||||
|
||||
final_layer = model._modules.get('layer4')
|
||||
|
||||
activated_features = SaveFeatures(final_layer)
|
||||
|
||||
for i, (img, target ) in enumerate(loader):
|
||||
img = img.to(device)
|
||||
prediction = model(img)
|
||||
pred_probabilities = F.softmax(prediction, dim=1).data.squeeze()
|
||||
class_idx = topk(pred_probabilities,1)[1].int()
|
||||
# if target.item() != class_idx:
|
||||
# print(dataset.imgs[i][0])
|
||||
|
||||
if cam:
|
||||
overlay = getCAM(activated_features.features, weight_softmax, class_idx )
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
||||
import PIL
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
img = ToPILImage()(overlay).resize(size=(1280, 1024), resample=PIL.Image.BILINEAR)
|
||||
img.save('heat-pil.jpg')
|
||||
|
||||
|
||||
img = cv2.imread(dataset.imgs[i][0])
|
||||
height, width, _ = img.shape
|
||||
overlay = cv2.resize(overlay, (width, height))
|
||||
heatmap = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
|
||||
cv2.imwrite('heat-cv2.jpg', heatmap)
|
||||
|
||||
img = cv2.imread(dataset.imgs[i][0])
|
||||
height, width, _ = img.shape
|
||||
overlay = cv2.resize(overlay, (width, height))
|
||||
heatmap = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
|
||||
result = heatmap * 0.3 + img * 0.5
|
||||
|
||||
clss = dataset.imgs[i][0].split(os.sep)[1]
|
||||
name = dataset.imgs[i][0].split(os.sep)[2].split('.')[0]
|
||||
cv2.imwrite(os.path.join(root+"_CAM", clss, name + '.jpg'), result)
|
||||
print(f'{os.path.join(root+"_CAM", clss, name + ".jpg")} saved')
|
||||
|
||||
activated_features.remove()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(cam=True)
|
Loading…
Add table
Add a link
Reference in a new issue