This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-12 16:40:05 -05:00
commit e75fb96716
57 changed files with 29210 additions and 0 deletions

98
salvador/cams.py Normal file
View file

@ -0,0 +1,98 @@
import torch
import numpy as np
import torchvision
from PIL import Image
from torch import topk
import torch.nn.functional as F
from torch import topk
import cv2
from torchvision import transforms
import os
class SaveFeatures():
features=None
def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output): self.features = ((output.cpu()).data).numpy()
def remove(self): self.hook.remove()
def getCAM(feature_conv, weight_fc, class_idx):
_, nc, h, w = feature_conv.shape
cam = weight_fc[class_idx].dot(feature_conv.reshape((nc, h*w)))
cam = cam.reshape(h, w)
cam = cam - np.min(cam)
cam_img = cam / np.max(cam)
# cam_img = np.uint8(255 * cam_img)
return cam_img
def main(cam):
device = 'cuda:0'
model_name = 'resnet50'
root = 'NEW_SS'
os.makedirs(os.path.join(root + '_CAM', 'OK'), exist_ok=True)
os.makedirs(os.path.join(root + '_CAM', 'NOK'), exist_ok=True)
train_transform = transforms.Compose([
transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(
root=root, transform=train_transform,
)
loader = torch.utils.data.DataLoader(dataset, batch_size=1)
model = torchvision.models.__dict__[model_name](pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
model = model.to(device)
model.eval()
weight_softmax_params = list(model._modules.get('fc').parameters())
weight_softmax = np.squeeze(weight_softmax_params[0].cpu().data.numpy())
final_layer = model._modules.get('layer4')
activated_features = SaveFeatures(final_layer)
for i, (img, target ) in enumerate(loader):
img = img.to(device)
prediction = model(img)
pred_probabilities = F.softmax(prediction, dim=1).data.squeeze()
class_idx = topk(pred_probabilities,1)[1].int()
# if target.item() != class_idx:
# print(dataset.imgs[i][0])
if cam:
overlay = getCAM(activated_features.features, weight_softmax, class_idx )
import ipdb; ipdb.set_trace()
import PIL
from torchvision.transforms import ToPILImage
img = ToPILImage()(overlay).resize(size=(1280, 1024), resample=PIL.Image.BILINEAR)
img.save('heat-pil.jpg')
img = cv2.imread(dataset.imgs[i][0])
height, width, _ = img.shape
overlay = cv2.resize(overlay, (width, height))
heatmap = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
cv2.imwrite('heat-cv2.jpg', heatmap)
img = cv2.imread(dataset.imgs[i][0])
height, width, _ = img.shape
overlay = cv2.resize(overlay, (width, height))
heatmap = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
result = heatmap * 0.3 + img * 0.5
clss = dataset.imgs[i][0].split(os.sep)[1]
name = dataset.imgs[i][0].split(os.sep)[2].split('.')[0]
cv2.imwrite(os.path.join(root+"_CAM", clss, name + '.jpg'), result)
print(f'{os.path.join(root+"_CAM", clss, name + ".jpg")} saved')
activated_features.remove()
if __name__ == "__main__":
main(cam=True)

BIN
salvador/checkpoint.pt Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

1136
salvador/dataug.py Executable file

File diff suppressed because it is too large Load diff

314
salvador/dataug_utils.py Normal file
View file

@ -0,0 +1,314 @@
import numpy as np
import json, math, time, os
import matplotlib.pyplot as plt
import copy
import gc
from torchviz import make_dot
import torch
import torch.nn.functional as F
import time
class timer():
def __init__(self):
self._start_time=time.time()
def exec_time(self):
end = time.time()
res = end-self._start_time
self._start_time=end
return res
def print_graph(PyTorch_obj, fig_name='graph'):
graph=make_dot(PyTorch_obj) #Loss give the whole graph
graph.format = 'svg' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
graph.render(fig_name)
def plot_res(log, fig_name='res', param_names=None):
epochs = [x["epoch"] for x in log]
fig, ax = plt.subplots(ncols=3, figsize=(15, 3))
ax[0].set_title('Loss')
ax[0].plot(epochs,[x["train_loss"] for x in log], label='Train')
ax[0].plot(epochs,[x["val_loss"] for x in log], label='Val')
ax[0].legend()
ax[1].set_title('Acc')
ax[1].plot(epochs,[x["acc"] for x in log])
if log[0]["param"]!= None:
if isinstance(log[0]["param"],float):
ax[2].set_title('Mag')
ax[2].plot(epochs,[x["param"] for x in log], label='Mag')
ax[2].legend()
else :
ax[2].set_title('Prob')
#for idx, _ in enumerate(log[0]["param"]):
#ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx))
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
ax[2].stackplot(epochs, proba, labels=param_names)
ax[2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name)
plt.close()
def plot_resV2(log, fig_name='res', param_names=None):
epochs = [x["epoch"] for x in log]
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(30, 15))
ax[0, 0].set_title('Loss')
ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train')
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
ax[0, 0].legend()
ax[1, 0].set_title('Acc')
ax[1, 0].plot(epochs,[x["acc"] for x in log])
if log[0]["param"]!= None:
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
#proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
ax[0, 1].set_title('Prob =f(epoch)')
ax[0, 1].stackplot(epochs, proba, labels=param_names)
#ax[0, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
ax[1, 1].set_title('Prob =f(TF)')
mean = np.mean(proba, axis=1)
std = np.std(proba, axis=1)
ax[1, 1].bar(param_names, mean, yerr=std)
plt.sca(ax[1, 1]), plt.xticks(rotation=90)
ax[0, 2].set_title('Mag =f(epoch)')
ax[0, 2].stackplot(epochs, mag, labels=param_names)
ax[0, 2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
ax[1, 2].set_title('Mag =f(TF)')
mean = np.mean(mag, axis=1)
std = np.std(mag, axis=1)
ax[1, 2].bar(param_names, mean, yerr=std)
plt.sca(ax[1, 2]), plt.xticks(rotation=90)
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def plot_compare(filenames, fig_name='res'):
all_data=[]
legend=""
for idx, file in enumerate(filenames):
legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
all_data.append(data)
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
for data_idx, log in enumerate(all_data):
log=log['Log']
epochs = [x["epoch"] for x in log]
ax[0].plot(epochs,[x["train_loss"] for x in log], label=str(data_idx)+'-Train')
ax[0].plot(epochs,[x["val_loss"] for x in log], label=str(data_idx)+'-Val')
ax[1].plot(epochs,[x["acc"] for x in log], label=str(data_idx))
#ax[1].text(x=0.5,y=0,s=str(data_idx)+'-'+filenames[data_idx], transform=ax[1].transAxes)
if log[0]["param"]!= None:
if isinstance(log[0]["param"],float):
ax[2].plot(epochs,[x["param"] for x in log], label=str(data_idx)+'-Mag')
else :
for idx, _ in enumerate(log[0]["param"]):
ax[2].plot(epochs,[x["param"][idx] for x in log], label=str(data_idx)+'-P'+str(idx))
fig.suptitle(legend)
ax[0].set_title('Loss')
ax[1].set_title('Acc')
ax[2].set_title('Param')
for a in ax: a.legend()
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def plot_res_compare(filenames, fig_name='res'):
all_data=[]
#legend=""
for idx, file in enumerate(filenames):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
all_data.append(data)
n_tf = [len(x["Param_names"]) for x in all_data]
acc = [x["Accuracy"] for x in all_data]
time = [x["Time"][0] for x in all_data]
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
ax[0].plot(n_tf, acc)
ax[1].plot(n_tf, time)
ax[0].set_title('Acc')
ax[1].set_title('Time')
#for a in ax: a.legend()
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def plot_TF_res(log, tf_names, fig_name='res'):
mean = np.mean([x["param"] for x in log], axis=0)
std = np.std([x["param"] for x in log], axis=0)
fig, ax = plt.subplots(1, 1, figsize=(30, 8), sharey=True)
ax.bar(tf_names, mean, yerr=std)
#ax.bar(tf_names, log[-1]["param"])
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def viz_sample_data(imgs, labels, fig_name='data_sample'):
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
plt.xlabel(labels[i].item())
plt.savefig(fig_name)
print("Sample saved :", fig_name)
plt.close()
def model_copy(src,dst, patch_copy=True, copy_grad=True):
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
if patch_copy:
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
#Copie des gradients
if copy_grad:
for paramName, paramValue, in src.named_parameters():
for netCopyName, netCopyValue, in dst.named_parameters():
if paramName == netCopyName:
netCopyValue.grad = paramValue.grad
#netCopyValue=copy.deepcopy(paramValue)
try: #Data_augV4
dst['data_aug']._input_info = src['data_aug']._input_info
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
except:
pass
def optim_copy(dopt, opt):
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
for group_idx, group in enumerate(opt.param_groups):
# print('gp idx',group_idx)
for p_idx, p in enumerate(group['params']):
opt.state[p]=dopt.state[group_idx][p_idx]
def print_torch_mem(add_info=''):
nb=0
max_size=0
for obj in gc.get_objects():
#print(type(obj))
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # and len(obj.size())>1:
#print(i, type(obj), obj.size())
size = np.sum(obj.size())
if(size>max_size): max_size=size
nb+=1
except:
pass
print(add_info, "-Pytroch tensor nb:",nb," / Max dim:", max_size)
#print(add_info, "-Garbage size :",len(gc.garbage))
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = add_info + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format(
torch.cuda.max_memory_cached()/ mega_bytes)
print(string)
def plot_TF_influence(log, fig_name='TF_influence', param_names=None):
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
plt.figure()
mean = np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
std = np.std(proba, axis=1)*np.std(mag, axis=1)
plt.bar(param_names, mean, yerr=std)
plt.xticks(rotation=90)
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
class loss_monitor(): #Voir https://github.com/pytorch/ignite
def __init__(self, patience, end_train=1):
self.patience = patience
self.end_train = end_train
self.counter = 0
self.best_score = None
self.reached_limit = 0
def register(self, loss):
if self.best_score is None:
self.best_score = loss
elif loss > self.best_score:
self.counter += 1
#if not self.reached_limit:
print("loss no improve counter", self.counter, self.reached_limit)
else:
self.best_score = loss
self.counter = 0
def limit_reached(self):
if self.counter >= self.patience:
self.counter = 0
self.reached_limit +=1
self.best_score = None
return self.reached_limit
def end_training(self):
if self.limit_reached() >= self.end_train:
return True
else:
return False
def reset(self):
self.__init__(self.patience, self.end_train)

102
salvador/grad_cam.py Normal file
View file

@ -0,0 +1,102 @@
import torch
import numpy as np
import torchvision
from PIL import Image
from torch import topk
from torch import nn
import torch.nn.functional as F
from torch import topk
import cv2
from torchvision import transforms
import os
class Lambda(nn.Module):
"Create a layer that simply calls `func` with `x`"
def __init__(self, func):
super().__init__()
self.func=func
def forward(self, x): return self.func(x)
class SaveFeatures():
activations, gradients = None, None
def __init__(self, m):
self.forward = m.register_forward_hook(self.forward_hook_fn)
self.backward = m.register_backward_hook(self.backward_hook_fn)
def forward_hook_fn(self, module, input, output):
self.activations = output.cpu().detach()
def backward_hook_fn(self, module, grad_input, grad_output):
self.gradients = grad_output[0].cpu().detach()
def remove(self):
self.forward.remove()
self.backward.remove()
def main(cam):
device = 'cuda:0'
model_name = 'resnet50'
root = '/mnt/md0/data/cifar10/tmp/cifar/train'
_root = 'cifar'
os.makedirs(os.path.join(_root + '_CAM'), exist_ok=True)
os.makedirs(os.path.join(_root + '_CAM'), exist_ok=True)
train_transform = transforms.Compose([
transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(
root=root, transform=train_transform,
)
loader = torch.utils.data.DataLoader(dataset, batch_size=1)
model = torchvision.models.__dict__[model_name](pretrained=True)
flat = list(model.children())
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(loader.dataset.classes)))
model = nn.Sequential(body, head)
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
model = model.to(device)
model.eval()
activated_features = SaveFeatures(model[0])
for i, (img, target ) in enumerate(loader):
img = img.to(device)
pred = model(img)
import ipdb; ipdb.set_trace()
# get the gradient of the output with respect to the parameters of the model
pred[:, target.item()].backward()
# import ipdb; ipdb.set_trace()
# pull the gradients out of the model
gradients = activated_features.gradients[0]
pooled_gradients = gradients.mean(1).mean(1)
# get the activations of the last convolutional layer
activations = activated_features.activations[0]
heatmap = F.relu(((activations*pooled_gradients[...,None,None])).sum(0))
heatmap /= torch.max(heatmap)
heatmap = heatmap.numpy()
image = cv2.imread(dataset.imgs[i][0])
heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# superimposed_img = heatmap * 0.3 + image * 0.5
superimposed_img = heatmap
clss = dataset.imgs[i][0].split(os.sep)[1]
name = dataset.imgs[i][0].split(os.sep)[2].split('.')[0]
cv2.imwrite(os.path.join(_root+"_CAM", name + '.jpg'), superimposed_img)
print(f'{os.path.join(_root+"_CAM", name + ".jpg")} saved')
activated_features.remove()
if __name__ == "__main__":
main(cam=True)

375
salvador/train.py Normal file
View file

@ -0,0 +1,375 @@
import datetime
import os
import time
import sys
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from PIL import ImageEnhance
import random
import utils
from fastprogress import master_bar, progress_bar
import numpy as np
## DATA AUG ##
import higher
from dataug import *
from dataug_utils import *
tf_names = [
## Geometric TF ##
'Identity',
'FlipUD',
'FlipLR',
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
]
class Lambda(nn.Module):
"Create a layer that simply calls `func` with `x`"
def __init__(self, func):
super().__init__()
self.func=func
def forward(self, x): return self.func(x)
class SubsetSampler(torch.utils.data.SubsetRandomSampler):
def __init__(self, indices):
super().__init__(indices)
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
def sharpness(img, factor):
sharpness_factor = random.uniform(1, factor)
sharp = ImageEnhance.Sharpness(img)
sharped = sharp.enhance(sharpness_factor)
return sharped
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, master_bar, Kldiv=False):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Epoch: {}'.format(epoch)
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=master_bar):
image, target = image.to(device), target.to(device)
if not Kldiv :
output = model(image)
#output = F.log_softmax(output, dim=1)
loss = criterion(output, target) #Pas de softmax ?
else : #Consume x2 memory
model.augment(mode=False)
output = model(image)
model.augment(mode=True)
log_sup=F.log_softmax(output, dim=1)
sup_loss = F.cross_entropy(log_sup, target)
aug_output = model(image)
log_aug=F.log_softmax(aug_output, dim=1)
aug_loss=F.cross_entropy(log_aug, target)
#KL div w/ logits - Similarite predictions (distributions)
KL_loss = F.softmax(output, dim=1)*(log_sup-log_aug)
KL_loss = KL_loss.sum(dim=-1)
#KL_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
KL_loss = KL_loss.mean()
unsupp_coeff = 1
loss = sup_loss + (aug_loss + KL_loss) * unsupp_coeff
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
confmat.update(target.flatten(), output.argmax(1).flatten())
return metric_logger.loss.global_avg, confmat
def evaluate(model, criterion, data_loader, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Test:'
missed = []
with torch.no_grad():
for i, (image, target) in metric_logger.log_every(data_loader, leave=False, header=header, parent=None):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
if target.item() != output.topk(1)[1].item():
missed.append(data_loader.dataset.imgs[data_loader.sampler.indices[i]])
confmat.update(target.flatten(), output.argmax(1).flatten())
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
return metric_logger.loss.global_avg, missed, confmat
def get_train_valid_loader(args, augment, random_seed, valid_size=0.1, shuffle=True, num_workers=4, pin_memory=True):
"""
Utility function for loading and returning train and valid
multi-process iterators over the CIFAR-10 dataset. A sample
9x9 grid of the images can be optionally displayed.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- augment: whether to apply the data augmentation scheme
mentioned in the paper. Only applied on the train split.
- random_seed: fix seed for reproducibility.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
- shuffle: whether to shuffle the train/validation indices.
- show_sample: plot 9x9 sample grid of the dataset.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg = "[!] valid_size should be in the range [0, 1]."
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg
# normalize = transforms.Normalize(
# mean=[0.4914, 0.4822, 0.4465],
# std=[0.2023, 0.1994, 0.2010],
# )
# define transforms
if augment:
train_transform = transforms.Compose([
# transforms.ColorJitter(brightness=0.3),
# transforms.Lambda(lambda img: sharpness(img, 5)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalize,
])
valid_transform = transforms.Compose([
# transforms.ColorJitter(brightness=0.3),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalize,
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
# normalize,
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
# normalize,
])
# load the dataset
train_dataset = torchvision.datasets.ImageFolder(
root=args.data_path, transform=train_transform
)
valid_dataset = torchvision.datasets.ImageFolder(
root=args.data_path, transform=valid_transform
)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
#np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) if not args.test_only else SubsetSampler(train_idx)
valid_sampler = SubsetSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=1, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
imgs = np.asarray(train_dataset.imgs)
# print('Train')
# print(imgs[train_idx])
#print('Valid')
#print(imgs[valid_idx])
tgt = [0,0]
for _, targets in train_loader:
for target in targets:
tgt[target]+=1
print("Train targets :", tgt)
tgt = [0,0]
for _, targets in valid_loader:
for target in targets:
tgt[target]+=1
print("Valid targets :", tgt)
return (train_loader, valid_loader)
def main(args):
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
#augment = True if not args.test_only else False
if not args.test_only and args.augment=='flip' : augment = True
else : augment = False
print("Augment", augment)
data_loader, data_loader_test = get_train_valid_loader(args=args, pin_memory=True, augment=augment,
num_workers=args.workers, valid_size=0.3, random_seed=999)
print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=True)
flat = list(model.children())
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(data_loader.dataset.classes)))
model = nn.Sequential(body, head)
Kldiv=False
if not args.test_only and (args.augment=='Rand' or args.augment=='RandKL'):
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
if args.augment=='RandKL': Kldiv=True
print("Augmodel")
# model.fc = nn.Linear(model.fc.in_features, 2)
# import ipdb; ipdb.set_trace()
criterion = nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.SGD(
# model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
es = utils.EarlyStopping()
if args.test_only:
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
model = model.to(device)
print('TEST')
_, missed, _ = evaluate(model, criterion, data_loader_test, device=device)
print(missed)
print('TRAIN')
_, missed, _ = evaluate(model, criterion, data_loader, device=device)
print(missed)
return
model = model.to(device)
print("Start training")
start_time = time.time()
mb = master_bar(range(args.epochs))
for epoch in mb:
_, train_confmat = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mb, Kldiv)
lr_scheduler.step( (epoch+1)*len(data_loader) )
val_loss, _, valid_confmat = evaluate(model, criterion, data_loader_test, device=device)
es(val_loss, model)
# print('Valid Missed')
# print(valid_missed)
# print('Train')
# print(train_confmat)
print('Valid')
print(valid_confmat)
# if es.early_stop:
# break
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser.add_argument('--data-path', default='/Salvador', help='dataset')
parser.add_argument('--model', default='resnet18', help='model') #'resnet18'
parser.add_argument('--device', default='cuda:1', help='device')
parser.add_argument('-b', '--batch-size', default=8, type=int)
parser.add_argument('--epochs', default=3, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=4e-5, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument('-a', '--augment', default='None', type=str,
metavar='N', help='Data augment',
dest='augment')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)

585
salvador/train_dataug.py Normal file
View file

@ -0,0 +1,585 @@
import datetime
import os
import time
import sys
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from PIL import ImageEnhance
import random
import utils
from fastprogress import master_bar, progress_bar
import numpy as np
## DATA AUG ##
import higher
from dataug import *
from dataug_utils import *
tf_names = [
## Geometric TF ##
'Identity',
'FlipUD',
'FlipLR',
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
]
def compute_vaLoss(model, dl_it, dl):
device = next(model.parameters()).device
try:
xs, ys = next(dl_it)
except StopIteration: #Fin epoch val
dl_it = iter(dl)
xs, ys = next(dl_it)
xs, ys = xs.to(device), ys.to(device)
model.eval() #Validation sans transfornations !
return F.cross_entropy(model(xs), ys)
def model_copy(src,dst, patch_copy=True, copy_grad=True):
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
if patch_copy:
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
#Copie des gradients
if copy_grad:
for paramName, paramValue, in src.named_parameters():
for netCopyName, netCopyValue, in dst.named_parameters():
if paramName == netCopyName:
netCopyValue.grad = paramValue.grad
#netCopyValue=copy.deepcopy(paramValue)
try: #Data_augV4
dst['data_aug']._input_info = src['data_aug']._input_info
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
except:
pass
def optim_copy(dopt, opt):
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
for group_idx, group in enumerate(opt.param_groups):
# print('gp idx',group_idx)
for p_idx, p in enumerate(group['params']):
opt.state[p]=dopt.state[group_idx][p_idx]
#############
class Lambda(nn.Module):
"Create a layer that simply calls `func` with `x`"
def __init__(self, func):
super().__init__()
self.func=func
def forward(self, x): return self.func(x)
class SubsetSampler(torch.utils.data.SubsetRandomSampler):
def __init__(self, indices):
super().__init__(indices)
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
def sharpness(img, factor):
sharpness_factor = random.uniform(1, factor)
sharp = ImageEnhance.Sharpness(img)
sharped = sharp.enhance(sharpness_factor)
return sharped
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, master_bar):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Epoch: {}'.format(epoch)
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=master_bar):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
confmat.update(target.flatten(), output.argmax(1).flatten())
return metric_logger.loss.global_avg, confmat
def evaluate(model, criterion, data_loader, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Test:'
missed = []
with torch.no_grad():
for i, (image, target) in metric_logger.log_every(data_loader, leave=False, header=header, parent=None):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
if target.item() != output.topk(1)[1].item():
missed.append(data_loader.dataset.imgs[data_loader.sampler.indices[i]])
confmat.update(target.flatten(), output.argmax(1).flatten())
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
return metric_logger.loss.global_avg, missed, confmat
def get_train_valid_loader(args, augment, random_seed, train_size=0.5, test_size=0.1, shuffle=True, num_workers=4, pin_memory=True):
"""
Utility function for loading and returning train and valid
multi-process iterators over the CIFAR-10 dataset. A sample
9x9 grid of the images can be optionally displayed.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- augment: whether to apply the data augmentation scheme
mentioned in the paper. Only applied on the train split.
- random_seed: fix seed for reproducibility.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
- shuffle: whether to shuffle the train/validation indices.
- show_sample: plot 9x9 sample grid of the dataset.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg = "[!] test_size should be in the range [0, 1]."
assert ((test_size >= 0) and (test_size <= 1)), error_msg
# normalize = transforms.Normalize(
# mean=[0.4914, 0.4822, 0.4465],
# std=[0.2023, 0.1994, 0.2010],
# )
# define transforms
if augment:
train_transform = transforms.Compose([
# transforms.ColorJitter(brightness=0.3),
# transforms.Lambda(lambda img: sharpness(img, 5)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalize,
])
valid_transform = transforms.Compose([
# transforms.ColorJitter(brightness=0.3),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalize,
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
# normalize,
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
# normalize,
])
# load the dataset
train_dataset = torchvision.datasets.ImageFolder(
root=args.data_path, transform=train_transform
)
test_dataset = torchvision.datasets.ImageFolder(
root=args.data_path, transform=valid_transform
)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(test_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, test_idx = indices[split:], indices[:split]
train_idx, valid_idx = train_idx[:int(len(train_idx)*train_size)], train_idx[int(len(train_idx)*train_size):]
print("\nTrain", len(train_idx), "\nValid", len(valid_idx), "\nTest", len(test_idx))
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) if not args.test_only else SubsetSampler(train_idx)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx) if not args.test_only else SubsetSampler(valid_idx)
test_sampler = SubsetSampler(test_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=1, sampler=test_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
imgs = np.asarray(train_dataset.imgs)
# print('Train')
# print(imgs[train_idx])
#print('Valid')
#print(imgs[valid_idx])
return (train_loader, valid_loader, test_loader)
def main(args):
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
#augment = True if not args.test_only else False
augment = False
data_loader, dl_val, data_loader_test = get_train_valid_loader(args=args, pin_memory=True, augment=augment,
num_workers=args.workers, train_size=0.99, test_size=0.2, random_seed=999)
print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=True)
flat = list(model.children())
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(data_loader.dataset.classes)))
model = nn.Sequential(body, head)
# model.fc = nn.Linear(model.fc.in_features, 2)
# import ipdb; ipdb.set_trace()
criterion = nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.SGD(
# model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
'''
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
'''
es = utils.EarlyStopping()
if args.test_only:
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
model = model.to(device)
print('TEST')
_, missed, _ = evaluate(model, criterion, data_loader_test, device=device)
print(missed)
print('TRAIN')
_, missed, _ = evaluate(model, criterion, data_loader, device=device)
print(missed)
return
model = model.to(device)
print("Start training")
start_time = time.time()
mb = master_bar(range(args.epochs))
"""
for epoch in mb:
_, train_confmat = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mb)
lr_scheduler.step( (epoch+1)*len(data_loader) )
val_loss, _, valid_confmat = evaluate(model, criterion, data_loader_test, device=device)
es(val_loss, model)
# print('Valid Missed')
# print(valid_missed)
# print('Train')
# print(train_confmat)
print('Valid')
print(valid_confmat)
# if es.early_stop:
# break
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
"""
#######
inner_it = args.inner_it
dataug_epoch_start=0
print_freq=1
KLdiv=False
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
#model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
dl_val_it = iter(dl_val)
countcopy=0
#if inner_it!=0:
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=args.lr) #lr=1e-2
#inner_opt = torch.optim.SGD(model['model'].parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #lr=1e-2 / momentum=0.9
inner_opt = torch.optim.Adam(model['model'].parameters(), lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
inner_opt,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
high_grad_track = True
if inner_it == 0:
high_grad_track=False
model.train()
model.augment(mode=False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
i=0
for epoch in mb:
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Epoch: {}'.format(epoch)
t0 = time.process_time()
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=mb):
#for i, (xs, ys) in enumerate(dl_train):
#print_torch_mem("it"+str(i))
i+=1
image, target = image.to(device), target.to(device)
if(not KLdiv):
#Methode uniforme
logits = fmodel(image) # modified `params` can also be passed as a kwarg
output = F.log_softmax(logits, dim=1)
loss = F.cross_entropy(output, target, reduction='none') # no need to call loss.backwards()
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
else:
#Methode KL div
fmodel.augment(mode=False)
sup_logits = fmodel(xs)
log_sup=F.log_softmax(sup_logits, dim=1)
fmodel.augment(mode=True)
loss = F.cross_entropy(log_sup, ys)
if fmodel._data_augmentation:
aug_logits = fmodel(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
aug_loss=0
if epoch>50: #debut differe ?
#KL div w/ logits - Similarite predictions (distributions)
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
aug_loss=aug_loss.sum(dim=-1)
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
aug_loss = (w_loss * aug_loss).mean()
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
#print(aug_loss)
unsupp_coeff = 1
loss += aug_loss * unsupp_coeff
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
if(high_grad_track and i%inner_it==0): #Perform Meta step
#print("meta")
#Peu utile si high_grad_track = False
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
#print_graph(val_loss)
val_loss.backward()
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
#if epoch>50:
meta_opt.step()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
#model['data_aug'].next_TF_set()
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
confmat.update(target.flatten(), output.argmax(1).flatten())
if(not high_grad_track and (torch.cuda.memory_cached()/1024.0**2)>20000):
countcopy+=1
print_torch_mem("copy")
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
print_torch_mem("copy")
if(not high_grad_track):
countcopy+=1
print_torch_mem("end copy")
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
print_torch_mem("end copy")
tf = time.process_time()
#### Print ####
if(print_freq and epoch%print_freq==0):
print('-'*9)
print('Epoch : %d'%(epoch))
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
print('TF Proba :', model['data_aug']['prob'].data)
#print('proba grad',model['data_aug']['prob'].grad)
print('TF Mag :', model['data_aug']['mag'].data)
#print('Mag grad',model['data_aug']['mag'].grad)
#print('Reg loss:', model['data_aug'].reg_loss().item())
#print('Aug loss', aug_loss.item())
#############
#### Log ####
#print(type(model['data_aug']) is dataug.Data_augV5)
'''
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": param #if isinstance(model['data_aug'], Data_augV5)
#else [p.item() for p in model['data_aug']['prob']],
}
log.append(data)
'''
#############
train_confmat=confmat
lr_scheduler.step( (epoch+1)*len(data_loader) )
test_loss, _, test_confmat = evaluate(model, criterion, data_loader_test, device=device)
es(test_loss, model)
# print('Valid Missed')
# print(valid_missed)
# print('Train')
# print(train_confmat)
print('Test')
print(test_confmat)
# if es.early_stop:
# break
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser.add_argument('--data-path', default='/Salvador', help='dataset')
parser.add_argument('--model', default='resnet50', help='model')
parser.add_argument('--device', default='cuda:1', help='device')
parser.add_argument('-b', '--batch-size', default=4, type=int)
parser.add_argument('--epochs', default=3, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=4e-5, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument('--in_it', '--inner_it', default=0, type=int,
metavar='N', help='higher inner_it',
dest='inner_it')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)

346
salvador/transformations.py Executable file
View file

@ -0,0 +1,346 @@
import torch
import kornia
import random
### Available TF for Dataug ###
'''
TF_dict={ #Dataugv4
## Geometric TF ##
'Identity' : (lambda x, mag: x),
'FlipUD' : (lambda x, mag: flipUD(x)),
'FlipLR' : (lambda x, mag: flipLR(x)),
'Rotate': (lambda x, mag: rotate(x, angle=torch.tensor([rand_int(mag, maxval=30)for _ in x], device=x.device))),
'TranslateX': (lambda x, mag: translate(x, translation=torch.tensor([[rand_int(mag, maxval=20), 0] for _ in x], device=x.device))),
'TranslateY': (lambda x, mag: translate(x, translation=torch.tensor([[0, rand_int(mag, maxval=20)] for _ in x], device=x.device))),
'ShearX': (lambda x, mag: shear(x, shear=torch.tensor([[rand_float(mag, maxval=0.3), 0] for _ in x], device=x.device))),
'ShearY': (lambda x, mag: shear(x, shear=torch.tensor([[0, rand_float(mag, maxval=0.3)] for _ in x], device=x.device))),
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))),
'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch
#Non fonctionnel
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None),
}
'''
'''
TF_dict={ #Dataugv5 #AutoAugment
## Geometric TF ##
'Identity' : (lambda x, mag: x),
'FlipUD' : (lambda x, mag: flipUD(x)),
'FlipLR' : (lambda x, mag: flipLR(x)),
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
#Non fonctionnel
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None),
}
'''
TF_dict={ #Dataugv5
## Geometric TF ##
'Identity' : (lambda x, mag: x),
'FlipUD' : (lambda x, mag: flipUD(x)),
'FlipLR' : (lambda x, mag: flipLR(x)),
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
#Color TF (Common mag scale)
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
'BRotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30*3))),
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=0))),
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=1))),
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3*3), zero_pos=0))),
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3*3), zero_pos=1))),
'BadTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=0))),
'BadTranslateX_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=0))),
'BadTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=1))),
'BadTranslateY_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=1))),
'BadColor':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
'BadSharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
#Non fonctionnel
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None),
}
TF_no_mag={'Identity', 'FlipUD', 'FlipLR'}
TF_ignore_mag= TF_no_mag | {'Solarize', 'Posterize'}
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
return (float_image*255.).type(torch.uint8)
def float_image(int_image):
return int_image.type(torch.float)/255.
#def rand_inverse(value):
# return value if random.random() < 0.5 else -value
#def rand_int(mag, maxval, minval=None): #[(-maxval,minval), maxval]
# real_max = int_parameter(mag, maxval=maxval)
# if not minval : minval = -real_max
# return random.randint(minval, real_max)
#def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval]
# real_max = float_parameter(mag, maxval=maxval)
# if not minval : minval = -real_max
# return random.uniform(minval, real_max)
def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
real_mag = float_parameter(mag, maxval=maxval)
if not minval : minval = -real_mag
#return random.uniform(minval, real_max)
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
def invScale_rand_floats(size, mag, maxval, minval):
#Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval]
real_mag = float_parameter(float(PARAMETER_MAX) - mag, maxval=maxval-minval)+minval
return real_mag + (maxval-real_mag) * torch.rand(size, device=mag.device) #[real_mag, max_val]
def zero_stack(tensor, zero_pos):
if zero_pos==0:
return torch.stack((tensor, torch.zeros((tensor.shape[0],), device=tensor.device)), dim=1)
if zero_pos==1:
return torch.stack((torch.zeros((tensor.shape[0],), device=tensor.device), tensor), dim=1)
else:
raise Exception("Invalid zero_pos : ", zero_pos)
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted
def float_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
A float that results from scaling `maxval` according to `level`.
"""
#return float(level) * maxval / PARAMETER_MAX
return (level * maxval / PARAMETER_MAX)#.to(torch.float)
#def int_parameter(level, maxval): #Perte de gradient
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
An int that results from scaling `maxval` according to `level`.
"""
#return int(level * maxval / PARAMETER_MAX)
# return (level * maxval / PARAMETER_MAX)
def flipLR(x):
device = x.device
(batch_size, channels, h, w) = x.shape
M =torch.tensor( [[[-1., 0., w-1],
[ 0., 1., 0.],
[ 0., 0., 1.]]], device=device).expand(batch_size,-1,-1)
# warp the original image by the found transform
return kornia.warp_perspective(x, M, dsize=(h, w))
def flipUD(x):
device = x.device
(batch_size, channels, h, w) = x.shape
M =torch.tensor( [[[ 1., 0., 0.],
[ 0., -1., h-1],
[ 0., 0., 1.]]], device=device).expand(batch_size,-1,-1)
# warp the original image by the found transform
return kornia.warp_perspective(x, M, dsize=(h, w))
def rotate(x, angle):
return kornia.rotate(x, angle=angle.type(torch.float)) #Kornia ne supporte pas les int
def translate(x, translation):
#print(translation)
return kornia.translate(x, translation=translation.type(torch.float)) #Kornia ne supporte pas les int
def shear(x, shear):
return kornia.shear(x, shear=shear)
def contrast(x, contrast_factor):
return kornia.adjust_contrast(x, contrast_factor=contrast_factor) #Expect image in the range of [0, 1]
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageEnhance.py
def color(x, color_factor):
(batch_size, channels, h, w) = x.shape
gray_x = kornia.rgb_to_grayscale(x)
gray_x = gray_x.repeat_interleave(channels, dim=1)
return blend(gray_x, x, color_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
def brightness(x, brightness_factor):
device = x.device
return blend(torch.zeros(x.size(), device=device), x, brightness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
def sharpeness(x, sharpness_factor):
device = x.device
(batch_size, channels, h, w) = x.shape
k = torch.tensor([[[ 1., 1., 1.],
[ 1., 5., 1.],
[ 1., 1., 1.]]], device=device) #Smooth Filter : https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageFilter.py
smooth_x = kornia.filter2D(x, kernel=k, border_type='reflect', normalized=True) #Peut etre necessaire de s'occuper du channel Alhpa differement
return blend(smooth_x, x, sharpness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
def posterize(x, bits):
bits = bits.type(torch.uint8) #Perte du gradient
x = int_image(x) #Expect image in the range of [0, 1]
mask = ~(2 ** (8 - bits) - 1).type(torch.uint8)
(batch_size, channels, h, w) = x.shape
mask = mask.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
return float_image(x & mask)
def auto_contrast(x): #PAS OPTIMISE POUR DES BATCH #EXTRA LENT
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
print("Warning : Pas encore check !")
(batch_size, channels, h, w) = x.shape
x = int_image(x) #Expect image in the range of [0, 1]
#print('Start',x[0])
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
#print(img.shape)
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
#print(chan.shape)
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
# find lowest/highest samples after preprocessing
for lo in range(256):
if hist[lo]:
break
for hi in range(255, -1, -1):
if hist[hi]:
break
if hi <= lo:
# don't bother
pass
else:
scale = 255.0 / (hi - lo)
offset = -lo * scale
for ix in range(256):
n_ix = int(ix * scale + offset)
if n_ix < 0: n_ix = 0
elif n_ix > 255: n_ix = 255
chan[chan==ix]=n_ix
x[im_idx, chan_idx]=chan
#print('End',x[0])
return float_image(x)
def equalize(x): #PAS OPTIMISE POUR DES BATCH
raise Exception(self, "not implemented")
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
(batch_size, channels, h, w) = x.shape
x = int_image(x) #Expect image in the range of [0, 1]
#print('Start',x[0])
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
#print(img.shape)
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
#print(chan.shape)
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
return float_image(x)
def solarize(x, thresholds):
batch_size, channels, h, w = x.shape
#imgs=[]
#for idx, t in enumerate(thresholds): #Operation par image
# mask = x[idx] > t #Perte du gradient
#In place
# inv_x = 1-x[idx][mask]
# x[idx][mask]=inv_x
#
#Out of place
# im = x[idx]
# inv_x = 1-im[mask]
# imgs.append(im.masked_scatter(mask,inv_x))
#idxs=torch.tensor(range(x.shape[0]), device=x.device)
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
#
thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
#print(thresholds.grad_fn)
x=torch.where(x>thresholds,1-x, x)
#print(mask.grad_fn)
#x=x.min(thresholds)
#inv_x = 1-x[mask]
#x=x.where(x<thresholds,1-x)
#x[mask]=inv_x
#x=x.masked_scatter(mask, inv_x)
return x
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
def blend(x,y,alpha): #out = image1 * (1.0 - alpha) + image2 * alpha
#return kornia.add_weighted(src1=x, alpha=(1-alpha), src2=y, beta=alpha, gamma=0) #out=src1alpha+src2beta+gamma #Ne fonctionne pas pour des batch de alpha
if not isinstance(x, torch.Tensor):
raise TypeError("x should be a tensor. Got {}".format(type(x)))
if not isinstance(y, torch.Tensor):
raise TypeError("y should be a tensor. Got {}".format(type(y)))
(batch_size, channels, h, w) = x.shape
alpha = alpha.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
res = x*(1-alpha) + y*alpha
return res

200
salvador/utils.py Normal file
View file

@ -0,0 +1,200 @@
from __future__ import print_function
from collections import defaultdict, deque
import datetime
import math
import time
import torch
import numpy as np
import os
from fastprogress import progress_bar
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{global_avg:.4f}"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class ConfusionMatrix(object):
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None
def update(self, a, b):
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.no_grad():
k = (a >= 0) & (a < n)
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
self.mat.zero_()
def compute(self):
h = self.mat.float()
acc_global = torch.diag(h).sum() / h.sum()
acc = torch.diag(h) / h.sum(1)
return acc_global, acc
def __str__(self):
acc_global, acc = self.compute()
return (
'global correct: {:.1f}\n'
'average row correct: {}').format(
acc_global.item() * 100,
['{:.1f}'.format(i) for i in (acc * 100).tolist()])
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, parent, header=None, **kwargs):
if not header:
header = ''
log_msg = self.delimiter.join([
'{meters}'
])
progrss = progress_bar(iterable, parent=parent, **kwargs)
for idx, obj in enumerate(progrss):
yield idx, obj
progrss.comment = log_msg.format(
meters=str(self))
print('{header} {meters}'.format(header=header, meters=str(self)))
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target[None])
res = []
for k in topk:
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
res.append(correct_k * (100.0 / batch_size))
return res
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score - self.delta:
self.counter += 1
# print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
# if self.counter >= self.patience:
# self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), 'checkpoint.pt')
self.val_loss_min = val_loss