mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Merge branch 'master' of http://frd-git/scm/axon/smart_augmentation
This commit is contained in:
commit
e75fb96716
57 changed files with 29210 additions and 0 deletions
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
98
salvador/cams.py
Normal file
98
salvador/cams.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
from torch import topk
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import topk
|
||||||
|
import cv2
|
||||||
|
from torchvision import transforms
|
||||||
|
import os
|
||||||
|
|
||||||
|
class SaveFeatures():
|
||||||
|
features=None
|
||||||
|
def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
|
||||||
|
def hook_fn(self, module, input, output): self.features = ((output.cpu()).data).numpy()
|
||||||
|
def remove(self): self.hook.remove()
|
||||||
|
|
||||||
|
def getCAM(feature_conv, weight_fc, class_idx):
|
||||||
|
_, nc, h, w = feature_conv.shape
|
||||||
|
cam = weight_fc[class_idx].dot(feature_conv.reshape((nc, h*w)))
|
||||||
|
cam = cam.reshape(h, w)
|
||||||
|
cam = cam - np.min(cam)
|
||||||
|
cam_img = cam / np.max(cam)
|
||||||
|
# cam_img = np.uint8(255 * cam_img)
|
||||||
|
return cam_img
|
||||||
|
|
||||||
|
def main(cam):
|
||||||
|
device = 'cuda:0'
|
||||||
|
model_name = 'resnet50'
|
||||||
|
root = 'NEW_SS'
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(root + '_CAM', 'OK'), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(root + '_CAM', 'NOK'), exist_ok=True)
|
||||||
|
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
|
||||||
|
dataset = torchvision.datasets.ImageFolder(
|
||||||
|
root=root, transform=train_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
||||||
|
|
||||||
|
model = torchvision.models.__dict__[model_name](pretrained=False)
|
||||||
|
model.fc = torch.nn.Linear(model.fc.in_features, 2)
|
||||||
|
|
||||||
|
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
weight_softmax_params = list(model._modules.get('fc').parameters())
|
||||||
|
weight_softmax = np.squeeze(weight_softmax_params[0].cpu().data.numpy())
|
||||||
|
|
||||||
|
final_layer = model._modules.get('layer4')
|
||||||
|
|
||||||
|
activated_features = SaveFeatures(final_layer)
|
||||||
|
|
||||||
|
for i, (img, target ) in enumerate(loader):
|
||||||
|
img = img.to(device)
|
||||||
|
prediction = model(img)
|
||||||
|
pred_probabilities = F.softmax(prediction, dim=1).data.squeeze()
|
||||||
|
class_idx = topk(pred_probabilities,1)[1].int()
|
||||||
|
# if target.item() != class_idx:
|
||||||
|
# print(dataset.imgs[i][0])
|
||||||
|
|
||||||
|
if cam:
|
||||||
|
overlay = getCAM(activated_features.features, weight_softmax, class_idx )
|
||||||
|
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
import PIL
|
||||||
|
from torchvision.transforms import ToPILImage
|
||||||
|
|
||||||
|
img = ToPILImage()(overlay).resize(size=(1280, 1024), resample=PIL.Image.BILINEAR)
|
||||||
|
img.save('heat-pil.jpg')
|
||||||
|
|
||||||
|
|
||||||
|
img = cv2.imread(dataset.imgs[i][0])
|
||||||
|
height, width, _ = img.shape
|
||||||
|
overlay = cv2.resize(overlay, (width, height))
|
||||||
|
heatmap = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
|
||||||
|
cv2.imwrite('heat-cv2.jpg', heatmap)
|
||||||
|
|
||||||
|
img = cv2.imread(dataset.imgs[i][0])
|
||||||
|
height, width, _ = img.shape
|
||||||
|
overlay = cv2.resize(overlay, (width, height))
|
||||||
|
heatmap = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
|
||||||
|
result = heatmap * 0.3 + img * 0.5
|
||||||
|
|
||||||
|
clss = dataset.imgs[i][0].split(os.sep)[1]
|
||||||
|
name = dataset.imgs[i][0].split(os.sep)[2].split('.')[0]
|
||||||
|
cv2.imwrite(os.path.join(root+"_CAM", clss, name + '.jpg'), result)
|
||||||
|
print(f'{os.path.join(root+"_CAM", clss, name + ".jpg")} saved')
|
||||||
|
|
||||||
|
activated_features.remove()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(cam=True)
|
BIN
salvador/checkpoint.pt
Normal file
BIN
salvador/checkpoint.pt
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054503377.tif
Normal file
BIN
salvador/data/NOK/nok054503377.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054503736.tif
Normal file
BIN
salvador/data/NOK/nok054503736.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054504079.tif
Normal file
BIN
salvador/data/NOK/nok054504079.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054506185.tif
Normal file
BIN
salvador/data/NOK/nok054506185.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054507230.tif
Normal file
BIN
salvador/data/NOK/nok054507230.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054507589.tif
Normal file
BIN
salvador/data/NOK/nok054507589.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054507932.tif
Normal file
BIN
salvador/data/NOK/nok054507932.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054508634.tif
Normal file
BIN
salvador/data/NOK/nok054508634.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054510382.tif
Normal file
BIN
salvador/data/NOK/nok054510382.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054510740.tif
Normal file
BIN
salvador/data/NOK/nok054510740.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054513533.tif
Normal file
BIN
salvador/data/NOK/nok054513533.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054513892.tif
Normal file
BIN
salvador/data/NOK/nok054513892.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054519508.tif
Normal file
BIN
salvador/data/NOK/nok054519508.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054521957.tif
Normal file
BIN
salvador/data/NOK/nok054521957.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054522659.tif
Normal file
BIN
salvador/data/NOK/nok054522659.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054527916.tif
Normal file
BIN
salvador/data/NOK/nok054527916.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054531083.tif
Normal file
BIN
salvador/data/NOK/nok054531083.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054532846.tif
Normal file
BIN
salvador/data/NOK/nok054532846.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054533891.tif
Normal file
BIN
salvador/data/NOK/nok054533891.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok054538118.tif
Normal file
BIN
salvador/data/NOK/nok054538118.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok_054501630.tif
Normal file
BIN
salvador/data/NOK/nok_054501630.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok_054502332.tif
Normal file
BIN
salvador/data/NOK/nok_054502332.tif
Normal file
Binary file not shown.
BIN
salvador/data/NOK/nok_054509336.tif
Normal file
BIN
salvador/data/NOK/nok_054509336.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054501989.tif
Normal file
BIN
salvador/data/OK/054501989.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054502675.tif
Normal file
BIN
salvador/data/OK/054502675.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054503034.tif
Normal file
BIN
salvador/data/OK/054503034.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054504438.tif
Normal file
BIN
salvador/data/OK/054504438.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054504781.tif
Normal file
BIN
salvador/data/OK/054504781.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054505124.tif
Normal file
BIN
salvador/data/OK/054505124.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054505483.tif
Normal file
BIN
salvador/data/OK/054505483.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054505826.tif
Normal file
BIN
salvador/data/OK/054505826.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054506528.tif
Normal file
BIN
salvador/data/OK/054506528.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054506887.tif
Normal file
BIN
salvador/data/OK/054506887.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054508276.tif
Normal file
BIN
salvador/data/OK/054508276.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054508978.tif
Normal file
BIN
salvador/data/OK/054508978.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054509680.tif
Normal file
BIN
salvador/data/OK/054509680.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054510038.tif
Normal file
BIN
salvador/data/OK/054510038.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054511084.tif
Normal file
BIN
salvador/data/OK/054511084.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054511427.tif
Normal file
BIN
salvador/data/OK/054511427.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054511786.tif
Normal file
BIN
salvador/data/OK/054511786.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054512129.tif
Normal file
BIN
salvador/data/OK/054512129.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054512488.tif
Normal file
BIN
salvador/data/OK/054512488.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054512831.tif
Normal file
BIN
salvador/data/OK/054512831.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054513190.tif
Normal file
BIN
salvador/data/OK/054513190.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054514235.tif
Normal file
BIN
salvador/data/OK/054514235.tif
Normal file
Binary file not shown.
BIN
salvador/data/OK/054514578.tif
Normal file
BIN
salvador/data/OK/054514578.tif
Normal file
Binary file not shown.
1136
salvador/dataug.py
Executable file
1136
salvador/dataug.py
Executable file
File diff suppressed because it is too large
Load diff
314
salvador/dataug_utils.py
Normal file
314
salvador/dataug_utils.py
Normal file
|
@ -0,0 +1,314 @@
|
||||||
|
import numpy as np
|
||||||
|
import json, math, time, os
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import copy
|
||||||
|
import gc
|
||||||
|
|
||||||
|
from torchviz import make_dot
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
class timer():
|
||||||
|
def __init__(self):
|
||||||
|
self._start_time=time.time()
|
||||||
|
def exec_time(self):
|
||||||
|
end = time.time()
|
||||||
|
res = end-self._start_time
|
||||||
|
self._start_time=end
|
||||||
|
return res
|
||||||
|
|
||||||
|
def print_graph(PyTorch_obj, fig_name='graph'):
|
||||||
|
graph=make_dot(PyTorch_obj) #Loss give the whole graph
|
||||||
|
graph.format = 'svg' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||||
|
graph.render(fig_name)
|
||||||
|
|
||||||
|
def plot_res(log, fig_name='res', param_names=None):
|
||||||
|
|
||||||
|
epochs = [x["epoch"] for x in log]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(ncols=3, figsize=(15, 3))
|
||||||
|
|
||||||
|
ax[0].set_title('Loss')
|
||||||
|
ax[0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
||||||
|
ax[0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||||||
|
ax[0].legend()
|
||||||
|
|
||||||
|
ax[1].set_title('Acc')
|
||||||
|
ax[1].plot(epochs,[x["acc"] for x in log])
|
||||||
|
|
||||||
|
if log[0]["param"]!= None:
|
||||||
|
if isinstance(log[0]["param"],float):
|
||||||
|
ax[2].set_title('Mag')
|
||||||
|
ax[2].plot(epochs,[x["param"] for x in log], label='Mag')
|
||||||
|
ax[2].legend()
|
||||||
|
else :
|
||||||
|
ax[2].set_title('Prob')
|
||||||
|
#for idx, _ in enumerate(log[0]["param"]):
|
||||||
|
#ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx))
|
||||||
|
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
ax[2].stackplot(epochs, proba, labels=param_names)
|
||||||
|
ax[2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
|
|
||||||
|
fig_name = fig_name.replace('.',',')
|
||||||
|
plt.savefig(fig_name)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def plot_resV2(log, fig_name='res', param_names=None):
|
||||||
|
|
||||||
|
epochs = [x["epoch"] for x in log]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(30, 15))
|
||||||
|
|
||||||
|
ax[0, 0].set_title('Loss')
|
||||||
|
ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
||||||
|
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||||||
|
ax[0, 0].legend()
|
||||||
|
|
||||||
|
ax[1, 0].set_title('Acc')
|
||||||
|
ax[1, 0].plot(epochs,[x["acc"] for x in log])
|
||||||
|
|
||||||
|
if log[0]["param"]!= None:
|
||||||
|
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
#proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
|
||||||
|
ax[0, 1].set_title('Prob =f(epoch)')
|
||||||
|
ax[0, 1].stackplot(epochs, proba, labels=param_names)
|
||||||
|
#ax[0, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
|
ax[1, 1].set_title('Prob =f(TF)')
|
||||||
|
mean = np.mean(proba, axis=1)
|
||||||
|
std = np.std(proba, axis=1)
|
||||||
|
ax[1, 1].bar(param_names, mean, yerr=std)
|
||||||
|
plt.sca(ax[1, 1]), plt.xticks(rotation=90)
|
||||||
|
|
||||||
|
ax[0, 2].set_title('Mag =f(epoch)')
|
||||||
|
ax[0, 2].stackplot(epochs, mag, labels=param_names)
|
||||||
|
ax[0, 2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
|
ax[1, 2].set_title('Mag =f(TF)')
|
||||||
|
mean = np.mean(mag, axis=1)
|
||||||
|
std = np.std(mag, axis=1)
|
||||||
|
ax[1, 2].bar(param_names, mean, yerr=std)
|
||||||
|
plt.sca(ax[1, 2]), plt.xticks(rotation=90)
|
||||||
|
|
||||||
|
|
||||||
|
fig_name = fig_name.replace('.',',')
|
||||||
|
plt.savefig(fig_name, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def plot_compare(filenames, fig_name='res'):
|
||||||
|
|
||||||
|
all_data=[]
|
||||||
|
legend=""
|
||||||
|
for idx, file in enumerate(filenames):
|
||||||
|
legend+=str(idx)+'-'+file+'\n'
|
||||||
|
with open(file) as json_file:
|
||||||
|
data = json.load(json_file)
|
||||||
|
all_data.append(data)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
|
||||||
|
|
||||||
|
for data_idx, log in enumerate(all_data):
|
||||||
|
log=log['Log']
|
||||||
|
epochs = [x["epoch"] for x in log]
|
||||||
|
|
||||||
|
ax[0].plot(epochs,[x["train_loss"] for x in log], label=str(data_idx)+'-Train')
|
||||||
|
ax[0].plot(epochs,[x["val_loss"] for x in log], label=str(data_idx)+'-Val')
|
||||||
|
|
||||||
|
ax[1].plot(epochs,[x["acc"] for x in log], label=str(data_idx))
|
||||||
|
#ax[1].text(x=0.5,y=0,s=str(data_idx)+'-'+filenames[data_idx], transform=ax[1].transAxes)
|
||||||
|
|
||||||
|
if log[0]["param"]!= None:
|
||||||
|
if isinstance(log[0]["param"],float):
|
||||||
|
ax[2].plot(epochs,[x["param"] for x in log], label=str(data_idx)+'-Mag')
|
||||||
|
|
||||||
|
else :
|
||||||
|
for idx, _ in enumerate(log[0]["param"]):
|
||||||
|
ax[2].plot(epochs,[x["param"][idx] for x in log], label=str(data_idx)+'-P'+str(idx))
|
||||||
|
|
||||||
|
fig.suptitle(legend)
|
||||||
|
ax[0].set_title('Loss')
|
||||||
|
ax[1].set_title('Acc')
|
||||||
|
ax[2].set_title('Param')
|
||||||
|
for a in ax: a.legend()
|
||||||
|
|
||||||
|
fig_name = fig_name.replace('.',',')
|
||||||
|
plt.savefig(fig_name, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def plot_res_compare(filenames, fig_name='res'):
|
||||||
|
|
||||||
|
all_data=[]
|
||||||
|
#legend=""
|
||||||
|
for idx, file in enumerate(filenames):
|
||||||
|
#legend+=str(idx)+'-'+file+'\n'
|
||||||
|
with open(file) as json_file:
|
||||||
|
data = json.load(json_file)
|
||||||
|
all_data.append(data)
|
||||||
|
|
||||||
|
n_tf = [len(x["Param_names"]) for x in all_data]
|
||||||
|
acc = [x["Accuracy"] for x in all_data]
|
||||||
|
time = [x["Time"][0] for x in all_data]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
|
||||||
|
|
||||||
|
ax[0].plot(n_tf, acc)
|
||||||
|
ax[1].plot(n_tf, time)
|
||||||
|
|
||||||
|
ax[0].set_title('Acc')
|
||||||
|
ax[1].set_title('Time')
|
||||||
|
#for a in ax: a.legend()
|
||||||
|
|
||||||
|
fig_name = fig_name.replace('.',',')
|
||||||
|
plt.savefig(fig_name, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def plot_TF_res(log, tf_names, fig_name='res'):
|
||||||
|
|
||||||
|
mean = np.mean([x["param"] for x in log], axis=0)
|
||||||
|
std = np.std([x["param"] for x in log], axis=0)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(1, 1, figsize=(30, 8), sharey=True)
|
||||||
|
ax.bar(tf_names, mean, yerr=std)
|
||||||
|
#ax.bar(tf_names, log[-1]["param"])
|
||||||
|
|
||||||
|
fig_name = fig_name.replace('.',',')
|
||||||
|
plt.savefig(fig_name, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def viz_sample_data(imgs, labels, fig_name='data_sample'):
|
||||||
|
|
||||||
|
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
|
||||||
|
|
||||||
|
plt.figure(figsize=(10,10))
|
||||||
|
for i in range(25):
|
||||||
|
plt.subplot(5,5,i+1)
|
||||||
|
plt.xticks([])
|
||||||
|
plt.yticks([])
|
||||||
|
plt.grid(False)
|
||||||
|
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
|
||||||
|
plt.xlabel(labels[i].item())
|
||||||
|
|
||||||
|
plt.savefig(fig_name)
|
||||||
|
print("Sample saved :", fig_name)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def model_copy(src,dst, patch_copy=True, copy_grad=True):
|
||||||
|
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
|
||||||
|
|
||||||
|
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
|
||||||
|
|
||||||
|
if patch_copy:
|
||||||
|
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
|
||||||
|
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
|
||||||
|
|
||||||
|
#Copie des gradients
|
||||||
|
if copy_grad:
|
||||||
|
for paramName, paramValue, in src.named_parameters():
|
||||||
|
for netCopyName, netCopyValue, in dst.named_parameters():
|
||||||
|
if paramName == netCopyName:
|
||||||
|
netCopyValue.grad = paramValue.grad
|
||||||
|
#netCopyValue=copy.deepcopy(paramValue)
|
||||||
|
|
||||||
|
try: #Data_augV4
|
||||||
|
dst['data_aug']._input_info = src['data_aug']._input_info
|
||||||
|
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def optim_copy(dopt, opt):
|
||||||
|
|
||||||
|
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
|
||||||
|
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
|
||||||
|
|
||||||
|
for group_idx, group in enumerate(opt.param_groups):
|
||||||
|
# print('gp idx',group_idx)
|
||||||
|
for p_idx, p in enumerate(group['params']):
|
||||||
|
opt.state[p]=dopt.state[group_idx][p_idx]
|
||||||
|
|
||||||
|
def print_torch_mem(add_info=''):
|
||||||
|
|
||||||
|
nb=0
|
||||||
|
max_size=0
|
||||||
|
for obj in gc.get_objects():
|
||||||
|
#print(type(obj))
|
||||||
|
try:
|
||||||
|
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # and len(obj.size())>1:
|
||||||
|
#print(i, type(obj), obj.size())
|
||||||
|
size = np.sum(obj.size())
|
||||||
|
if(size>max_size): max_size=size
|
||||||
|
nb+=1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
print(add_info, "-Pytroch tensor nb:",nb," / Max dim:", max_size)
|
||||||
|
|
||||||
|
#print(add_info, "-Garbage size :",len(gc.garbage))
|
||||||
|
|
||||||
|
"""Simple GPU memory report."""
|
||||||
|
|
||||||
|
mega_bytes = 1024.0 * 1024.0
|
||||||
|
string = add_info + ' memory (MB)'
|
||||||
|
string += ' | allocated: {}'.format(
|
||||||
|
torch.cuda.memory_allocated() / mega_bytes)
|
||||||
|
string += ' | max allocated: {}'.format(
|
||||||
|
torch.cuda.max_memory_allocated() / mega_bytes)
|
||||||
|
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
|
||||||
|
string += ' | max cached: {}'.format(
|
||||||
|
torch.cuda.max_memory_cached()/ mega_bytes)
|
||||||
|
print(string)
|
||||||
|
|
||||||
|
def plot_TF_influence(log, fig_name='TF_influence', param_names=None):
|
||||||
|
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
|
||||||
|
plt.figure()
|
||||||
|
|
||||||
|
mean = np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
|
||||||
|
std = np.std(proba, axis=1)*np.std(mag, axis=1)
|
||||||
|
plt.bar(param_names, mean, yerr=std)
|
||||||
|
|
||||||
|
plt.xticks(rotation=90)
|
||||||
|
fig_name = fig_name.replace('.',',')
|
||||||
|
plt.savefig(fig_name, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
||||||
|
def __init__(self, patience, end_train=1):
|
||||||
|
self.patience = patience
|
||||||
|
self.end_train = end_train
|
||||||
|
self.counter = 0
|
||||||
|
self.best_score = None
|
||||||
|
self.reached_limit = 0
|
||||||
|
|
||||||
|
def register(self, loss):
|
||||||
|
if self.best_score is None:
|
||||||
|
self.best_score = loss
|
||||||
|
elif loss > self.best_score:
|
||||||
|
self.counter += 1
|
||||||
|
#if not self.reached_limit:
|
||||||
|
print("loss no improve counter", self.counter, self.reached_limit)
|
||||||
|
else:
|
||||||
|
self.best_score = loss
|
||||||
|
self.counter = 0
|
||||||
|
def limit_reached(self):
|
||||||
|
if self.counter >= self.patience:
|
||||||
|
self.counter = 0
|
||||||
|
self.reached_limit +=1
|
||||||
|
self.best_score = None
|
||||||
|
return self.reached_limit
|
||||||
|
|
||||||
|
def end_training(self):
|
||||||
|
if self.limit_reached() >= self.end_train:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.__init__(self.patience, self.end_train)
|
102
salvador/grad_cam.py
Normal file
102
salvador/grad_cam.py
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
from torch import topk
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import topk
|
||||||
|
import cv2
|
||||||
|
from torchvision import transforms
|
||||||
|
import os
|
||||||
|
|
||||||
|
class Lambda(nn.Module):
|
||||||
|
"Create a layer that simply calls `func` with `x`"
|
||||||
|
def __init__(self, func):
|
||||||
|
super().__init__()
|
||||||
|
self.func=func
|
||||||
|
def forward(self, x): return self.func(x)
|
||||||
|
|
||||||
|
class SaveFeatures():
|
||||||
|
activations, gradients = None, None
|
||||||
|
def __init__(self, m):
|
||||||
|
self.forward = m.register_forward_hook(self.forward_hook_fn)
|
||||||
|
self.backward = m.register_backward_hook(self.backward_hook_fn)
|
||||||
|
|
||||||
|
def forward_hook_fn(self, module, input, output):
|
||||||
|
self.activations = output.cpu().detach()
|
||||||
|
|
||||||
|
def backward_hook_fn(self, module, grad_input, grad_output):
|
||||||
|
self.gradients = grad_output[0].cpu().detach()
|
||||||
|
|
||||||
|
def remove(self):
|
||||||
|
self.forward.remove()
|
||||||
|
self.backward.remove()
|
||||||
|
|
||||||
|
def main(cam):
|
||||||
|
device = 'cuda:0'
|
||||||
|
model_name = 'resnet50'
|
||||||
|
root = '/mnt/md0/data/cifar10/tmp/cifar/train'
|
||||||
|
_root = 'cifar'
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(_root + '_CAM'), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(_root + '_CAM'), exist_ok=True)
|
||||||
|
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
|
||||||
|
dataset = torchvision.datasets.ImageFolder(
|
||||||
|
root=root, transform=train_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
||||||
|
model = torchvision.models.__dict__[model_name](pretrained=True)
|
||||||
|
flat = list(model.children())
|
||||||
|
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(loader.dataset.classes)))
|
||||||
|
model = nn.Sequential(body, head)
|
||||||
|
|
||||||
|
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
activated_features = SaveFeatures(model[0])
|
||||||
|
|
||||||
|
for i, (img, target ) in enumerate(loader):
|
||||||
|
img = img.to(device)
|
||||||
|
pred = model(img)
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
# get the gradient of the output with respect to the parameters of the model
|
||||||
|
pred[:, target.item()].backward()
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
# pull the gradients out of the model
|
||||||
|
gradients = activated_features.gradients[0]
|
||||||
|
|
||||||
|
pooled_gradients = gradients.mean(1).mean(1)
|
||||||
|
|
||||||
|
# get the activations of the last convolutional layer
|
||||||
|
activations = activated_features.activations[0]
|
||||||
|
|
||||||
|
heatmap = F.relu(((activations*pooled_gradients[...,None,None])).sum(0))
|
||||||
|
heatmap /= torch.max(heatmap)
|
||||||
|
|
||||||
|
heatmap = heatmap.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
image = cv2.imread(dataset.imgs[i][0])
|
||||||
|
heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
|
||||||
|
heatmap = np.uint8(255 * heatmap)
|
||||||
|
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
||||||
|
# superimposed_img = heatmap * 0.3 + image * 0.5
|
||||||
|
superimposed_img = heatmap
|
||||||
|
|
||||||
|
clss = dataset.imgs[i][0].split(os.sep)[1]
|
||||||
|
name = dataset.imgs[i][0].split(os.sep)[2].split('.')[0]
|
||||||
|
cv2.imwrite(os.path.join(_root+"_CAM", name + '.jpg'), superimposed_img)
|
||||||
|
print(f'{os.path.join(_root+"_CAM", name + ".jpg")} saved')
|
||||||
|
|
||||||
|
activated_features.remove()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(cam=True)
|
375
salvador/train.py
Normal file
375
salvador/train.py
Normal file
|
@ -0,0 +1,375 @@
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from torch import nn
|
||||||
|
import torchvision
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import ImageEnhance
|
||||||
|
import random
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from fastprogress import master_bar, progress_bar
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
## DATA AUG ##
|
||||||
|
import higher
|
||||||
|
from dataug import *
|
||||||
|
from dataug_utils import *
|
||||||
|
tf_names = [
|
||||||
|
## Geometric TF ##
|
||||||
|
'Identity',
|
||||||
|
'FlipUD',
|
||||||
|
'FlipLR',
|
||||||
|
'Rotate',
|
||||||
|
'TranslateX',
|
||||||
|
'TranslateY',
|
||||||
|
'ShearX',
|
||||||
|
'ShearY',
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
'Contrast',
|
||||||
|
'Color',
|
||||||
|
'Brightness',
|
||||||
|
'Sharpness',
|
||||||
|
'Posterize',
|
||||||
|
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
]
|
||||||
|
|
||||||
|
class Lambda(nn.Module):
|
||||||
|
"Create a layer that simply calls `func` with `x`"
|
||||||
|
def __init__(self, func):
|
||||||
|
super().__init__()
|
||||||
|
self.func=func
|
||||||
|
def forward(self, x): return self.func(x)
|
||||||
|
|
||||||
|
class SubsetSampler(torch.utils.data.SubsetRandomSampler):
|
||||||
|
def __init__(self, indices):
|
||||||
|
super().__init__(indices)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return (self.indices[i] for i in range(len(self.indices)))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.indices)
|
||||||
|
|
||||||
|
def sharpness(img, factor):
|
||||||
|
sharpness_factor = random.uniform(1, factor)
|
||||||
|
sharp = ImageEnhance.Sharpness(img)
|
||||||
|
sharped = sharp.enhance(sharpness_factor)
|
||||||
|
return sharped
|
||||||
|
|
||||||
|
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, master_bar, Kldiv=False):
|
||||||
|
model.train()
|
||||||
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
|
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
|
||||||
|
header = 'Epoch: {}'.format(epoch)
|
||||||
|
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=master_bar):
|
||||||
|
|
||||||
|
image, target = image.to(device), target.to(device)
|
||||||
|
|
||||||
|
if not Kldiv :
|
||||||
|
output = model(image)
|
||||||
|
#output = F.log_softmax(output, dim=1)
|
||||||
|
loss = criterion(output, target) #Pas de softmax ?
|
||||||
|
|
||||||
|
else : #Consume x2 memory
|
||||||
|
model.augment(mode=False)
|
||||||
|
output = model(image)
|
||||||
|
model.augment(mode=True)
|
||||||
|
log_sup=F.log_softmax(output, dim=1)
|
||||||
|
sup_loss = F.cross_entropy(log_sup, target)
|
||||||
|
|
||||||
|
aug_output = model(image)
|
||||||
|
log_aug=F.log_softmax(aug_output, dim=1)
|
||||||
|
aug_loss=F.cross_entropy(log_aug, target)
|
||||||
|
|
||||||
|
#KL div w/ logits - Similarite predictions (distributions)
|
||||||
|
KL_loss = F.softmax(output, dim=1)*(log_sup-log_aug)
|
||||||
|
KL_loss = KL_loss.sum(dim=-1)
|
||||||
|
#KL_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
|
||||||
|
KL_loss = KL_loss.mean()
|
||||||
|
|
||||||
|
unsupp_coeff = 1
|
||||||
|
loss = sup_loss + (aug_loss + KL_loss) * unsupp_coeff
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
acc1 = utils.accuracy(output, target)[0]
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||||
|
metric_logger.update(loss=loss.item())
|
||||||
|
|
||||||
|
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||||
|
|
||||||
|
|
||||||
|
return metric_logger.loss.global_avg, confmat
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, criterion, data_loader, device):
|
||||||
|
model.eval()
|
||||||
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
|
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
|
||||||
|
header = 'Test:'
|
||||||
|
missed = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, (image, target) in metric_logger.log_every(data_loader, leave=False, header=header, parent=None):
|
||||||
|
image, target = image.to(device), target.to(device)
|
||||||
|
output = model(image)
|
||||||
|
loss = criterion(output, target)
|
||||||
|
if target.item() != output.topk(1)[1].item():
|
||||||
|
missed.append(data_loader.dataset.imgs[data_loader.sampler.indices[i]])
|
||||||
|
|
||||||
|
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||||
|
|
||||||
|
acc1 = utils.accuracy(output, target)[0]
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||||
|
metric_logger.update(loss=loss.item())
|
||||||
|
|
||||||
|
|
||||||
|
return metric_logger.loss.global_avg, missed, confmat
|
||||||
|
|
||||||
|
def get_train_valid_loader(args, augment, random_seed, valid_size=0.1, shuffle=True, num_workers=4, pin_memory=True):
|
||||||
|
"""
|
||||||
|
Utility function for loading and returning train and valid
|
||||||
|
multi-process iterators over the CIFAR-10 dataset. A sample
|
||||||
|
9x9 grid of the images can be optionally displayed.
|
||||||
|
If using CUDA, num_workers should be set to 1 and pin_memory to True.
|
||||||
|
Params
|
||||||
|
------
|
||||||
|
- data_dir: path directory to the dataset.
|
||||||
|
- batch_size: how many samples per batch to load.
|
||||||
|
- augment: whether to apply the data augmentation scheme
|
||||||
|
mentioned in the paper. Only applied on the train split.
|
||||||
|
- random_seed: fix seed for reproducibility.
|
||||||
|
- valid_size: percentage split of the training set used for
|
||||||
|
the validation set. Should be a float in the range [0, 1].
|
||||||
|
- shuffle: whether to shuffle the train/validation indices.
|
||||||
|
- show_sample: plot 9x9 sample grid of the dataset.
|
||||||
|
- num_workers: number of subprocesses to use when loading the dataset.
|
||||||
|
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
|
||||||
|
True if using GPU.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
- train_loader: training set iterator.
|
||||||
|
- valid_loader: validation set iterator.
|
||||||
|
"""
|
||||||
|
error_msg = "[!] valid_size should be in the range [0, 1]."
|
||||||
|
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg
|
||||||
|
|
||||||
|
# normalize = transforms.Normalize(
|
||||||
|
# mean=[0.4914, 0.4822, 0.4465],
|
||||||
|
# std=[0.2023, 0.1994, 0.2010],
|
||||||
|
# )
|
||||||
|
|
||||||
|
# define transforms
|
||||||
|
if augment:
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
# transforms.ColorJitter(brightness=0.3),
|
||||||
|
# transforms.Lambda(lambda img: sharpness(img, 5)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
# normalize,
|
||||||
|
])
|
||||||
|
|
||||||
|
valid_transform = transforms.Compose([
|
||||||
|
# transforms.ColorJitter(brightness=0.3),
|
||||||
|
# transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
# normalize,
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
# normalize,
|
||||||
|
])
|
||||||
|
|
||||||
|
valid_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
# normalize,
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# load the dataset
|
||||||
|
train_dataset = torchvision.datasets.ImageFolder(
|
||||||
|
root=args.data_path, transform=train_transform
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_dataset = torchvision.datasets.ImageFolder(
|
||||||
|
root=args.data_path, transform=valid_transform
|
||||||
|
)
|
||||||
|
|
||||||
|
num_train = len(train_dataset)
|
||||||
|
indices = list(range(num_train))
|
||||||
|
split = int(np.floor(valid_size * num_train))
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
#np.random.seed(random_seed)
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
|
train_idx, valid_idx = indices[split:], indices[:split]
|
||||||
|
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) if not args.test_only else SubsetSampler(train_idx)
|
||||||
|
valid_sampler = SubsetSampler(valid_idx)
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=train_sampler,
|
||||||
|
num_workers=num_workers, pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
valid_dataset, batch_size=1, sampler=valid_sampler,
|
||||||
|
num_workers=num_workers, pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
imgs = np.asarray(train_dataset.imgs)
|
||||||
|
|
||||||
|
# print('Train')
|
||||||
|
# print(imgs[train_idx])
|
||||||
|
#print('Valid')
|
||||||
|
#print(imgs[valid_idx])
|
||||||
|
|
||||||
|
tgt = [0,0]
|
||||||
|
for _, targets in train_loader:
|
||||||
|
for target in targets:
|
||||||
|
tgt[target]+=1
|
||||||
|
print("Train targets :", tgt)
|
||||||
|
|
||||||
|
tgt = [0,0]
|
||||||
|
for _, targets in valid_loader:
|
||||||
|
for target in targets:
|
||||||
|
tgt[target]+=1
|
||||||
|
print("Valid targets :", tgt)
|
||||||
|
|
||||||
|
return (train_loader, valid_loader)
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
|
||||||
|
#augment = True if not args.test_only else False
|
||||||
|
|
||||||
|
if not args.test_only and args.augment=='flip' : augment = True
|
||||||
|
else : augment = False
|
||||||
|
|
||||||
|
print("Augment", augment)
|
||||||
|
data_loader, data_loader_test = get_train_valid_loader(args=args, pin_memory=True, augment=augment,
|
||||||
|
num_workers=args.workers, valid_size=0.3, random_seed=999)
|
||||||
|
|
||||||
|
print("Creating model")
|
||||||
|
model = torchvision.models.__dict__[args.model](pretrained=True)
|
||||||
|
flat = list(model.children())
|
||||||
|
|
||||||
|
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(data_loader.dataset.classes)))
|
||||||
|
model = nn.Sequential(body, head)
|
||||||
|
|
||||||
|
Kldiv=False
|
||||||
|
if not args.test_only and (args.augment=='Rand' or args.augment=='RandKL'):
|
||||||
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
|
model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||||
|
|
||||||
|
if args.augment=='RandKL': Kldiv=True
|
||||||
|
print("Augmodel")
|
||||||
|
|
||||||
|
# model.fc = nn.Linear(model.fc.in_features, 2)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss().to(device)
|
||||||
|
|
||||||
|
# optimizer = torch.optim.SGD(
|
||||||
|
# model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
|
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
|
optimizer,
|
||||||
|
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
|
||||||
|
|
||||||
|
es = utils.EarlyStopping()
|
||||||
|
|
||||||
|
if args.test_only:
|
||||||
|
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
|
||||||
|
model = model.to(device)
|
||||||
|
print('TEST')
|
||||||
|
_, missed, _ = evaluate(model, criterion, data_loader_test, device=device)
|
||||||
|
print(missed)
|
||||||
|
print('TRAIN')
|
||||||
|
_, missed, _ = evaluate(model, criterion, data_loader, device=device)
|
||||||
|
print(missed)
|
||||||
|
return
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
print("Start training")
|
||||||
|
start_time = time.time()
|
||||||
|
mb = master_bar(range(args.epochs))
|
||||||
|
|
||||||
|
for epoch in mb:
|
||||||
|
_, train_confmat = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mb, Kldiv)
|
||||||
|
lr_scheduler.step( (epoch+1)*len(data_loader) )
|
||||||
|
val_loss, _, valid_confmat = evaluate(model, criterion, data_loader_test, device=device)
|
||||||
|
es(val_loss, model)
|
||||||
|
|
||||||
|
# print('Valid Missed')
|
||||||
|
# print(valid_missed)
|
||||||
|
|
||||||
|
# print('Train')
|
||||||
|
# print(train_confmat)
|
||||||
|
print('Valid')
|
||||||
|
print(valid_confmat)
|
||||||
|
|
||||||
|
# if es.early_stop:
|
||||||
|
# break
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print('Training time {}'.format(total_time_str))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
|
||||||
|
|
||||||
|
parser.add_argument('--data-path', default='/Salvador', help='dataset')
|
||||||
|
parser.add_argument('--model', default='resnet18', help='model') #'resnet18'
|
||||||
|
parser.add_argument('--device', default='cuda:1', help='device')
|
||||||
|
parser.add_argument('-b', '--batch-size', default=8, type=int)
|
||||||
|
parser.add_argument('--epochs', default=3, type=int, metavar='N',
|
||||||
|
help='number of total epochs to run')
|
||||||
|
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
|
||||||
|
help='number of data loading workers (default: 16)')
|
||||||
|
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
|
||||||
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||||
|
help='momentum')
|
||||||
|
parser.add_argument('--wd', '--weight-decay', default=4e-5, type=float,
|
||||||
|
metavar='W', help='weight decay (default: 1e-4)',
|
||||||
|
dest='weight_decay')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-only",
|
||||||
|
dest="test_only",
|
||||||
|
help="Only test the model",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('-a', '--augment', default='None', type=str,
|
||||||
|
metavar='N', help='Data augment',
|
||||||
|
dest='augment')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
585
salvador/train_dataug.py
Normal file
585
salvador/train_dataug.py
Normal file
|
@ -0,0 +1,585 @@
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from torch import nn
|
||||||
|
import torchvision
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import ImageEnhance
|
||||||
|
import random
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from fastprogress import master_bar, progress_bar
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
## DATA AUG ##
|
||||||
|
import higher
|
||||||
|
from dataug import *
|
||||||
|
from dataug_utils import *
|
||||||
|
tf_names = [
|
||||||
|
## Geometric TF ##
|
||||||
|
'Identity',
|
||||||
|
'FlipUD',
|
||||||
|
'FlipLR',
|
||||||
|
'Rotate',
|
||||||
|
'TranslateX',
|
||||||
|
'TranslateY',
|
||||||
|
'ShearX',
|
||||||
|
'ShearY',
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
'Contrast',
|
||||||
|
'Color',
|
||||||
|
'Brightness',
|
||||||
|
'Sharpness',
|
||||||
|
'Posterize',
|
||||||
|
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
]
|
||||||
|
|
||||||
|
def compute_vaLoss(model, dl_it, dl):
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
try:
|
||||||
|
xs, ys = next(dl_it)
|
||||||
|
except StopIteration: #Fin epoch val
|
||||||
|
dl_it = iter(dl)
|
||||||
|
xs, ys = next(dl_it)
|
||||||
|
xs, ys = xs.to(device), ys.to(device)
|
||||||
|
|
||||||
|
model.eval() #Validation sans transfornations !
|
||||||
|
|
||||||
|
return F.cross_entropy(model(xs), ys)
|
||||||
|
|
||||||
|
def model_copy(src,dst, patch_copy=True, copy_grad=True):
|
||||||
|
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
|
||||||
|
|
||||||
|
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
|
||||||
|
|
||||||
|
if patch_copy:
|
||||||
|
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
|
||||||
|
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
|
||||||
|
|
||||||
|
#Copie des gradients
|
||||||
|
if copy_grad:
|
||||||
|
for paramName, paramValue, in src.named_parameters():
|
||||||
|
for netCopyName, netCopyValue, in dst.named_parameters():
|
||||||
|
if paramName == netCopyName:
|
||||||
|
netCopyValue.grad = paramValue.grad
|
||||||
|
#netCopyValue=copy.deepcopy(paramValue)
|
||||||
|
|
||||||
|
try: #Data_augV4
|
||||||
|
dst['data_aug']._input_info = src['data_aug']._input_info
|
||||||
|
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def optim_copy(dopt, opt):
|
||||||
|
|
||||||
|
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
|
||||||
|
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
|
||||||
|
|
||||||
|
for group_idx, group in enumerate(opt.param_groups):
|
||||||
|
# print('gp idx',group_idx)
|
||||||
|
for p_idx, p in enumerate(group['params']):
|
||||||
|
opt.state[p]=dopt.state[group_idx][p_idx]
|
||||||
|
|
||||||
|
|
||||||
|
#############
|
||||||
|
|
||||||
|
class Lambda(nn.Module):
|
||||||
|
"Create a layer that simply calls `func` with `x`"
|
||||||
|
def __init__(self, func):
|
||||||
|
super().__init__()
|
||||||
|
self.func=func
|
||||||
|
def forward(self, x): return self.func(x)
|
||||||
|
|
||||||
|
class SubsetSampler(torch.utils.data.SubsetRandomSampler):
|
||||||
|
def __init__(self, indices):
|
||||||
|
super().__init__(indices)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return (self.indices[i] for i in range(len(self.indices)))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.indices)
|
||||||
|
|
||||||
|
def sharpness(img, factor):
|
||||||
|
sharpness_factor = random.uniform(1, factor)
|
||||||
|
sharp = ImageEnhance.Sharpness(img)
|
||||||
|
sharped = sharp.enhance(sharpness_factor)
|
||||||
|
return sharped
|
||||||
|
|
||||||
|
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, master_bar):
|
||||||
|
model.train()
|
||||||
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
|
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
|
||||||
|
header = 'Epoch: {}'.format(epoch)
|
||||||
|
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=master_bar):
|
||||||
|
|
||||||
|
image, target = image.to(device), target.to(device)
|
||||||
|
output = model(image)
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
acc1 = utils.accuracy(output, target)[0]
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||||
|
metric_logger.update(loss=loss.item())
|
||||||
|
|
||||||
|
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||||
|
|
||||||
|
|
||||||
|
return metric_logger.loss.global_avg, confmat
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, criterion, data_loader, device):
|
||||||
|
model.eval()
|
||||||
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
|
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
|
||||||
|
header = 'Test:'
|
||||||
|
missed = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, (image, target) in metric_logger.log_every(data_loader, leave=False, header=header, parent=None):
|
||||||
|
image, target = image.to(device), target.to(device)
|
||||||
|
output = model(image)
|
||||||
|
loss = criterion(output, target)
|
||||||
|
if target.item() != output.topk(1)[1].item():
|
||||||
|
missed.append(data_loader.dataset.imgs[data_loader.sampler.indices[i]])
|
||||||
|
|
||||||
|
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||||
|
|
||||||
|
acc1 = utils.accuracy(output, target)[0]
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||||
|
metric_logger.update(loss=loss.item())
|
||||||
|
|
||||||
|
|
||||||
|
return metric_logger.loss.global_avg, missed, confmat
|
||||||
|
|
||||||
|
def get_train_valid_loader(args, augment, random_seed, train_size=0.5, test_size=0.1, shuffle=True, num_workers=4, pin_memory=True):
|
||||||
|
"""
|
||||||
|
Utility function for loading and returning train and valid
|
||||||
|
multi-process iterators over the CIFAR-10 dataset. A sample
|
||||||
|
9x9 grid of the images can be optionally displayed.
|
||||||
|
If using CUDA, num_workers should be set to 1 and pin_memory to True.
|
||||||
|
Params
|
||||||
|
------
|
||||||
|
- data_dir: path directory to the dataset.
|
||||||
|
- batch_size: how many samples per batch to load.
|
||||||
|
- augment: whether to apply the data augmentation scheme
|
||||||
|
mentioned in the paper. Only applied on the train split.
|
||||||
|
- random_seed: fix seed for reproducibility.
|
||||||
|
- valid_size: percentage split of the training set used for
|
||||||
|
the validation set. Should be a float in the range [0, 1].
|
||||||
|
- shuffle: whether to shuffle the train/validation indices.
|
||||||
|
- show_sample: plot 9x9 sample grid of the dataset.
|
||||||
|
- num_workers: number of subprocesses to use when loading the dataset.
|
||||||
|
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
|
||||||
|
True if using GPU.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
- train_loader: training set iterator.
|
||||||
|
- valid_loader: validation set iterator.
|
||||||
|
"""
|
||||||
|
error_msg = "[!] test_size should be in the range [0, 1]."
|
||||||
|
assert ((test_size >= 0) and (test_size <= 1)), error_msg
|
||||||
|
|
||||||
|
# normalize = transforms.Normalize(
|
||||||
|
# mean=[0.4914, 0.4822, 0.4465],
|
||||||
|
# std=[0.2023, 0.1994, 0.2010],
|
||||||
|
# )
|
||||||
|
|
||||||
|
# define transforms
|
||||||
|
if augment:
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
# transforms.ColorJitter(brightness=0.3),
|
||||||
|
# transforms.Lambda(lambda img: sharpness(img, 5)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
# normalize,
|
||||||
|
])
|
||||||
|
|
||||||
|
valid_transform = transforms.Compose([
|
||||||
|
# transforms.ColorJitter(brightness=0.3),
|
||||||
|
# transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
# normalize,
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
# normalize,
|
||||||
|
])
|
||||||
|
|
||||||
|
valid_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
# normalize,
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# load the dataset
|
||||||
|
train_dataset = torchvision.datasets.ImageFolder(
|
||||||
|
root=args.data_path, transform=train_transform
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataset = torchvision.datasets.ImageFolder(
|
||||||
|
root=args.data_path, transform=valid_transform
|
||||||
|
)
|
||||||
|
|
||||||
|
num_train = len(train_dataset)
|
||||||
|
indices = list(range(num_train))
|
||||||
|
split = int(np.floor(test_size * num_train))
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
np.random.seed(random_seed)
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
|
train_idx, test_idx = indices[split:], indices[:split]
|
||||||
|
train_idx, valid_idx = train_idx[:int(len(train_idx)*train_size)], train_idx[int(len(train_idx)*train_size):]
|
||||||
|
print("\nTrain", len(train_idx), "\nValid", len(valid_idx), "\nTest", len(test_idx))
|
||||||
|
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) if not args.test_only else SubsetSampler(train_idx)
|
||||||
|
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx) if not args.test_only else SubsetSampler(valid_idx)
|
||||||
|
test_sampler = SubsetSampler(test_idx)
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=train_sampler,
|
||||||
|
num_workers=num_workers, pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=valid_sampler,
|
||||||
|
num_workers=num_workers, pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
test_loader = torch.utils.data.DataLoader(
|
||||||
|
test_dataset, batch_size=1, sampler=test_sampler,
|
||||||
|
num_workers=num_workers, pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
imgs = np.asarray(train_dataset.imgs)
|
||||||
|
|
||||||
|
# print('Train')
|
||||||
|
# print(imgs[train_idx])
|
||||||
|
#print('Valid')
|
||||||
|
#print(imgs[valid_idx])
|
||||||
|
|
||||||
|
return (train_loader, valid_loader, test_loader)
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
#augment = True if not args.test_only else False
|
||||||
|
augment = False
|
||||||
|
|
||||||
|
data_loader, dl_val, data_loader_test = get_train_valid_loader(args=args, pin_memory=True, augment=augment,
|
||||||
|
num_workers=args.workers, train_size=0.99, test_size=0.2, random_seed=999)
|
||||||
|
|
||||||
|
print("Creating model")
|
||||||
|
model = torchvision.models.__dict__[args.model](pretrained=True)
|
||||||
|
flat = list(model.children())
|
||||||
|
|
||||||
|
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(data_loader.dataset.classes)))
|
||||||
|
model = nn.Sequential(body, head)
|
||||||
|
|
||||||
|
# model.fc = nn.Linear(model.fc.in_features, 2)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss().to(device)
|
||||||
|
|
||||||
|
# optimizer = torch.optim.SGD(
|
||||||
|
# model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||||
|
'''
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
|
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
|
optimizer,
|
||||||
|
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
|
||||||
|
'''
|
||||||
|
es = utils.EarlyStopping()
|
||||||
|
|
||||||
|
if args.test_only:
|
||||||
|
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
|
||||||
|
model = model.to(device)
|
||||||
|
print('TEST')
|
||||||
|
_, missed, _ = evaluate(model, criterion, data_loader_test, device=device)
|
||||||
|
print(missed)
|
||||||
|
print('TRAIN')
|
||||||
|
_, missed, _ = evaluate(model, criterion, data_loader, device=device)
|
||||||
|
print(missed)
|
||||||
|
return
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
print("Start training")
|
||||||
|
start_time = time.time()
|
||||||
|
mb = master_bar(range(args.epochs))
|
||||||
|
"""
|
||||||
|
for epoch in mb:
|
||||||
|
_, train_confmat = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mb)
|
||||||
|
lr_scheduler.step( (epoch+1)*len(data_loader) )
|
||||||
|
val_loss, _, valid_confmat = evaluate(model, criterion, data_loader_test, device=device)
|
||||||
|
es(val_loss, model)
|
||||||
|
|
||||||
|
# print('Valid Missed')
|
||||||
|
# print(valid_missed)
|
||||||
|
|
||||||
|
|
||||||
|
# print('Train')
|
||||||
|
# print(train_confmat)
|
||||||
|
print('Valid')
|
||||||
|
print(valid_confmat)
|
||||||
|
|
||||||
|
# if es.early_stop:
|
||||||
|
# break
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print('Training time {}'.format(total_time_str))
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
#######
|
||||||
|
|
||||||
|
inner_it = args.inner_it
|
||||||
|
dataug_epoch_start=0
|
||||||
|
print_freq=1
|
||||||
|
KLdiv=False
|
||||||
|
|
||||||
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
|
model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
||||||
|
#model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||||
|
|
||||||
|
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
|
||||||
|
dl_val_it = iter(dl_val)
|
||||||
|
countcopy=0
|
||||||
|
|
||||||
|
#if inner_it!=0:
|
||||||
|
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=args.lr) #lr=1e-2
|
||||||
|
#inner_opt = torch.optim.SGD(model['model'].parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #lr=1e-2 / momentum=0.9
|
||||||
|
inner_opt = torch.optim.Adam(model['model'].parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
|
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
|
inner_opt,
|
||||||
|
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
|
||||||
|
|
||||||
|
high_grad_track = True
|
||||||
|
if inner_it == 0:
|
||||||
|
high_grad_track=False
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
model.augment(mode=False)
|
||||||
|
|
||||||
|
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
|
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||||
|
|
||||||
|
i=0
|
||||||
|
|
||||||
|
for epoch in mb:
|
||||||
|
|
||||||
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
|
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
|
||||||
|
header = 'Epoch: {}'.format(epoch)
|
||||||
|
|
||||||
|
t0 = time.process_time()
|
||||||
|
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=mb):
|
||||||
|
#for i, (xs, ys) in enumerate(dl_train):
|
||||||
|
#print_torch_mem("it"+str(i))
|
||||||
|
i+=1
|
||||||
|
image, target = image.to(device), target.to(device)
|
||||||
|
|
||||||
|
if(not KLdiv):
|
||||||
|
#Methode uniforme
|
||||||
|
logits = fmodel(image) # modified `params` can also be passed as a kwarg
|
||||||
|
output = F.log_softmax(logits, dim=1)
|
||||||
|
loss = F.cross_entropy(output, target, reduction='none') # no need to call loss.backwards()
|
||||||
|
|
||||||
|
if fmodel._data_augmentation: #Weight loss
|
||||||
|
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||||
|
loss = loss * w_loss
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
|
else:
|
||||||
|
#Methode KL div
|
||||||
|
fmodel.augment(mode=False)
|
||||||
|
sup_logits = fmodel(xs)
|
||||||
|
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||||
|
fmodel.augment(mode=True)
|
||||||
|
loss = F.cross_entropy(log_sup, ys)
|
||||||
|
|
||||||
|
if fmodel._data_augmentation:
|
||||||
|
aug_logits = fmodel(xs)
|
||||||
|
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||||
|
aug_loss=0
|
||||||
|
if epoch>50: #debut differe ?
|
||||||
|
#KL div w/ logits - Similarite predictions (distributions)
|
||||||
|
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
|
||||||
|
aug_loss=aug_loss.sum(dim=-1)
|
||||||
|
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
|
||||||
|
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||||
|
aug_loss = (w_loss * aug_loss).mean()
|
||||||
|
|
||||||
|
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||||
|
#print(aug_loss)
|
||||||
|
unsupp_coeff = 1
|
||||||
|
loss += aug_loss * unsupp_coeff
|
||||||
|
|
||||||
|
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||||
|
|
||||||
|
if(high_grad_track and i%inner_it==0): #Perform Meta step
|
||||||
|
#print("meta")
|
||||||
|
#Peu utile si high_grad_track = False
|
||||||
|
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
||||||
|
#print_graph(val_loss)
|
||||||
|
|
||||||
|
val_loss.backward()
|
||||||
|
|
||||||
|
countcopy+=1
|
||||||
|
model_copy(src=fmodel, dst=model)
|
||||||
|
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
|
||||||
|
#if epoch>50:
|
||||||
|
meta_opt.step()
|
||||||
|
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||||
|
#model['data_aug'].next_TF_set()
|
||||||
|
|
||||||
|
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
|
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||||
|
|
||||||
|
|
||||||
|
acc1 = utils.accuracy(output, target)[0]
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||||
|
metric_logger.update(loss=loss.item())
|
||||||
|
|
||||||
|
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||||
|
|
||||||
|
if(not high_grad_track and (torch.cuda.memory_cached()/1024.0**2)>20000):
|
||||||
|
countcopy+=1
|
||||||
|
print_torch_mem("copy")
|
||||||
|
model_copy(src=fmodel, dst=model)
|
||||||
|
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
||||||
|
|
||||||
|
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
||||||
|
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
|
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||||
|
print_torch_mem("copy")
|
||||||
|
|
||||||
|
if(not high_grad_track):
|
||||||
|
countcopy+=1
|
||||||
|
print_torch_mem("end copy")
|
||||||
|
model_copy(src=fmodel, dst=model)
|
||||||
|
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
||||||
|
|
||||||
|
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
||||||
|
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
|
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||||
|
print_torch_mem("end copy")
|
||||||
|
|
||||||
|
|
||||||
|
tf = time.process_time()
|
||||||
|
|
||||||
|
|
||||||
|
#### Print ####
|
||||||
|
if(print_freq and epoch%print_freq==0):
|
||||||
|
print('-'*9)
|
||||||
|
print('Epoch : %d'%(epoch))
|
||||||
|
print('Time : %.00f'%(tf - t0))
|
||||||
|
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||||
|
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||||
|
print('TF Proba :', model['data_aug']['prob'].data)
|
||||||
|
#print('proba grad',model['data_aug']['prob'].grad)
|
||||||
|
print('TF Mag :', model['data_aug']['mag'].data)
|
||||||
|
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||||
|
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||||
|
#print('Aug loss', aug_loss.item())
|
||||||
|
#############
|
||||||
|
#### Log ####
|
||||||
|
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||||
|
'''
|
||||||
|
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||||
|
data={
|
||||||
|
"epoch": epoch,
|
||||||
|
"train_loss": loss.item(),
|
||||||
|
"val_loss": val_loss.item(),
|
||||||
|
"acc": accuracy,
|
||||||
|
"time": tf - t0,
|
||||||
|
|
||||||
|
"param": param #if isinstance(model['data_aug'], Data_augV5)
|
||||||
|
#else [p.item() for p in model['data_aug']['prob']],
|
||||||
|
}
|
||||||
|
log.append(data)
|
||||||
|
'''
|
||||||
|
#############
|
||||||
|
|
||||||
|
train_confmat=confmat
|
||||||
|
lr_scheduler.step( (epoch+1)*len(data_loader) )
|
||||||
|
|
||||||
|
test_loss, _, test_confmat = evaluate(model, criterion, data_loader_test, device=device)
|
||||||
|
es(test_loss, model)
|
||||||
|
|
||||||
|
# print('Valid Missed')
|
||||||
|
# print(valid_missed)
|
||||||
|
|
||||||
|
|
||||||
|
# print('Train')
|
||||||
|
# print(train_confmat)
|
||||||
|
print('Test')
|
||||||
|
print(test_confmat)
|
||||||
|
|
||||||
|
# if es.early_stop:
|
||||||
|
# break
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print('Training time {}'.format(total_time_str))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
|
||||||
|
|
||||||
|
parser.add_argument('--data-path', default='/Salvador', help='dataset')
|
||||||
|
parser.add_argument('--model', default='resnet50', help='model')
|
||||||
|
parser.add_argument('--device', default='cuda:1', help='device')
|
||||||
|
parser.add_argument('-b', '--batch-size', default=4, type=int)
|
||||||
|
parser.add_argument('--epochs', default=3, type=int, metavar='N',
|
||||||
|
help='number of total epochs to run')
|
||||||
|
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
|
||||||
|
help='number of data loading workers (default: 16)')
|
||||||
|
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
|
||||||
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||||
|
help='momentum')
|
||||||
|
parser.add_argument('--wd', '--weight-decay', default=4e-5, type=float,
|
||||||
|
metavar='W', help='weight decay (default: 1e-4)',
|
||||||
|
dest='weight_decay')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-only",
|
||||||
|
dest="test_only",
|
||||||
|
help="Only test the model",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('--in_it', '--inner_it', default=0, type=int,
|
||||||
|
metavar='N', help='higher inner_it',
|
||||||
|
dest='inner_it')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
346
salvador/transformations.py
Executable file
346
salvador/transformations.py
Executable file
|
@ -0,0 +1,346 @@
|
||||||
|
import torch
|
||||||
|
import kornia
|
||||||
|
import random
|
||||||
|
|
||||||
|
### Available TF for Dataug ###
|
||||||
|
'''
|
||||||
|
TF_dict={ #Dataugv4
|
||||||
|
## Geometric TF ##
|
||||||
|
'Identity' : (lambda x, mag: x),
|
||||||
|
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||||
|
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||||
|
'Rotate': (lambda x, mag: rotate(x, angle=torch.tensor([rand_int(mag, maxval=30)for _ in x], device=x.device))),
|
||||||
|
'TranslateX': (lambda x, mag: translate(x, translation=torch.tensor([[rand_int(mag, maxval=20), 0] for _ in x], device=x.device))),
|
||||||
|
'TranslateY': (lambda x, mag: translate(x, translation=torch.tensor([[0, rand_int(mag, maxval=20)] for _ in x], device=x.device))),
|
||||||
|
'ShearX': (lambda x, mag: shear(x, shear=torch.tensor([[rand_float(mag, maxval=0.3), 0] for _ in x], device=x.device))),
|
||||||
|
'ShearY': (lambda x, mag: shear(x, shear=torch.tensor([[0, rand_float(mag, maxval=0.3)] for _ in x], device=x.device))),
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||||
|
'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||||
|
'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||||
|
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||||
|
'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))),
|
||||||
|
'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
|
||||||
|
#Non fonctionnel
|
||||||
|
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||||
|
#'Equalize': (lambda mag: None),
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
'''
|
||||||
|
TF_dict={ #Dataugv5 #AutoAugment
|
||||||
|
## Geometric TF ##
|
||||||
|
'Identity' : (lambda x, mag: x),
|
||||||
|
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||||
|
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||||
|
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
|
||||||
|
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
||||||
|
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
||||||
|
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||||
|
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
|
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||||
|
|
||||||
|
#Non fonctionnel
|
||||||
|
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||||
|
#'Equalize': (lambda mag: None),
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
TF_dict={ #Dataugv5
|
||||||
|
## Geometric TF ##
|
||||||
|
'Identity' : (lambda x, mag: x),
|
||||||
|
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||||
|
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||||
|
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
|
||||||
|
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
||||||
|
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
||||||
|
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||||
|
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
|
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||||
|
|
||||||
|
#Color TF (Common mag scale)
|
||||||
|
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
|
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||||
|
|
||||||
|
|
||||||
|
'BRotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30*3))),
|
||||||
|
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=0))),
|
||||||
|
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=1))),
|
||||||
|
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3*3), zero_pos=0))),
|
||||||
|
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3*3), zero_pos=1))),
|
||||||
|
|
||||||
|
'BadTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=0))),
|
||||||
|
'BadTranslateX_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=0))),
|
||||||
|
'BadTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=1))),
|
||||||
|
'BadTranslateY_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=1))),
|
||||||
|
|
||||||
|
'BadColor':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||||
|
'BadSharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||||
|
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||||
|
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||||
|
|
||||||
|
#Non fonctionnel
|
||||||
|
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||||
|
#'Equalize': (lambda mag: None),
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_no_mag={'Identity', 'FlipUD', 'FlipLR'}
|
||||||
|
TF_ignore_mag= TF_no_mag | {'Solarize', 'Posterize'}
|
||||||
|
|
||||||
|
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
||||||
|
return (float_image*255.).type(torch.uint8)
|
||||||
|
|
||||||
|
def float_image(int_image):
|
||||||
|
return int_image.type(torch.float)/255.
|
||||||
|
|
||||||
|
#def rand_inverse(value):
|
||||||
|
# return value if random.random() < 0.5 else -value
|
||||||
|
|
||||||
|
#def rand_int(mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||||
|
# real_max = int_parameter(mag, maxval=maxval)
|
||||||
|
# if not minval : minval = -real_max
|
||||||
|
# return random.randint(minval, real_max)
|
||||||
|
|
||||||
|
#def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||||
|
# real_max = float_parameter(mag, maxval=maxval)
|
||||||
|
# if not minval : minval = -real_max
|
||||||
|
# return random.uniform(minval, real_max)
|
||||||
|
|
||||||
|
def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||||
|
real_mag = float_parameter(mag, maxval=maxval)
|
||||||
|
if not minval : minval = -real_mag
|
||||||
|
#return random.uniform(minval, real_max)
|
||||||
|
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
|
||||||
|
|
||||||
|
def invScale_rand_floats(size, mag, maxval, minval):
|
||||||
|
#Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval]
|
||||||
|
real_mag = float_parameter(float(PARAMETER_MAX) - mag, maxval=maxval-minval)+minval
|
||||||
|
return real_mag + (maxval-real_mag) * torch.rand(size, device=mag.device) #[real_mag, max_val]
|
||||||
|
|
||||||
|
def zero_stack(tensor, zero_pos):
|
||||||
|
if zero_pos==0:
|
||||||
|
return torch.stack((tensor, torch.zeros((tensor.shape[0],), device=tensor.device)), dim=1)
|
||||||
|
if zero_pos==1:
|
||||||
|
return torch.stack((torch.zeros((tensor.shape[0],), device=tensor.device), tensor), dim=1)
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid zero_pos : ", zero_pos)
|
||||||
|
|
||||||
|
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
|
||||||
|
PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted
|
||||||
|
def float_parameter(level, maxval):
|
||||||
|
"""Helper function to scale `val` between 0 and maxval .
|
||||||
|
Args:
|
||||||
|
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
||||||
|
maxval: Maximum value that the operation can have. This will be scaled
|
||||||
|
to level/PARAMETER_MAX.
|
||||||
|
Returns:
|
||||||
|
A float that results from scaling `maxval` according to `level`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
#return float(level) * maxval / PARAMETER_MAX
|
||||||
|
return (level * maxval / PARAMETER_MAX)#.to(torch.float)
|
||||||
|
|
||||||
|
#def int_parameter(level, maxval): #Perte de gradient
|
||||||
|
"""Helper function to scale `val` between 0 and maxval .
|
||||||
|
Args:
|
||||||
|
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
||||||
|
maxval: Maximum value that the operation can have. This will be scaled
|
||||||
|
to level/PARAMETER_MAX.
|
||||||
|
Returns:
|
||||||
|
An int that results from scaling `maxval` according to `level`.
|
||||||
|
"""
|
||||||
|
#return int(level * maxval / PARAMETER_MAX)
|
||||||
|
# return (level * maxval / PARAMETER_MAX)
|
||||||
|
|
||||||
|
def flipLR(x):
|
||||||
|
device = x.device
|
||||||
|
(batch_size, channels, h, w) = x.shape
|
||||||
|
|
||||||
|
M =torch.tensor( [[[-1., 0., w-1],
|
||||||
|
[ 0., 1., 0.],
|
||||||
|
[ 0., 0., 1.]]], device=device).expand(batch_size,-1,-1)
|
||||||
|
|
||||||
|
# warp the original image by the found transform
|
||||||
|
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||||
|
|
||||||
|
def flipUD(x):
|
||||||
|
device = x.device
|
||||||
|
(batch_size, channels, h, w) = x.shape
|
||||||
|
|
||||||
|
M =torch.tensor( [[[ 1., 0., 0.],
|
||||||
|
[ 0., -1., h-1],
|
||||||
|
[ 0., 0., 1.]]], device=device).expand(batch_size,-1,-1)
|
||||||
|
|
||||||
|
# warp the original image by the found transform
|
||||||
|
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||||
|
|
||||||
|
def rotate(x, angle):
|
||||||
|
return kornia.rotate(x, angle=angle.type(torch.float)) #Kornia ne supporte pas les int
|
||||||
|
|
||||||
|
def translate(x, translation):
|
||||||
|
#print(translation)
|
||||||
|
return kornia.translate(x, translation=translation.type(torch.float)) #Kornia ne supporte pas les int
|
||||||
|
|
||||||
|
def shear(x, shear):
|
||||||
|
return kornia.shear(x, shear=shear)
|
||||||
|
|
||||||
|
def contrast(x, contrast_factor):
|
||||||
|
return kornia.adjust_contrast(x, contrast_factor=contrast_factor) #Expect image in the range of [0, 1]
|
||||||
|
|
||||||
|
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageEnhance.py
|
||||||
|
def color(x, color_factor):
|
||||||
|
(batch_size, channels, h, w) = x.shape
|
||||||
|
|
||||||
|
gray_x = kornia.rgb_to_grayscale(x)
|
||||||
|
gray_x = gray_x.repeat_interleave(channels, dim=1)
|
||||||
|
return blend(gray_x, x, color_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
|
||||||
|
|
||||||
|
def brightness(x, brightness_factor):
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
return blend(torch.zeros(x.size(), device=device), x, brightness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
|
||||||
|
|
||||||
|
def sharpeness(x, sharpness_factor):
|
||||||
|
device = x.device
|
||||||
|
(batch_size, channels, h, w) = x.shape
|
||||||
|
|
||||||
|
k = torch.tensor([[[ 1., 1., 1.],
|
||||||
|
[ 1., 5., 1.],
|
||||||
|
[ 1., 1., 1.]]], device=device) #Smooth Filter : https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageFilter.py
|
||||||
|
smooth_x = kornia.filter2D(x, kernel=k, border_type='reflect', normalized=True) #Peut etre necessaire de s'occuper du channel Alhpa differement
|
||||||
|
|
||||||
|
return blend(smooth_x, x, sharpness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
|
||||||
|
|
||||||
|
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
|
||||||
|
def posterize(x, bits):
|
||||||
|
bits = bits.type(torch.uint8) #Perte du gradient
|
||||||
|
x = int_image(x) #Expect image in the range of [0, 1]
|
||||||
|
|
||||||
|
mask = ~(2 ** (8 - bits) - 1).type(torch.uint8)
|
||||||
|
|
||||||
|
(batch_size, channels, h, w) = x.shape
|
||||||
|
mask = mask.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
|
|
||||||
|
return float_image(x & mask)
|
||||||
|
|
||||||
|
def auto_contrast(x): #PAS OPTIMISE POUR DES BATCH #EXTRA LENT
|
||||||
|
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
|
||||||
|
print("Warning : Pas encore check !")
|
||||||
|
(batch_size, channels, h, w) = x.shape
|
||||||
|
x = int_image(x) #Expect image in the range of [0, 1]
|
||||||
|
#print('Start',x[0])
|
||||||
|
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
|
||||||
|
#print(img.shape)
|
||||||
|
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
|
||||||
|
#print(chan.shape)
|
||||||
|
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
||||||
|
|
||||||
|
# find lowest/highest samples after preprocessing
|
||||||
|
for lo in range(256):
|
||||||
|
if hist[lo]:
|
||||||
|
break
|
||||||
|
for hi in range(255, -1, -1):
|
||||||
|
if hist[hi]:
|
||||||
|
break
|
||||||
|
if hi <= lo:
|
||||||
|
# don't bother
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
scale = 255.0 / (hi - lo)
|
||||||
|
offset = -lo * scale
|
||||||
|
for ix in range(256):
|
||||||
|
n_ix = int(ix * scale + offset)
|
||||||
|
if n_ix < 0: n_ix = 0
|
||||||
|
elif n_ix > 255: n_ix = 255
|
||||||
|
|
||||||
|
chan[chan==ix]=n_ix
|
||||||
|
x[im_idx, chan_idx]=chan
|
||||||
|
|
||||||
|
#print('End',x[0])
|
||||||
|
return float_image(x)
|
||||||
|
|
||||||
|
def equalize(x): #PAS OPTIMISE POUR DES BATCH
|
||||||
|
raise Exception(self, "not implemented")
|
||||||
|
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
|
||||||
|
(batch_size, channels, h, w) = x.shape
|
||||||
|
x = int_image(x) #Expect image in the range of [0, 1]
|
||||||
|
#print('Start',x[0])
|
||||||
|
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
|
||||||
|
#print(img.shape)
|
||||||
|
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
|
||||||
|
#print(chan.shape)
|
||||||
|
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
||||||
|
|
||||||
|
return float_image(x)
|
||||||
|
|
||||||
|
def solarize(x, thresholds):
|
||||||
|
batch_size, channels, h, w = x.shape
|
||||||
|
#imgs=[]
|
||||||
|
#for idx, t in enumerate(thresholds): #Operation par image
|
||||||
|
# mask = x[idx] > t #Perte du gradient
|
||||||
|
#In place
|
||||||
|
# inv_x = 1-x[idx][mask]
|
||||||
|
# x[idx][mask]=inv_x
|
||||||
|
#
|
||||||
|
|
||||||
|
#Out of place
|
||||||
|
# im = x[idx]
|
||||||
|
# inv_x = 1-im[mask]
|
||||||
|
|
||||||
|
# imgs.append(im.masked_scatter(mask,inv_x))
|
||||||
|
|
||||||
|
#idxs=torch.tensor(range(x.shape[0]), device=x.device)
|
||||||
|
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
|
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
|
||||||
|
#
|
||||||
|
|
||||||
|
thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
|
#print(thresholds.grad_fn)
|
||||||
|
x=torch.where(x>thresholds,1-x, x)
|
||||||
|
#print(mask.grad_fn)
|
||||||
|
|
||||||
|
#x=x.min(thresholds)
|
||||||
|
#inv_x = 1-x[mask]
|
||||||
|
#x=x.where(x<thresholds,1-x)
|
||||||
|
#x[mask]=inv_x
|
||||||
|
#x=x.masked_scatter(mask, inv_x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
|
||||||
|
def blend(x,y,alpha): #out = image1 * (1.0 - alpha) + image2 * alpha
|
||||||
|
#return kornia.add_weighted(src1=x, alpha=(1-alpha), src2=y, beta=alpha, gamma=0) #out=src1∗alpha+src2∗beta+gamma #Ne fonctionne pas pour des batch de alpha
|
||||||
|
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
raise TypeError("x should be a tensor. Got {}".format(type(x)))
|
||||||
|
|
||||||
|
if not isinstance(y, torch.Tensor):
|
||||||
|
raise TypeError("y should be a tensor. Got {}".format(type(y)))
|
||||||
|
|
||||||
|
(batch_size, channels, h, w) = x.shape
|
||||||
|
alpha = alpha.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
|
res = x*(1-alpha) + y*alpha
|
||||||
|
|
||||||
|
return res
|
200
salvador/utils.py
Normal file
200
salvador/utils.py
Normal file
|
@ -0,0 +1,200 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
import datetime
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import os
|
||||||
|
from fastprogress import progress_bar
|
||||||
|
|
||||||
|
class SmoothedValue(object):
|
||||||
|
"""Track a series of values and provide access to smoothed values over a
|
||||||
|
window or the global series average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size=20, fmt=None):
|
||||||
|
if fmt is None:
|
||||||
|
fmt = "{global_avg:.4f}"
|
||||||
|
self.deque = deque(maxlen=window_size)
|
||||||
|
self.total = 0.0
|
||||||
|
self.count = 0
|
||||||
|
self.fmt = fmt
|
||||||
|
|
||||||
|
def update(self, value, n=1):
|
||||||
|
self.deque.append(value)
|
||||||
|
self.count += n
|
||||||
|
self.total += value * n
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
return d.median().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self):
|
||||||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||||
|
return d.mean().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_avg(self):
|
||||||
|
return self.total / self.count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self):
|
||||||
|
return max(self.deque)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self.deque[-1]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.fmt.format(
|
||||||
|
median=self.median,
|
||||||
|
avg=self.avg,
|
||||||
|
global_avg=self.global_avg,
|
||||||
|
max=self.max,
|
||||||
|
value=self.value)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix(object):
|
||||||
|
def __init__(self, num_classes):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.mat = None
|
||||||
|
|
||||||
|
def update(self, a, b):
|
||||||
|
n = self.num_classes
|
||||||
|
if self.mat is None:
|
||||||
|
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
k = (a >= 0) & (a < n)
|
||||||
|
inds = n * a[k].to(torch.int64) + b[k]
|
||||||
|
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.mat.zero_()
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
h = self.mat.float()
|
||||||
|
acc_global = torch.diag(h).sum() / h.sum()
|
||||||
|
acc = torch.diag(h) / h.sum(1)
|
||||||
|
return acc_global, acc
|
||||||
|
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
acc_global, acc = self.compute()
|
||||||
|
return (
|
||||||
|
'global correct: {:.1f}\n'
|
||||||
|
'average row correct: {}').format(
|
||||||
|
acc_global.item() * 100,
|
||||||
|
['{:.1f}'.format(i) for i in (acc * 100).tolist()])
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLogger(object):
|
||||||
|
def __init__(self, delimiter="\t"):
|
||||||
|
self.meters = defaultdict(SmoothedValue)
|
||||||
|
self.delimiter = delimiter
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.item()
|
||||||
|
assert isinstance(v, (float, int))
|
||||||
|
self.meters[k].update(v)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in self.meters:
|
||||||
|
return self.meters[attr]
|
||||||
|
if attr in self.__dict__:
|
||||||
|
return self.__dict__[attr]
|
||||||
|
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||||
|
type(self).__name__, attr))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
loss_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
loss_str.append(
|
||||||
|
"{}: {}".format(name, str(meter))
|
||||||
|
)
|
||||||
|
return self.delimiter.join(loss_str)
|
||||||
|
|
||||||
|
|
||||||
|
def add_meter(self, name, meter):
|
||||||
|
self.meters[name] = meter
|
||||||
|
|
||||||
|
def log_every(self, iterable, parent, header=None, **kwargs):
|
||||||
|
if not header:
|
||||||
|
header = ''
|
||||||
|
log_msg = self.delimiter.join([
|
||||||
|
'{meters}'
|
||||||
|
])
|
||||||
|
|
||||||
|
progrss = progress_bar(iterable, parent=parent, **kwargs)
|
||||||
|
|
||||||
|
for idx, obj in enumerate(progrss):
|
||||||
|
yield idx, obj
|
||||||
|
progrss.comment = log_msg.format(
|
||||||
|
meters=str(self))
|
||||||
|
|
||||||
|
print('{header} {meters}'.format(header=header, meters=str(self)))
|
||||||
|
|
||||||
|
def accuracy(output, target, topk=(1,)):
|
||||||
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||||
|
with torch.no_grad():
|
||||||
|
maxk = max(topk)
|
||||||
|
batch_size = target.size(0)
|
||||||
|
|
||||||
|
_, pred = output.topk(maxk, 1, True, True)
|
||||||
|
pred = pred.t()
|
||||||
|
correct = pred.eq(target[None])
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for k in topk:
|
||||||
|
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
|
||||||
|
res.append(correct_k * (100.0 / batch_size))
|
||||||
|
return res
|
||||||
|
|
||||||
|
class EarlyStopping:
|
||||||
|
"""Early stops the training if validation loss doesn't improve after a given patience."""
|
||||||
|
def __init__(self, patience=7, verbose=False, delta=0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
patience (int): How long to wait after last time validation loss improved.
|
||||||
|
Default: 7
|
||||||
|
verbose (bool): If True, prints a message for each validation loss improvement.
|
||||||
|
Default: False
|
||||||
|
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
||||||
|
Default: 0
|
||||||
|
"""
|
||||||
|
self.patience = patience
|
||||||
|
self.verbose = verbose
|
||||||
|
self.counter = 0
|
||||||
|
self.best_score = None
|
||||||
|
self.early_stop = False
|
||||||
|
self.val_loss_min = np.Inf
|
||||||
|
self.delta = delta
|
||||||
|
|
||||||
|
def __call__(self, val_loss, model):
|
||||||
|
|
||||||
|
score = -val_loss
|
||||||
|
|
||||||
|
if self.best_score is None:
|
||||||
|
self.best_score = score
|
||||||
|
self.save_checkpoint(val_loss, model)
|
||||||
|
elif score < self.best_score - self.delta:
|
||||||
|
self.counter += 1
|
||||||
|
# print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||||
|
# if self.counter >= self.patience:
|
||||||
|
# self.early_stop = True
|
||||||
|
else:
|
||||||
|
self.best_score = score
|
||||||
|
self.save_checkpoint(val_loss, model)
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
def save_checkpoint(self, val_loss, model):
|
||||||
|
'''Saves model when validation loss decrease.'''
|
||||||
|
if self.verbose:
|
||||||
|
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
|
||||||
|
torch.save(model.state_dict(), 'checkpoint.pt')
|
||||||
|
self.val_loss_min = val_loss
|
Loading…
Add table
Add a link
Reference in a new issue