mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
98 lines
3.3 KiB
Python
98 lines
3.3 KiB
Python
import torch
|
|
import numpy as np
|
|
import torchvision
|
|
from PIL import Image
|
|
from torch import topk
|
|
import torch.nn.functional as F
|
|
from torch import topk
|
|
import cv2
|
|
from torchvision import transforms
|
|
import os
|
|
|
|
class SaveFeatures():
|
|
features=None
|
|
def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
|
|
def hook_fn(self, module, input, output): self.features = ((output.cpu()).data).numpy()
|
|
def remove(self): self.hook.remove()
|
|
|
|
def getCAM(feature_conv, weight_fc, class_idx):
|
|
_, nc, h, w = feature_conv.shape
|
|
cam = weight_fc[class_idx].dot(feature_conv.reshape((nc, h*w)))
|
|
cam = cam.reshape(h, w)
|
|
cam = cam - np.min(cam)
|
|
cam_img = cam / np.max(cam)
|
|
# cam_img = np.uint8(255 * cam_img)
|
|
return cam_img
|
|
|
|
def main(cam):
|
|
device = 'cuda:0'
|
|
model_name = 'resnet50'
|
|
root = 'NEW_SS'
|
|
|
|
os.makedirs(os.path.join(root + '_CAM', 'OK'), exist_ok=True)
|
|
os.makedirs(os.path.join(root + '_CAM', 'NOK'), exist_ok=True)
|
|
|
|
train_transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
dataset = torchvision.datasets.ImageFolder(
|
|
root=root, transform=train_transform,
|
|
)
|
|
|
|
loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
|
|
|
model = torchvision.models.__dict__[model_name](pretrained=False)
|
|
model.fc = torch.nn.Linear(model.fc.in_features, 2)
|
|
|
|
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
weight_softmax_params = list(model._modules.get('fc').parameters())
|
|
weight_softmax = np.squeeze(weight_softmax_params[0].cpu().data.numpy())
|
|
|
|
final_layer = model._modules.get('layer4')
|
|
|
|
activated_features = SaveFeatures(final_layer)
|
|
|
|
for i, (img, target ) in enumerate(loader):
|
|
img = img.to(device)
|
|
prediction = model(img)
|
|
pred_probabilities = F.softmax(prediction, dim=1).data.squeeze()
|
|
class_idx = topk(pred_probabilities,1)[1].int()
|
|
# if target.item() != class_idx:
|
|
# print(dataset.imgs[i][0])
|
|
|
|
if cam:
|
|
overlay = getCAM(activated_features.features, weight_softmax, class_idx )
|
|
|
|
import ipdb; ipdb.set_trace()
|
|
import PIL
|
|
from torchvision.transforms import ToPILImage
|
|
|
|
img = ToPILImage()(overlay).resize(size=(1280, 1024), resample=PIL.Image.BILINEAR)
|
|
img.save('heat-pil.jpg')
|
|
|
|
|
|
img = cv2.imread(dataset.imgs[i][0])
|
|
height, width, _ = img.shape
|
|
overlay = cv2.resize(overlay, (width, height))
|
|
heatmap = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
|
|
cv2.imwrite('heat-cv2.jpg', heatmap)
|
|
|
|
img = cv2.imread(dataset.imgs[i][0])
|
|
height, width, _ = img.shape
|
|
overlay = cv2.resize(overlay, (width, height))
|
|
heatmap = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
|
|
result = heatmap * 0.3 + img * 0.5
|
|
|
|
clss = dataset.imgs[i][0].split(os.sep)[1]
|
|
name = dataset.imgs[i][0].split(os.sep)[2].split('.')[0]
|
|
cv2.imwrite(os.path.join(root+"_CAM", clss, name + '.jpg'), result)
|
|
print(f'{os.path.join(root+"_CAM", clss, name + ".jpg")} saved')
|
|
|
|
activated_features.remove()
|
|
|
|
if __name__ == "__main__":
|
|
main(cam=True)
|