From f2019aae4af148643c4e980b13a3a46fb041668d Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Wed, 22 Jan 2020 11:15:56 -0500 Subject: [PATCH] Stockage code inutile dans old --- higher/datasets.py | 120 ---- higher/dataug.py | 983 +++--------------------------- higher/model.py | 595 ------------------ higher/old/dataug_old.py | 1065 +++++++++++++++++++++++++++++++++ higher/old/higher_repro.py | 85 +++ higher/old/model_old.py | 502 ++++++++++++++++ higher/{ => old}/test_lr.py | 0 higher/old/train_utils_old.py | 590 ++++++++++++++++++ higher/old/utils_old.py | 161 +++++ higher/train_utils.py | 583 +----------------- higher/transformations.py | 235 +++++--- higher/utils.py | 137 ----- 12 files changed, 2649 insertions(+), 2407 deletions(-) create mode 100644 higher/old/dataug_old.py create mode 100644 higher/old/higher_repro.py create mode 100644 higher/old/model_old.py rename higher/{ => old}/test_lr.py (100%) create mode 100644 higher/old/train_utils_old.py create mode 100644 higher/old/utils_old.py diff --git a/higher/datasets.py b/higher/datasets.py index 2c06c7e..d7e84e2 100755 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -38,126 +38,6 @@ from PIL import Image import augmentation_transforms import numpy as np -class AugmentedDataset(VisionDataset): - def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None): - - super(AugmentedDataset, self).__init__(root, transform=transform, target_transform=target_transform) - - supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform) - - self.sup_data = supervised_dataset.data if not subset else supervised_dataset.data[subset[0]:subset[1]] - self.sup_targets = supervised_dataset.targets if not subset else supervised_dataset.targets[subset[0]:subset[1]] - assert len(self.sup_data)==len(self.sup_targets) - - for idx, img in enumerate(self.sup_data): - self.sup_data[idx]= Image.fromarray(img) #to PIL Image - - self.unsup_data=[] - self.unsup_targets=[] - - self.data= self.sup_data - self.targets= self.sup_targets - - self.dataset_info= { - 'name': 'CIFAR10', - 'sup': len(self.sup_data), - 'unsup': len(self.unsup_data), - 'length': len(self.sup_data)+len(self.unsup_data), - } - - - self._TF = [ - ## Geometric TF ## - 'Rotate', - 'TranslateX', - 'TranslateY', - 'ShearX', - 'ShearY', - - 'Cutout', - - ## Color TF ## - 'Contrast', - 'Color', - 'Brightness', - 'Sharpness', - #'Posterize', - #'Solarize', - - 'Invert', - 'AutoContrast', - 'Equalize', - ] - self._op_list =[] - self.prob=0.5 - for tf in self._TF: - for mag in range(1, 10): - self._op_list+=[(tf, self.prob, mag)] - self._nb_op = len(self._op_list) - - def __getitem__(self, index): - """ - Args: - index (int): Index - - Returns: - tuple: (image, target) where target is index of the target class. - """ - img, target = self.data[index], self.targets[index] - - # doing this so that it is consistent with all other datasets - # to return a PIL Image - #img = Image.fromarray(img) - - if self.transform is not None: - img = self.transform(img) - - if self.target_transform is not None: - target = self.target_transform(target) - - return img, target - - def augement_data(self, aug_copy=1): - - policies = [] - for op_1 in self._op_list: - for op_2 in self._op_list: - policies += [[op_1, op_2]] - - for idx, image in enumerate(self.sup_data): - if (idx/self.dataset_info['sup'])%0.2==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup']) - #if idx==10000:break - - for _ in range(aug_copy): - chosen_policy = policies[np.random.choice(len(policies))] - aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image - #aug_image = augmentation_transforms.cutout_numpy(aug_image) - - self.unsup_data+=[(aug_image*255.).astype(self.sup_data.dtype)]#Cast float image to uint8 - self.unsup_targets+=[self.sup_targets[idx]] - - #self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8 - self.unsup_data=np.array(self.unsup_data) - self.data= np.concatenate((self.sup_data, self.unsup_data), axis=0) - self.targets= np.concatenate((self.sup_targets, self.unsup_targets), axis=0) - - assert len(self.unsup_data)==len(self.unsup_targets) - assert len(self.data)==len(self.targets) - self.dataset_info['unsup']=len(self.unsup_data) - self.dataset_info['length']=self.dataset_info['sup']+self.dataset_info['unsup'] - - def len_supervised(self): - return self.dataset_info['sup'] - - def len_unsupervised(self): - return self.dataset_info['unsup'] - - def __len__(self): - return self.dataset_info['length'] - - def __str__(self): - return "CIFAR10(Sup:{}-Unsup:{}-{}TF)".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF)) - class AugmentedDatasetV2(VisionDataset): def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None): diff --git a/higher/dataug.py b/higher/dataug.py index b33256a..9a0db02 100755 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -1,3 +1,12 @@ +""" Data augmentation modules. + + Features a custom implementaiton of RandAugment (RandAug), as well as a data augmentation modules allowing gradient propagation. + + Typical usage: + + aug_model = Augmented_model(Data_AugV5, model) +""" + import torch import torch.nn as nn import torch.nn.functional as F @@ -10,526 +19,6 @@ import copy import transformations as TF -class Data_aug(nn.Module): #Rotation parametree - def __init__(self): - super(Data_aug, self).__init__() - self._data_augmentation = True - self._params = nn.ParameterDict({ - "prob": nn.Parameter(torch.tensor(0.5)), - "mag": nn.Parameter(torch.tensor(1.0)) - }) - - #self.params["mag"].register_hook(print) - - def forward(self, x): - - if self._data_augmentation and random.random() < self._params["prob"]: - #print('Aug') - batch_size = x.shape[0] - # create transformation (rotation) - alpha = self._params["mag"]*180 # in degrees - angle = torch.ones(batch_size, device=x.device) * alpha - - # define the rotation center - center = torch.ones(batch_size, 2, device=x.device) - center[..., 0] = x.shape[3] / 2 # x - center[..., 1] = x.shape[2] / 2 # y - - #print(x.shape, center) - # define the scale factor - scale = torch.ones(batch_size, device=x.device) - - # compute the transformation matrix - M = kornia.get_rotation_matrix2d(center, angle, scale) - - # apply the transformation to original image - x = kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w) - - return x - - def eval(self): - self.augment(mode=False) - nn.Module.eval(self) - - def augment(self, mode=True): - self._data_augmentation=mode - - def __getitem__(self, key): - return self._params[key] - - def __str__(self): - return "Data_aug(Mag-1 TF)" - -class Data_augV2(nn.Module): #Methode exacte - def __init__(self): - super(Data_augV2, self).__init__() - self._data_augmentation = True - - self._fixed_transf=[0.0, 45.0, 180.0] #Degree rotation - #self._fixed_transf=[0.0] - self._nb_tf= len(self._fixed_transf) - - self._params = nn.ParameterDict({ - "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme - #"prob2": nn.Parameter(torch.ones(len(self._fixed_transf)).softmax(dim=0)) - }) - - #print(self._params["prob"], self._params["prob2"]) - - self.transf_idx=0 - - def forward(self, x): - - if self._data_augmentation: - #print('Aug',self._fixed_transf[self.transf_idx]) - device = x.device - batch_size = x.shape[0] - - # create transformation (rotation) - #alpha = 180 # in degrees - alpha = self._fixed_transf[self.transf_idx] - angle = torch.ones(batch_size, device=device) * alpha - - x = self.rotate(x,angle) - - return x - - def rotate(self, x, angle): - - device = x.device - batch_size = x.shape[0] - # define the rotation center - center = torch.ones(batch_size, 2, device=device) - center[..., 0] = x.shape[3] / 2 # x - center[..., 1] = x.shape[2] / 2 # y - - #print(x.shape, center) - # define the scale factor - scale = torch.ones(batch_size, device=device) - - # compute the transformation matrix - M = kornia.get_rotation_matrix2d(center, angle, scale) - - # apply the transformation to original image - return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w) - - - def adjust_param(self): #Detach from gradient ? - self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) - #print('proba',self._params['prob']) - self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 - #print('Sum p', sum(self._params['prob'])) - - def eval(self): - self.augment(mode=False) - nn.Module.eval(self) - - def augment(self, mode=True): - self._data_augmentation=mode - - def __getitem__(self, key): - return self._params[key] - - def __str__(self): - return "Data_augV2(Exact-%d TF)" % self._nb_tf - -class Data_augV3(nn.Module): #Echantillonage uniforme/Mixte - def __init__(self, mix_dist=0.0): - super(Data_augV3, self).__init__() - self._data_augmentation = True - - #self._fixed_transf=[0.0, 45.0, 180.0] #Degree rotation - self._fixed_transf=[0.0, 1.0, -1.0] #Flips (Identity,Horizontal,Vertical) - #self._fixed_transf=[0.0] - self._nb_tf= len(self._fixed_transf) - - self._params = nn.ParameterDict({ - "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme - #"prob2": nn.Parameter(torch.ones(len(self._fixed_transf)).softmax(dim=0)) - }) - - #print(self._params["prob"], self._params["prob2"]) - self._sample = [] - - self._mix_dist = False - if mix_dist != 0.0: - self._mix_dist = True - self._mix_factor = max(min(mix_dist, 1.0), 0.0) - - def forward(self, x): - - if self._data_augmentation: - device = x.device - batch_size = x.shape[0] - - - #good_distrib = Uniform(low=torch.zeros(batch_size,1, device=device),high=torch.new_full((batch_size,1),self._params["prob"], device=device)) - #bad_distrib = Uniform(low=torch.zeros(batch_size,1, device=device),high=torch.new_full((batch_size,1), 1-self._params["prob"], device=device)) - - #transform_dist = Categorical(probs=torch.tensor([self._params["prob"], 1-self._params["prob"]], device=device)) - #self._sample = transform_dist._sample(sample_shape=torch.Size([batch_size,1])) - - uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=0) - - if not self._mix_dist: - distrib = uniforme_dist - else: - distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=0) #Mix distrib reel / uniforme avec mix_factor - - cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*distrib) - self._sample = cat_distrib.sample() - - TF_param = torch.tensor([self._fixed_transf[x] for x in self._sample], device=device) #Approche de marco peut-etre plus rapide - - #x = self.rotate(x,angle=TF_param) - x = self.flip(x,flip_mat=TF_param) - - return x - - def rotate(self, x, angle): - - device = x.device - batch_size = x.shape[0] - # define the rotation center - center = torch.ones(batch_size, 2, device=device) - center[..., 0] = x.shape[3] / 2 # x - center[..., 1] = x.shape[2] / 2 # y - - #print(x.shape, center) - # define the scale factor - scale = torch.ones(batch_size, device=device) - - # compute the transformation matrix - M = kornia.get_rotation_matrix2d(center, angle, scale) - - # apply the transformation to original image - return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w) - - def flip(self, x, flip_mat): - - #print(flip_mat) - device = x.device - batch_size = x.shape[0] - - h, w = x.shape[2], x.shape[3] # destination size - #points_src = torch.ones(batch_size, 4, 2, device=device) - #points_dst = torch.ones(batch_size, 4, 2, device=device) - - #Identity - iM=torch.tensor(np.eye(3)) - - #Horizontal flip - # the source points are the region to crop corners - #points_src = torch.FloatTensor([[ - # [w - 1, 0], [0, 0], [0, h - 1], [w - 1, h - 1], - #]]) - # the destination points are the image vertexes - #points_dst = torch.FloatTensor([[ - # [0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1], - #]]) - # compute perspective transform - #hM = kornia.get_perspective_transform(points_src, points_dst) - hM =torch.tensor( [[[-1., 0., w-1], - [ 0., 1., 0.], - [ 0., 0., 1.]]]) - - #Vertical flip - # the source points are the region to crop corners - #points_src = torch.FloatTensor([[ - # [0, h - 1], [w - 1, h - 1], [w - 1, 0], [0, 0], - #]]) - # the destination points are the image vertexes - #points_dst = torch.FloatTensor([[ - # [0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1], - #]]) - # compute perspective transform - #vM = kornia.get_perspective_transform(points_src, points_dst) - vM =torch.tensor( [[[ 1., 0., 0.], - [ 0., -1., h-1], - [ 0., 0., 1.]]]) - #print(vM) - - M=torch.ones(batch_size, 3, 3, device=device) - - for i in range(batch_size): # A optimiser - if flip_mat[i]==0.0: - M[i,]=iM - elif flip_mat[i]==1.0: - M[i,]=hM - elif flip_mat[i]==-1.0: - M[i,]=vM - - # warp the original image by the found transform - return kornia.warp_perspective(x, M, dsize=(h, w)) - - def adjust_param(self, soft=False): #Detach from gradient ? - - if soft : - self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible - else: - #self._params['prob'].clamp(min=0.0,max=1.0) - self._params['prob'].data = F.relu(self._params['prob'].data) - #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) - #print('proba',self._params['prob']) - self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 - #print('Sum p', sum(self._params['prob'])) - - def loss_weight(self): - #w_loss = [self._params["prob"][x] for x in self._sample] - #print(self._sample.view(-1,1).shape) - #print(self._sample[:10]) - - w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device) - w_loss.scatter_(1, self._sample.view(-1,1), 1) - #print(w_loss.shape) - #print(w_loss[:10,:]) - w_loss = w_loss * self._params["prob"] - #print(w_loss.shape) - #print(w_loss[:10,:]) - w_loss = torch.sum(w_loss,dim=1) - #print(w_loss.shape) - #print(w_loss[:10]) - return w_loss - - def train(self, mode=None): - if mode is None : - mode=self._data_augmentation - self.augment(mode=mode) #Inutile si mode=None - super(Data_augV3, self).train(mode) - - def eval(self): - self.train(mode=False) - #super(Augmented_model, self).eval() - - def augment(self, mode=True): - self._data_augmentation=mode - - def __getitem__(self, key): - return self._params[key] - - def __str__(self): - if not self._mix_dist: - return "Data_augV3(Uniform-%d TF)" % self._nb_tf - else: - return "Data_augV3(Mix %.1f-%d TF)" % (self._mix_factor, self._nb_tf) - -class Data_augV4(nn.Module): #Transformations avec mask - def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0): - super(Data_augV4, self).__init__() - assert len(TF_dict)>0 - - self._data_augmentation = True - - #self._TF_matrix={} - #self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix - #self._mag_fct = TF_dict - self._TF_dict = TF_dict - self._TF= list(self._TF_dict.keys()) - self._nb_tf= len(self._TF) - - self._N_seqTF = N_TF - - self._fixed_mag=5 #[0, PARAMETER_MAX] - self._params = nn.ParameterDict({ - "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme - }) - - self._samples = [] - - self._mix_dist = False - if mix_dist != 0.0: - self._mix_dist = True - self._mix_factor = max(min(mix_dist, 1.0), 0.0) - - def forward(self, x): - if self._data_augmentation: - device = x.device - batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] - - x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) - self._samples = [] - - for _ in range(self._N_seqTF): - ## Echantillonage ## - uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) - - if not self._mix_dist: - self._distrib = uniforme_dist - else: - self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor - - cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) - sample = cat_distrib.sample() - self._samples.append(sample) - - ## Transformations ## - x = self.apply_TF(x, sample) - return x - ''' - def compute_TF_matrix(self, magnitude=None, sample_info= None): - print('Computing TF_matrix...') - if not magnitude : - magnitude=self._fixed_mag - - if sample_info: - self._input_info['h']= sample_info['h'] - self._input_info['w']= sample_info['w'] - self._input_info['device'] = sample_info['device'] - h, w, device= self._input_info['h'], self._input_info['w'], self._input_info['device'] - - self._TF_matrix={} - for tf in self._TF : - if tf=='Id': - self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.], - [ 0., 1., 0.], - [ 0., 0., 1.]]], device=device) - elif tf=='Rot': - center = torch.ones(1, 2, device=device) - center[0, 0] = w / 2 # x - center[0, 1] = h / 2 # y - scale = torch.ones(1, device=device) - angle = self._mag_fct[tf](magnitude) * torch.ones(1, device=device) - R = kornia.get_rotation_matrix2d(center, angle, scale) #Rotation matrix (1,2,3) - self._TF_matrix[tf]=torch.cat((R,torch.tensor([[[ 0., 0., 1.]]], device=device)), dim=1) #TF matrix (1,3,3) - elif tf=='FlipLR': - self._TF_matrix[tf]=torch.tensor([[[-1., 0., w-1], - [ 0., 1., 0.], - [ 0., 0., 1.]]], device=device) - elif tf=='FlipUD': - self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.], - [ 0., -1., h-1], - [ 0., 0., 1.]]], device=device) - else: - raise Exception("Invalid TF requested") - ''' - def apply_TF(self, x, sampled_TF): - device = x.device - smps_x=[] - masks=[] - for tf_idx in range(self._nb_tf): - mask = sampled_TF==tf_idx #Create selection mask - smp_x = x[mask] #torch.masked_select() ? - - if smp_x.shape[0]!=0: #if there's data to TF - magnitude=self._fixed_mag - tf=self._TF[tf_idx] - - ''' - ## Geometric TF ## - if tf=='Identity': - pass - elif tf=='FlipLR': - smp_x = TF.flipLR(smp_x) - elif tf=='FlipUD': - smp_x = TF.flipUD(smp_x) - elif tf=='Rotate': - smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='TranslateX' or tf=='TranslateY': - smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='ShearX' or tf=='ShearY' : - smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - - ## Color TF (Expect image in the range of [0, 1]) ## - elif tf=='Contrast': - smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Color': - smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Brightness': - smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Sharpness': - smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Posterize': - smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device)) - elif tf=='Solarize': - smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Equalize': - smp_x = TF.equalize(smp_x) - elif tf=='Auto_Contrast': - smp_x = TF.auto_contrast(smp_x) - else: - raise Exception("Invalid TF requested : ", tf) - - x[mask]=smp_x # Refusionner eviter x[mask] : in place - ''' - x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place - - #idx= mask.nonzero() - #print('-'*8) - #print(idx[0], tf_idx) - #print(smp_x[0,]) - #x=x.view(-1,3*32*32) - #x=x.scatter(dim=0, index=idx, src=smp_x.view(-1,3*32*32)) #Changement des Tensor mais pas visible sur la visualisation... - #x=x.view(-1,3,32,32) - #print(x[0,]) - - ''' - if len(self._TF_matrix)==0 or self._input_info['h']!=h or self._input_info['w']!=w or self._input_info['device']!=device: #Device different:Pas necessaire de tout recalculer - self.compute_TF_matrix(sample_info={'h': x.shape[2], - 'w': x.shape[3], - 'device': x.device}) - - TF_matrix = torch.zeros(batch_size, 3, 3, device=device) #All geom TF - - for tf_idx in range(self._nb_tf): - mask = self._sample==tf_idx #Create selection mask - TF_matrix[mask,]=self._TF_matrix[self._TF[tf_idx]] - - x=kornia.warp_perspective(x, TF_matrix, dsize=(h, w)) - ''' - return x - - def adjust_param(self, soft=False): #Detach from gradient ? - - if soft : - self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible - else: - #self._params['prob'].clamp(min=0.0,max=1.0) - self._params['prob'].data = F.relu(self._params['prob'].data) - #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) - - self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 - - def loss_weight(self): - # 1 seule TF - #self._sample = self._samples[-1] - #w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device) - #w_loss.scatter_(dim=1, index=self._sample.view(-1,1), value=1) - #w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) - #w_loss = torch.sum(w_loss,dim=1) - - #Plusieurs TF sequentielles - w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device) - for sample in self._samples: - tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) - tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF) - w_loss += tmp_w - - w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) - w_loss = torch.sum(w_loss,dim=1) - return w_loss - - - def train(self, mode=None): - if mode is None : - mode=self._data_augmentation - self.augment(mode=mode) #Inutile si mode=None - super(Data_augV4, self).train(mode) - - def eval(self): - self.train(mode=False) - - def augment(self, mode=True): - self._data_augmentation=mode - - def __getitem__(self, key): - return self._params[key] - - def __str__(self): - if not self._mix_dist: - return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF) - else: - return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF) - class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) """Data augmentation module with learnable parameters. @@ -800,229 +289,6 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) else: return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param) - -class Data_augV6(nn.Module): #Optimisation sequentielle #Mauvais resultats - def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=None, fixed_mag=True, shared_mag=True): - super(Data_augV6, self).__init__() - assert len(TF_dict)>0 - - self._data_augmentation = True - - self._TF_dict = TF_dict - self._TF= list(self._TF_dict.keys()) - self._nb_tf= len(self._TF) - - self._N_seqTF = N_TF - self._shared_mag = shared_mag - self._fixed_mag = fixed_mag - - self._TF_set_size = prob_set_size if prob_set_size else self._nb_tf - - self._fixed_TF=[0] #Identite - assert self._TF_set_size>=len(self._fixed_TF) - - if self._TF_set_size>self._nb_tf: - print("Warning : TF sets size higher than number of TF. Reducing set size to %d"%self._nb_tf) - self._TF_set_size=self._nb_tf - - ## Genenerate TF sets ## - if self._TF_set_size==len(self._fixed_TF): - print("Warning : using only fixed set of TF : ", self._fixed_TF) - self._TF_sets=torch.tensor([self._fixed_TF]) - else: - def generate_TF_sets(n_TF, set_size, idx_prefix=[]): - TF_sets=[] - if len(idx_prefix)!=0: - if set_size>2: - for i in range(idx_prefix[-1]+1, n_TF): - TF_sets += generate_TF_sets(n_TF=n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i]) - else: - #if i not in idx_prefix: - TF_sets+=[torch.tensor(idx_prefix+[i]) for i in range(idx_prefix[-1]+1, n_TF)] - elif set_size>1: - for i in range(0, n_TF): - TF_sets += generate_TF_sets(n_TF=n_TF, set_size=set_size, idx_prefix=[i]) - else: - TF_sets+=[torch.tensor([i]) for i in range(0, n_TF)] - return TF_sets - - self._TF_sets=generate_TF_sets(self._nb_tf, self._TF_set_size, self._fixed_TF) - - ## Plan TF learning schedule ## - self._TF_schedule = [list(range(len(self._TF_sets))) for _ in range(self._N_seqTF)] - for n_tf in range(self._N_seqTF) : - TF.random.shuffle(self._TF_schedule[n_tf]) - - self._current_TF_idx=0 #random.randint - self._start_prob = 1/self._TF_set_size - - - self._params = nn.ParameterDict({ - "prob": nn.Parameter(torch.tensor(self._start_prob).expand(self._nb_tf)), #Proba independantes - "mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)) if self._shared_mag - else torch.tensor(float(TF.PARAMETER_MAX)).expand(self._nb_tf)), #[0, PARAMETER_MAX] - }) - - #for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag - - #Distribution - self._fixed_prob=fixed_prob - self._samples = [] - self._mix_dist = False - if mix_dist != 0.0: - self._mix_dist = True - self._mix_factor = max(min(mix_dist, 0.999), 0.0) - - #Mag regularisation - if not self._fixed_mag: - if self._shared_mag : - self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max - else: - self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag] - self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max - - def forward(self, x): - self._samples = [] - if self._data_augmentation:# and TF.random.random() < 0.5: - device = x.device - batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] - - x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) - - for n_tf in range(self._N_seqTF): - - tf_set = self._TF_sets[self._TF_schedule[n_tf][self._current_TF_idx]].to(device) - #print(n_tf, tf_set) - ## Echantillonage ## - uniforme_dist = torch.ones(1,len(tf_set),device=device).softmax(dim=1) - - if not self._mix_dist: - self._distrib = uniforme_dist - else: - prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] - curr_prob = torch.index_select(prob, 0, tf_set) - curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1 - self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor - - cat_distrib= Categorical(probs=torch.ones((batch_size, len(tf_set)), device=device)*self._distrib) - sample = cat_distrib.sample() - self._samples.append(sample) - - ## Transformations ## - x = self.apply_TF(x, sample) - return x - - def apply_TF(self, x, sampled_TF): - device = x.device - batch_size, channels, h, w = x.shape - smps_x=[] - - for sel_idx, tf_idx in enumerate(self._TF_sets[self._current_TF_idx]): - mask = sampled_TF==sel_idx #Create selection mask - smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim) - - if smp_x.shape[0]!=0: #if there's data to TF - magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx] - if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param - - tf=self._TF[tf_idx] - #print(magnitude) - - #In place - #x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) - - #Out of place - smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude) - idx= mask.nonzero() - idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... - x=x.scatter(dim=0, index=idx, src=smp_x) - - return x - - def adjust_param(self, soft=False): #Detach from gradient ? - if not self._fixed_prob: - if soft : - self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible - else: - self._params['prob'].data = F.relu(self._params['prob'].data) - #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) - #self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 - - self._params['prob'].data[0]=self._start_prob #Fixe p identite - - if not self._fixed_mag: - #self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme - self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX) - - def loss_weight(self): #A verifier - if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation - - prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] - - #Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !) - w_loss = torch.zeros((self._samples[0].shape[0],self._TF_set_size), device=self._samples[0].device) - for n_tf in range(self._N_seqTF): - tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) - tmp_w.scatter_(dim=1, index=self._samples[n_tf].view(-1,1), value=1/self._N_seqTF) - - tf_set = self._TF_sets[self._TF_schedule[n_tf][self._current_TF_idx]].to(prob.device) - curr_prob = torch.index_select(prob, 0, tf_set) - curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1 - - #ATTENTION DISTRIB DIFFERENTE AVEC MIX - assert not self._mix_dist - w_loss += tmp_w * curr_prob /self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) - - w_loss = torch.sum(w_loss,dim=1) - return w_loss - - def reg_loss(self, reg_factor=0.005): - if self._fixed_mag: # or self._fixed_prob: #Pas de regularisation si trop peu de DOF - return torch.tensor(0) - else: - #return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean') - params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask] - return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean') - - def next_TF_set(self, idx=None): - if idx: - self._current_TF_idx=idx - else: - self._current_TF_idx+=1 - - if self._current_TF_idx>=len(self._TF_schedule[0]): - self._current_TF_idx=0 - for n_tf in range(self._N_seqTF) : - TF.random.shuffle(self._TF_schedule[n_tf]) - #print('-- New schedule --') - - def train(self, mode=None): - if mode is None : - mode=self._data_augmentation - self.augment(mode=mode) #Inutile si mode=None - super(Data_augV6, self).train(mode) - - def eval(self): - self.train(mode=False) - - def augment(self, mode=True): - self._data_augmentation=mode - - def __getitem__(self, key): - return self._params[key] - - def __str__(self): - dist_param='' - if self._fixed_prob: dist_param+='Fx' - mag_param='Mag' - if self._fixed_mag: mag_param+= 'Fx' - if self._shared_mag: mag_param+= 'Sh' - if not self._mix_dist: - return "Data_augV6(Uniform%s-%dTF(%d)x%d-%s)" % (dist_param, self._nb_tf, self._TF_set_size, self._N_seqTF, mag_param) - else: - return "Data_augV6(Mix%.1f%s-%dTF(%d)x%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._TF_set_size, self._N_seqTF, mag_param) - - class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide """RandAugment implementation. @@ -1176,91 +442,103 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide """ return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) -class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training) - def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX): - super(RandAugUDA, self).__init__() +import higher +class Higher_model(nn.Module): + """Model wrapper for higher gradient tracking. - self._data_augmentation = True + Keep in memory the orginial model and it's functionnal, higher, version. - self._TF_dict = TF_dict - self._TF= list(self._TF_dict.keys()) - self._nb_tf= len(self._TF) - self._N_seqTF = N_TF + Might not be needed anymore if Higher implement detach for fmodel. - self.mag=nn.Parameter(torch.tensor(float(mag))) - self._params = nn.ParameterDict({ - "prob": nn.Parameter(torch.tensor(0.5).unsqueeze(dim=0)), - "mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX))), - }) - self._shared_mag = True - self._fixed_mag = True + see : https://github.com/facebookresearch/higher - self._op_list =[] - for tf in self._TF: - for mag in range(1, int(self._params['mag']*10), 1): - self._op_list+=[(tf, self._params['prob'].item(), mag/10)] - self._nb_op = len(self._op_list) + TODO: Get rid of the original model if not needed by user. + + Attributes: + _name (string): Name of the model. + _mods (nn.ModuleDict): Models (Orginial and Higher version). + """ + def __init__(self, model): + """Init Higher_model. + + Args: + model (nn.Module): Network for which higher gradients can be tracked. + """ + super(Higher_model, self).__init__() + + self._name = model.__str__() + self._mods = nn.ModuleDict({ + 'original': model, + 'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + }) + + def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True): + """Get a differentiable version of an Optimizer. + + Higher/Differentiable optimizer required to be used for higher gradient tracking. + Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step) + + Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called. + Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation. + + Args: + opt (torch.optim): Optimizer to make differentiable. + grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None) + track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True) + + Returns: + (Higher.DifferentiableOptimizer): Differentiable version of the optimizer. + """ + return higher.optim.get_diff_optim(opt, + self._mods['original'].parameters(), + fmodel=self._mods['functional'], + grad_callback=grad_callback, + track_higher_grads=track_higher_grads) def forward(self, x): - if self._data_augmentation:# and TF.random.random() < 0.5: - device = x.device - batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] + """ Main method of the model. - x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) - - for _ in range(self._N_seqTF): - ## Echantillonage ## == sampled_ops = np.random.choice(transforms, N) - uniforme_dist = torch.ones(1, self._nb_op, device=device).softmax(dim=1) - cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_op), device=device)*uniforme_dist) - sample = cat_distrib.sample() + Args: + x (Tensor): Batch of data. - ## Transformations ## - x = self.apply_TF(x, sample) - return x + Returns: + Tensor : Output of the network. Should be logits. + """ + return self._mods['functional'](x) - def apply_TF(self, x, sampled_TF): - smps_x=[] - - for op_idx in range(self._nb_op): - mask = sampled_TF==op_idx #Create selection mask - smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim) + def detach_(self): + """Detach from the graph. - if smp_x.shape[0]!=0: #if there's data to TF - if TF.random.random() < self._op_list[op_idx][1]: - magnitude=self._op_list[op_idx][2] - tf=self._op_list[op_idx][0] + Needed to limit the number of state kept in memory. + """ + tmp = self._mods['functional'].fast_params + self._mods['functional']._fast_params=[] + self._mods['functional'].update_params(tmp) + for p in self._mods['functional'].fast_params: + p.detach_().requires_grad_() - #In place - x[mask]=self._TF_dict[tf](x=smp_x, mag=torch.tensor(magnitude, device=x.device)) - - return x - - def adjust_param(self, soft=False): - pass #Pas de parametre a opti - - def loss_weight(self): - return 1 #Pas d'echantillon = pas de ponderation - - def reg_loss(self, reg_factor=0.005): - return torch.tensor(0) #Pas de regularisation - - def train(self, mode=None): - if mode is None : - mode=self._data_augmentation - self.augment(mode=mode) #Inutile si mode=None - super(RandAugUDA, self).train(mode) - - def eval(self): - self.train(mode=False) - - def augment(self, mode=True): - self._data_augmentation=mode + def state_dict(self): + """Returns a dictionary containing a whole state of the module. + """ + return self._mods['functional'].state_dict() def __getitem__(self, key): - return self._params[key] + """Access to modules + Args: + key (string): Name of the module to access. + + Returns: + nn.Module. + """ + return self._mods[key] def __str__(self): - return "RandAugUDA(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) + """Name of the module + + Returns: + String containing the name of the module. + """ + return self._name class Augmented_model(nn.Module): """Wrapper for a Data Augmentation module and a model. @@ -1377,81 +655,4 @@ class Augmented_model(nn.Module): Returns: String containing the name of the module as well as the higher levels parameters. """ - return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")" - -''' -import higher -class Augmented_model2(nn.Module): - def __init__(self, data_augmenter, model): - super(Augmented_model2, self).__init__() - - self._mods = nn.ModuleDict({ - 'data_aug': data_augmenter, - 'model': model, - 'fmodel': None - }) - - self.augment(mode=True) - - def initialize(self): - self._mods['model'].initialize() - - def forward(self, x): - if self._mods['fmodel']: - return self._mods['fmodel'](self._mods['data_aug'](x)) - else: - return self._mods['model'](self._mods['data_aug'](x)) - - def functional(self, opt, track_higher_grads=True): - self._mods['fmodel'] = higher.patch.monkeypatch(self._mods['model'], device=None, copy_initial_weights=True) - - return higher.optim.get_diff_optim(opt, - self._mods['model'].parameters(), - fmodel=self._mods['fmodel'], - track_higher_grads=track_higher_grads) - - def detach_(self): - tmp = self._mods['fmodel'].fast_params - self._mods['fmodel']._fast_params=[] - self._mods['fmodel'].update_params(tmp) - for p in self._mods['fmodel'].fast_params: - p.detach_().requires_grad_() - - def augment(self, mode=True): - self._data_augmentation=mode - self._mods['data_aug'].augment(mode) - - def train(self, mode=None): - if mode is None : - mode=self._data_augmentation - self._mods['data_aug'].augment(mode) - super(Augmented_model2, self).train(mode) - return self - - def eval(self): - return self.train(mode=False) - #super(Augmented_model, self).eval() - - def items(self): - """Return an iterable of the ModuleDict key/value pairs. - """ - return self._mods.items() - - def update(self, modules): - self._mods.update(modules) - - def is_augmenting(self): - return self._data_augmentation - - def TF_names(self): - try: - return self._mods['data_aug']._TF - except: - return None - - def __getitem__(self, key): - return self._mods[key] - - def __str__(self): - return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")" -''' \ No newline at end of file + return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")" \ No newline at end of file diff --git a/higher/model.py b/higher/model.py index d51dd6a..a38bbae 100755 --- a/higher/model.py +++ b/higher/model.py @@ -3,154 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F - -import higher -class Higher_model(nn.Module): - """Model wrapper for higher gradient tracking. - - Keep in memory the orginial model and it's functionnal, higher, version. - - Might not be needed anymore if Higher implement detach for fmodel. - - see : https://github.com/facebookresearch/higher - - TODO: Get rid of the original model if not needed by user. - - Attributes: - _name (string): Name of the model. - _mods (nn.ModuleDict): Models (Orginial and Higher version). - """ - def __init__(self, model): - """Init Higher_model. - - Args: - model (nn.Module): Network for which higher gradients can be tracked. - """ - super(Higher_model, self).__init__() - - self._name = model.__str__() - self._mods = nn.ModuleDict({ - 'original': model, - 'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - }) - - def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True): - """Get a differentiable version of an Optimizer. - - Higher/Differentiable optimizer required to be used for higher gradient tracking. - Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step) - - Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called. - Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation. - - Args: - opt (torch.optim): Optimizer to make differentiable. - grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None) - track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True) - - Returns: - (Higher.DifferentiableOptimizer): Differentiable version of the optimizer. - """ - return higher.optim.get_diff_optim(opt, - self._mods['original'].parameters(), - fmodel=self._mods['functional'], - grad_callback=grad_callback, - track_higher_grads=track_higher_grads) - - def forward(self, x): - """ Main method of the model. - - Args: - x (Tensor): Batch of data. - - Returns: - Tensor : Output of the network. Should be logits. - """ - return self._mods['functional'](x) - - def detach_(self): - """Detach from the graph. - - Needed to limit the number of state kept in memory. - """ - tmp = self._mods['functional'].fast_params - self._mods['functional']._fast_params=[] - self._mods['functional'].update_params(tmp) - for p in self._mods['functional'].fast_params: - p.detach_().requires_grad_() - - def state_dict(self): - """Returns a dictionary containing a whole state of the module. - """ - return self._mods['functional'].state_dict() - - def __getitem__(self, key): - """Access to modules - Args: - key (string): Name of the module to access. - - Returns: - nn.Module. - """ - return self._mods[key] - - def __str__(self): - """Name of the module - - Returns: - String containing the name of the module. - """ - return self._name - ## Basic CNN ## -class LeNet_F(nn.Module): - def __init__(self, num_inp, num_out): - super(LeNet_F, self).__init__() - self._params = nn.ParameterDict({ - 'w1': nn.Parameter(torch.zeros(20, num_inp, 5, 5)), - 'b1': nn.Parameter(torch.zeros(20)), - 'w2': nn.Parameter(torch.zeros(50, 20, 5, 5)), - 'b2': nn.Parameter(torch.zeros(50)), - #'w3': nn.Parameter(torch.zeros(500,4*4*50)), #num_imp=1 - 'w3': nn.Parameter(torch.zeros(500,5*5*50)), #num_imp=3 - 'b3': nn.Parameter(torch.zeros(500)), - 'w4': nn.Parameter(torch.zeros(num_out, 500)), - 'b4': nn.Parameter(torch.zeros(num_out)) - }) - self.initialize() - - - def initialize(self): - nn.init.kaiming_uniform_(self._params["w1"], a=math.sqrt(5)) - nn.init.kaiming_uniform_(self._params["w2"], a=math.sqrt(5)) - nn.init.kaiming_uniform_(self._params["w3"], a=math.sqrt(5)) - nn.init.kaiming_uniform_(self._params["w4"], a=math.sqrt(5)) - - def forward(self, x): - #print("Start Shape ", x.shape) - out = F.relu(F.conv2d(input=x, weight=self._params["w1"], bias=self._params["b1"])) - #print("Shape ", out.shape) - out = F.max_pool2d(out, 2) - #print("Shape ", out.shape) - out = F.relu(F.conv2d(input=out, weight=self._params["w2"], bias=self._params["b2"])) - #print("Shape ", out.shape) - out = F.max_pool2d(out, 2) - #print("Shape ", out.shape) - out = out.view(out.size(0), -1) - #print("Shape ", out.shape) - out = F.relu(F.linear(out, self._params["w3"], self._params["b3"])) - #print("Shape ", out.shape) - out = F.linear(out, self._params["w4"], self._params["b4"]) - #print("Shape ", out.shape) - #return F.log_softmax(out, dim=1) - return out - - def __getitem__(self, key): - return self._params[key] - - def __str__(self): - return "LeNet" - class LeNet(nn.Module): def __init__(self, num_inp, num_out): super(LeNet, self).__init__() @@ -171,451 +24,3 @@ class LeNet(nn.Module): def __str__(self): return "LeNet" - -## MobileNetv2 ## - -def _make_divisible(v, divisor, min_value=None): - """ - This function is taken from the original tf repo. - It ensures that all layers have a channel number that is divisible by 8 - It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py - :param v: - :param divisor: - :param min_value: - :return: - """ - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class ConvBNReLU(nn.Sequential): - def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): - padding = (kernel_size - 1) // 2 - super(ConvBNReLU, self).__init__( - nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), - nn.BatchNorm2d(out_planes), - nn.ReLU6(inplace=True) - ) - - -class InvertedResidual(nn.Module): - def __init__(self, inp, oup, stride, expand_ratio): - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2] - - hidden_dim = int(round(inp * expand_ratio)) - self.use_res_connect = self.stride == 1 and inp == oup - - layers = [] - if expand_ratio != 1: - # pw - layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) - layers.extend([ - # dw - ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), - # pw-linear - nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - ]) - self.conv = nn.Sequential(*layers) - - def forward(self, x): - if self.use_res_connect: - return x + self.conv(x) - else: - return self.conv(x) - - -class MobileNetV2(nn.Module): - def __init__(self, - num_classes=1000, - width_mult=1.0, - inverted_residual_setting=None, - round_nearest=8, - block=None): - """ - MobileNet V2 main class - Args: - num_classes (int): Number of classes - width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount - inverted_residual_setting: Network structure - round_nearest (int): Round the number of channels in each layer to be a multiple of this number - Set to 1 to turn off rounding - block: Module specifying inverted residual building block for mobilenet - """ - super(MobileNetV2, self).__init__() - - if block is None: - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - - if inverted_residual_setting is None: - inverted_residual_setting = [ - # t, c, n, s - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] - - # only check the first element, assuming user knows t,c,n,s are required - if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: - raise ValueError("inverted_residual_setting should be non-empty " - "or a 4-element list, got {}".format(inverted_residual_setting)) - - # building first layer - input_channel = _make_divisible(input_channel * width_mult, round_nearest) - self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - features = [ConvBNReLU(3, input_channel, stride=2)] - # building inverted residual blocks - for t, c, n, s in inverted_residual_setting: - output_channel = _make_divisible(c * width_mult, round_nearest) - for i in range(n): - stride = s if i == 0 else 1 - features.append(block(input_channel, output_channel, stride, expand_ratio=t)) - input_channel = output_channel - # building last several layers - features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) - # make it nn.Sequential - self.features = nn.Sequential(*features) - - # building classifier - self.classifier = nn.Sequential( - nn.Dropout(0.2), - nn.Linear(self.last_channel, num_classes), - ) - - # weight initialization - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.BatchNorm2d): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.zeros_(m.bias) - - def _forward_impl(self, x): - # This exists since TorchScript doesn't support inheritance, so the superclass method - # (this one) needs to have a name other than `forward` that can be accessed in a subclass - x = self.features(x) - x = x.mean([2, 3]) - x = self.classifier(x) - return x - - def forward(self, x): - return self._forward_impl(x) - - def __str__(self): - return "MobileNetV2" - -## ResNet ## -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -class BasicBlock(nn.Module): - expansion = 1 - __constants__ = ['downsample'] - - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): - super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion = 4 - __constants__ = ['downsample'] - - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): - super(Bottleneck, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation) - self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) - self.bn3 = norm_layer(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - -#ResNet18 : block=BasicBlock, layers=[2, 2, 2, 2] -class ResNet(nn.Module): - - def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=1000, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None): - super(ResNet, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - self._norm_layer = norm_layer - - self.inplanes = 64 - self.dilation = 1 - if replace_stride_with_dilation is None: - # each element in the tuple indicates if we should replace - # the 2x2 stride with a dilated convolution instead - replace_stride_with_dilation = [False, False, False] - if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) - self.groups = groups - self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, - bias=False) - self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2]) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - if zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) - elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) - - def _make_layer(self, block, planes, blocks, stride=1, dilate=False): - norm_layer = self._norm_layer - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - norm_layer(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation, norm_layer)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, groups=self.groups, - base_width=self.base_width, dilation=self.dilation, - norm_layer=norm_layer)) - - return nn.Sequential(*layers) - - def _forward_impl(self, x): - # See note [TorchScript super()] - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - x = torch.flatten(x, 1) - x = self.fc(x) - - return x - - def forward(self, x): - return self._forward_impl(x) - - def __str__(self): - return "ResNet18" - -## Wide ResNet ## -#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py -#https://github.com/arcelien/pba/blob/master/pba/wrn.py -#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py -''' -class BasicBlock(nn.Module): - def __init__(self, in_planes, out_planes, stride, dropRate=0.0): - super(BasicBlock, self).__init__() - self.bn1 = nn.BatchNorm2d(in_planes) - self.relu1 = nn.ReLU(inplace=True) - self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(out_planes) - self.relu2 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, - padding=1, bias=False) - self.droprate = dropRate - self.equalInOut = (in_planes == out_planes) - self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, - padding=0, bias=False) or None - def forward(self, x): - if not self.equalInOut: - x = self.relu1(self.bn1(x)) - else: - out = self.relu1(self.bn1(x)) - out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) - if self.droprate > 0: - out = F.dropout(out, p=self.droprate, training=self.training) - out = self.conv2(out) - return torch.add(x if self.equalInOut else self.convShortcut(x), out) - -class NetworkBlock(nn.Module): - def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): - super(NetworkBlock, self).__init__() - self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) - def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): - layers = [] - for i in range(int(nb_layers)): - layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) - return nn.Sequential(*layers) - def forward(self, x): - return self.layer(x) - -#wrn_size: 32 = WRN-28-2 ? 160 = WRN-28-10 -class WideResNet(nn.Module): - #def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): - def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0): - super(WideResNet, self).__init__() - - self.kernel_size = wrn_size - self.depth=depth - filter_size = 3 - nChannels = [min(self.kernel_size, 16), self.kernel_size, self.kernel_size * 2, self.kernel_size * 4] - strides = [1, 2, 2] # stride for each resblock - - #nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] - assert((depth - 4) % 6 == 0) - n = (depth - 4) / 6 - block = BasicBlock - # 1st conv before any network block - self.conv1 = nn.Conv2d(filter_size, nChannels[0], kernel_size=3, stride=1, - padding=1, bias=False) - # 1st block - self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, strides[0], dropRate) - # 2nd block - self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, strides[1], dropRate) - # 3rd block - self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, strides[2], dropRate) - # global average pooling and classifier - self.bn1 = nn.BatchNorm2d(nChannels[3]) - self.relu = nn.ReLU(inplace=True) - self.fc = nn.Linear(nChannels[3], num_classes) - self.nChannels = nChannels[3] - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - m.bias.data.zero_() - def forward(self, x): - out = self.conv1(x) - out = self.block1(out) - out = self.block2(out) - out = self.block3(out) - out = self.relu(self.bn1(out)) - out = F.avg_pool2d(out, 8) - out = out.view(-1, self.nChannels) - return self.fc(out) - - def architecture(self): - return super(WideResNet, self).__str__() - - def __str__(self): - return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth) -''' \ No newline at end of file diff --git a/higher/old/dataug_old.py b/higher/old/dataug_old.py new file mode 100644 index 0000000..2ffdf51 --- /dev/null +++ b/higher/old/dataug_old.py @@ -0,0 +1,1065 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import * + +#import kornia +#import random +import numpy as np +import copy + +import transformations as TF + +class Data_aug(nn.Module): #Rotation parametree + def __init__(self): + super(Data_aug, self).__init__() + self._data_augmentation = True + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.tensor(0.5)), + "mag": nn.Parameter(torch.tensor(1.0)) + }) + + #self.params["mag"].register_hook(print) + + def forward(self, x): + + if self._data_augmentation and random.random() < self._params["prob"]: + #print('Aug') + batch_size = x.shape[0] + # create transformation (rotation) + alpha = self._params["mag"]*180 # in degrees + angle = torch.ones(batch_size, device=x.device) * alpha + + # define the rotation center + center = torch.ones(batch_size, 2, device=x.device) + center[..., 0] = x.shape[3] / 2 # x + center[..., 1] = x.shape[2] / 2 # y + + #print(x.shape, center) + # define the scale factor + scale = torch.ones(batch_size, device=x.device) + + # compute the transformation matrix + M = kornia.get_rotation_matrix2d(center, angle, scale) + + # apply the transformation to original image + x = kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w) + + return x + + def eval(self): + self.augment(mode=False) + nn.Module.eval(self) + + def augment(self, mode=True): + self._data_augmentation=mode + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + return "Data_aug(Mag-1 TF)" + +class Data_augV2(nn.Module): #Methode exacte + def __init__(self): + super(Data_augV2, self).__init__() + self._data_augmentation = True + + self._fixed_transf=[0.0, 45.0, 180.0] #Degree rotation + #self._fixed_transf=[0.0] + self._nb_tf= len(self._fixed_transf) + + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme + #"prob2": nn.Parameter(torch.ones(len(self._fixed_transf)).softmax(dim=0)) + }) + + #print(self._params["prob"], self._params["prob2"]) + + self.transf_idx=0 + + def forward(self, x): + + if self._data_augmentation: + #print('Aug',self._fixed_transf[self.transf_idx]) + device = x.device + batch_size = x.shape[0] + + # create transformation (rotation) + #alpha = 180 # in degrees + alpha = self._fixed_transf[self.transf_idx] + angle = torch.ones(batch_size, device=device) * alpha + + x = self.rotate(x,angle) + + return x + + def rotate(self, x, angle): + + device = x.device + batch_size = x.shape[0] + # define the rotation center + center = torch.ones(batch_size, 2, device=device) + center[..., 0] = x.shape[3] / 2 # x + center[..., 1] = x.shape[2] / 2 # y + + #print(x.shape, center) + # define the scale factor + scale = torch.ones(batch_size, device=device) + + # compute the transformation matrix + M = kornia.get_rotation_matrix2d(center, angle, scale) + + # apply the transformation to original image + return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w) + + + def adjust_param(self): #Detach from gradient ? + self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) + #print('proba',self._params['prob']) + self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 + #print('Sum p', sum(self._params['prob'])) + + def eval(self): + self.augment(mode=False) + nn.Module.eval(self) + + def augment(self, mode=True): + self._data_augmentation=mode + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + return "Data_augV2(Exact-%d TF)" % self._nb_tf + +class Data_augV3(nn.Module): #Echantillonage uniforme/Mixte + def __init__(self, mix_dist=0.0): + super(Data_augV3, self).__init__() + self._data_augmentation = True + + #self._fixed_transf=[0.0, 45.0, 180.0] #Degree rotation + self._fixed_transf=[0.0, 1.0, -1.0] #Flips (Identity,Horizontal,Vertical) + #self._fixed_transf=[0.0] + self._nb_tf= len(self._fixed_transf) + + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme + #"prob2": nn.Parameter(torch.ones(len(self._fixed_transf)).softmax(dim=0)) + }) + + #print(self._params["prob"], self._params["prob2"]) + self._sample = [] + + self._mix_dist = False + if mix_dist != 0.0: + self._mix_dist = True + self._mix_factor = max(min(mix_dist, 1.0), 0.0) + + def forward(self, x): + + if self._data_augmentation: + device = x.device + batch_size = x.shape[0] + + + #good_distrib = Uniform(low=torch.zeros(batch_size,1, device=device),high=torch.new_full((batch_size,1),self._params["prob"], device=device)) + #bad_distrib = Uniform(low=torch.zeros(batch_size,1, device=device),high=torch.new_full((batch_size,1), 1-self._params["prob"], device=device)) + + #transform_dist = Categorical(probs=torch.tensor([self._params["prob"], 1-self._params["prob"]], device=device)) + #self._sample = transform_dist._sample(sample_shape=torch.Size([batch_size,1])) + + uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=0) + + if not self._mix_dist: + distrib = uniforme_dist + else: + distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=0) #Mix distrib reel / uniforme avec mix_factor + + cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*distrib) + self._sample = cat_distrib.sample() + + TF_param = torch.tensor([self._fixed_transf[x] for x in self._sample], device=device) #Approche de marco peut-etre plus rapide + + #x = self.rotate(x,angle=TF_param) + x = self.flip(x,flip_mat=TF_param) + + return x + + def rotate(self, x, angle): + + device = x.device + batch_size = x.shape[0] + # define the rotation center + center = torch.ones(batch_size, 2, device=device) + center[..., 0] = x.shape[3] / 2 # x + center[..., 1] = x.shape[2] / 2 # y + + #print(x.shape, center) + # define the scale factor + scale = torch.ones(batch_size, device=device) + + # compute the transformation matrix + M = kornia.get_rotation_matrix2d(center, angle, scale) + + # apply the transformation to original image + return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w) + + def flip(self, x, flip_mat): + + #print(flip_mat) + device = x.device + batch_size = x.shape[0] + + h, w = x.shape[2], x.shape[3] # destination size + #points_src = torch.ones(batch_size, 4, 2, device=device) + #points_dst = torch.ones(batch_size, 4, 2, device=device) + + #Identity + iM=torch.tensor(np.eye(3)) + + #Horizontal flip + # the source points are the region to crop corners + #points_src = torch.FloatTensor([[ + # [w - 1, 0], [0, 0], [0, h - 1], [w - 1, h - 1], + #]]) + # the destination points are the image vertexes + #points_dst = torch.FloatTensor([[ + # [0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1], + #]]) + # compute perspective transform + #hM = kornia.get_perspective_transform(points_src, points_dst) + hM =torch.tensor( [[[-1., 0., w-1], + [ 0., 1., 0.], + [ 0., 0., 1.]]]) + + #Vertical flip + # the source points are the region to crop corners + #points_src = torch.FloatTensor([[ + # [0, h - 1], [w - 1, h - 1], [w - 1, 0], [0, 0], + #]]) + # the destination points are the image vertexes + #points_dst = torch.FloatTensor([[ + # [0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1], + #]]) + # compute perspective transform + #vM = kornia.get_perspective_transform(points_src, points_dst) + vM =torch.tensor( [[[ 1., 0., 0.], + [ 0., -1., h-1], + [ 0., 0., 1.]]]) + #print(vM) + + M=torch.ones(batch_size, 3, 3, device=device) + + for i in range(batch_size): # A optimiser + if flip_mat[i]==0.0: + M[i,]=iM + elif flip_mat[i]==1.0: + M[i,]=hM + elif flip_mat[i]==-1.0: + M[i,]=vM + + # warp the original image by the found transform + return kornia.warp_perspective(x, M, dsize=(h, w)) + + def adjust_param(self, soft=False): #Detach from gradient ? + + if soft : + self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible + else: + #self._params['prob'].clamp(min=0.0,max=1.0) + self._params['prob'].data = F.relu(self._params['prob'].data) + #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) + #print('proba',self._params['prob']) + self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 + #print('Sum p', sum(self._params['prob'])) + + def loss_weight(self): + #w_loss = [self._params["prob"][x] for x in self._sample] + #print(self._sample.view(-1,1).shape) + #print(self._sample[:10]) + + w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device) + w_loss.scatter_(1, self._sample.view(-1,1), 1) + #print(w_loss.shape) + #print(w_loss[:10,:]) + w_loss = w_loss * self._params["prob"] + #print(w_loss.shape) + #print(w_loss[:10,:]) + w_loss = torch.sum(w_loss,dim=1) + #print(w_loss.shape) + #print(w_loss[:10]) + return w_loss + + def train(self, mode=None): + if mode is None : + mode=self._data_augmentation + self.augment(mode=mode) #Inutile si mode=None + super(Data_augV3, self).train(mode) + + def eval(self): + self.train(mode=False) + #super(Augmented_model, self).eval() + + def augment(self, mode=True): + self._data_augmentation=mode + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + if not self._mix_dist: + return "Data_augV3(Uniform-%d TF)" % self._nb_tf + else: + return "Data_augV3(Mix %.1f-%d TF)" % (self._mix_factor, self._nb_tf) + +''' +TF_dict={ #Dataugv4 + ## Geometric TF ## + 'Identity' : (lambda x, mag: x), + 'FlipUD' : (lambda x, mag: flipUD(x)), + 'FlipLR' : (lambda x, mag: flipLR(x)), + 'Rotate': (lambda x, mag: rotate(x, angle=torch.tensor([rand_int(mag, maxval=30)for _ in x], device=x.device))), + 'TranslateX': (lambda x, mag: translate(x, translation=torch.tensor([[rand_int(mag, maxval=20), 0] for _ in x], device=x.device))), + 'TranslateY': (lambda x, mag: translate(x, translation=torch.tensor([[0, rand_int(mag, maxval=20)] for _ in x], device=x.device))), + 'ShearX': (lambda x, mag: shear(x, shear=torch.tensor([[rand_float(mag, maxval=0.3), 0] for _ in x], device=x.device))), + 'ShearY': (lambda x, mag: shear(x, shear=torch.tensor([[0, rand_float(mag, maxval=0.3)] for _ in x], device=x.device))), + + ## Color TF (Expect image in the range of [0, 1]) ## + 'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), + 'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), + 'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), + 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), + 'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))), + 'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch + + #Non fonctionnel + #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) + #'Equalize': (lambda mag: None), +} +''' +class Data_augV4(nn.Module): #Transformations avec mask + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0): + super(Data_augV4, self).__init__() + assert len(TF_dict)>0 + + self._data_augmentation = True + + #self._TF_matrix={} + #self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix + #self._mag_fct = TF_dict + self._TF_dict = TF_dict + self._TF= list(self._TF_dict.keys()) + self._nb_tf= len(self._TF) + + self._N_seqTF = N_TF + + self._fixed_mag=5 #[0, PARAMETER_MAX] + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme + }) + + self._samples = [] + + self._mix_dist = False + if mix_dist != 0.0: + self._mix_dist = True + self._mix_factor = max(min(mix_dist, 1.0), 0.0) + + def forward(self, x): + if self._data_augmentation: + device = x.device + batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] + + x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) + self._samples = [] + + for _ in range(self._N_seqTF): + ## Echantillonage ## + uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) + + if not self._mix_dist: + self._distrib = uniforme_dist + else: + self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor + + cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) + sample = cat_distrib.sample() + self._samples.append(sample) + + ## Transformations ## + x = self.apply_TF(x, sample) + return x + ''' + def compute_TF_matrix(self, magnitude=None, sample_info= None): + print('Computing TF_matrix...') + if not magnitude : + magnitude=self._fixed_mag + + if sample_info: + self._input_info['h']= sample_info['h'] + self._input_info['w']= sample_info['w'] + self._input_info['device'] = sample_info['device'] + h, w, device= self._input_info['h'], self._input_info['w'], self._input_info['device'] + + self._TF_matrix={} + for tf in self._TF : + if tf=='Id': + self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]], device=device) + elif tf=='Rot': + center = torch.ones(1, 2, device=device) + center[0, 0] = w / 2 # x + center[0, 1] = h / 2 # y + scale = torch.ones(1, device=device) + angle = self._mag_fct[tf](magnitude) * torch.ones(1, device=device) + R = kornia.get_rotation_matrix2d(center, angle, scale) #Rotation matrix (1,2,3) + self._TF_matrix[tf]=torch.cat((R,torch.tensor([[[ 0., 0., 1.]]], device=device)), dim=1) #TF matrix (1,3,3) + elif tf=='FlipLR': + self._TF_matrix[tf]=torch.tensor([[[-1., 0., w-1], + [ 0., 1., 0.], + [ 0., 0., 1.]]], device=device) + elif tf=='FlipUD': + self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.], + [ 0., -1., h-1], + [ 0., 0., 1.]]], device=device) + else: + raise Exception("Invalid TF requested") + ''' + def apply_TF(self, x, sampled_TF): + device = x.device + smps_x=[] + masks=[] + for tf_idx in range(self._nb_tf): + mask = sampled_TF==tf_idx #Create selection mask + smp_x = x[mask] #torch.masked_select() ? + + if smp_x.shape[0]!=0: #if there's data to TF + magnitude=self._fixed_mag + tf=self._TF[tf_idx] + + ''' + ## Geometric TF ## + if tf=='Identity': + pass + elif tf=='FlipLR': + smp_x = TF.flipLR(smp_x) + elif tf=='FlipUD': + smp_x = TF.flipUD(smp_x) + elif tf=='Rotate': + smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='TranslateX' or tf=='TranslateY': + smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='ShearX' or tf=='ShearY' : + smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + + ## Color TF (Expect image in the range of [0, 1]) ## + elif tf=='Contrast': + smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Color': + smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Brightness': + smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Sharpness': + smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Posterize': + smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device)) + elif tf=='Solarize': + smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Equalize': + smp_x = TF.equalize(smp_x) + elif tf=='Auto_Contrast': + smp_x = TF.auto_contrast(smp_x) + else: + raise Exception("Invalid TF requested : ", tf) + + x[mask]=smp_x # Refusionner eviter x[mask] : in place + ''' + x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place + + #idx= mask.nonzero() + #print('-'*8) + #print(idx[0], tf_idx) + #print(smp_x[0,]) + #x=x.view(-1,3*32*32) + #x=x.scatter(dim=0, index=idx, src=smp_x.view(-1,3*32*32)) #Changement des Tensor mais pas visible sur la visualisation... + #x=x.view(-1,3,32,32) + #print(x[0,]) + + ''' + if len(self._TF_matrix)==0 or self._input_info['h']!=h or self._input_info['w']!=w or self._input_info['device']!=device: #Device different:Pas necessaire de tout recalculer + self.compute_TF_matrix(sample_info={'h': x.shape[2], + 'w': x.shape[3], + 'device': x.device}) + + TF_matrix = torch.zeros(batch_size, 3, 3, device=device) #All geom TF + + for tf_idx in range(self._nb_tf): + mask = self._sample==tf_idx #Create selection mask + TF_matrix[mask,]=self._TF_matrix[self._TF[tf_idx]] + + x=kornia.warp_perspective(x, TF_matrix, dsize=(h, w)) + ''' + return x + + def adjust_param(self, soft=False): #Detach from gradient ? + + if soft : + self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible + else: + #self._params['prob'].clamp(min=0.0,max=1.0) + self._params['prob'].data = F.relu(self._params['prob'].data) + #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) + + self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 + + def loss_weight(self): + # 1 seule TF + #self._sample = self._samples[-1] + #w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device) + #w_loss.scatter_(dim=1, index=self._sample.view(-1,1), value=1) + #w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) + #w_loss = torch.sum(w_loss,dim=1) + + #Plusieurs TF sequentielles + w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device) + for sample in self._samples: + tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) + tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF) + w_loss += tmp_w + + w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) + w_loss = torch.sum(w_loss,dim=1) + return w_loss + + + def train(self, mode=None): + if mode is None : + mode=self._data_augmentation + self.augment(mode=mode) #Inutile si mode=None + super(Data_augV4, self).train(mode) + + def eval(self): + self.train(mode=False) + + def augment(self, mode=True): + self._data_augmentation=mode + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + if not self._mix_dist: + return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF) + else: + return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF) + + +class Data_augV6(nn.Module): #Optimisation sequentielle #Mauvais resultats + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=None, fixed_mag=True, shared_mag=True): + super(Data_augV6, self).__init__() + assert len(TF_dict)>0 + + self._data_augmentation = True + + self._TF_dict = TF_dict + self._TF= list(self._TF_dict.keys()) + self._nb_tf= len(self._TF) + + self._N_seqTF = N_TF + self._shared_mag = shared_mag + self._fixed_mag = fixed_mag + + self._TF_set_size = prob_set_size if prob_set_size else self._nb_tf + + self._fixed_TF=[0] #Identite + assert self._TF_set_size>=len(self._fixed_TF) + + if self._TF_set_size>self._nb_tf: + print("Warning : TF sets size higher than number of TF. Reducing set size to %d"%self._nb_tf) + self._TF_set_size=self._nb_tf + + ## Genenerate TF sets ## + if self._TF_set_size==len(self._fixed_TF): + print("Warning : using only fixed set of TF : ", self._fixed_TF) + self._TF_sets=torch.tensor([self._fixed_TF]) + else: + def generate_TF_sets(n_TF, set_size, idx_prefix=[]): + TF_sets=[] + if len(idx_prefix)!=0: + if set_size>2: + for i in range(idx_prefix[-1]+1, n_TF): + TF_sets += generate_TF_sets(n_TF=n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i]) + else: + #if i not in idx_prefix: + TF_sets+=[torch.tensor(idx_prefix+[i]) for i in range(idx_prefix[-1]+1, n_TF)] + elif set_size>1: + for i in range(0, n_TF): + TF_sets += generate_TF_sets(n_TF=n_TF, set_size=set_size, idx_prefix=[i]) + else: + TF_sets+=[torch.tensor([i]) for i in range(0, n_TF)] + return TF_sets + + self._TF_sets=generate_TF_sets(self._nb_tf, self._TF_set_size, self._fixed_TF) + + ## Plan TF learning schedule ## + self._TF_schedule = [list(range(len(self._TF_sets))) for _ in range(self._N_seqTF)] + for n_tf in range(self._N_seqTF) : + TF.random.shuffle(self._TF_schedule[n_tf]) + + self._current_TF_idx=0 #random.randint + self._start_prob = 1/self._TF_set_size + + + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.tensor(self._start_prob).expand(self._nb_tf)), #Proba independantes + "mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)) if self._shared_mag + else torch.tensor(float(TF.PARAMETER_MAX)).expand(self._nb_tf)), #[0, PARAMETER_MAX] + }) + + #for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag + + #Distribution + self._fixed_prob=fixed_prob + self._samples = [] + self._mix_dist = False + if mix_dist != 0.0: + self._mix_dist = True + self._mix_factor = max(min(mix_dist, 0.999), 0.0) + + #Mag regularisation + if not self._fixed_mag: + if self._shared_mag : + self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max + else: + self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag] + self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max + + def forward(self, x): + self._samples = [] + if self._data_augmentation:# and TF.random.random() < 0.5: + device = x.device + batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] + + x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) + + for n_tf in range(self._N_seqTF): + + tf_set = self._TF_sets[self._TF_schedule[n_tf][self._current_TF_idx]].to(device) + #print(n_tf, tf_set) + ## Echantillonage ## + uniforme_dist = torch.ones(1,len(tf_set),device=device).softmax(dim=1) + + if not self._mix_dist: + self._distrib = uniforme_dist + else: + prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] + curr_prob = torch.index_select(prob, 0, tf_set) + curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1 + self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor + + cat_distrib= Categorical(probs=torch.ones((batch_size, len(tf_set)), device=device)*self._distrib) + sample = cat_distrib.sample() + self._samples.append(sample) + + ## Transformations ## + x = self.apply_TF(x, sample) + return x + + def apply_TF(self, x, sampled_TF): + device = x.device + batch_size, channels, h, w = x.shape + smps_x=[] + + for sel_idx, tf_idx in enumerate(self._TF_sets[self._current_TF_idx]): + mask = sampled_TF==sel_idx #Create selection mask + smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim) + + if smp_x.shape[0]!=0: #if there's data to TF + magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx] + if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param + + tf=self._TF[tf_idx] + #print(magnitude) + + #In place + #x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) + + #Out of place + smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude) + idx= mask.nonzero() + idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + x=x.scatter(dim=0, index=idx, src=smp_x) + + return x + + def adjust_param(self, soft=False): #Detach from gradient ? + if not self._fixed_prob: + if soft : + self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible + else: + self._params['prob'].data = F.relu(self._params['prob'].data) + #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) + #self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 + + self._params['prob'].data[0]=self._start_prob #Fixe p identite + + if not self._fixed_mag: + #self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme + self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX) + + def loss_weight(self): #A verifier + if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation + + prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] + + #Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !) + w_loss = torch.zeros((self._samples[0].shape[0],self._TF_set_size), device=self._samples[0].device) + for n_tf in range(self._N_seqTF): + tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) + tmp_w.scatter_(dim=1, index=self._samples[n_tf].view(-1,1), value=1/self._N_seqTF) + + tf_set = self._TF_sets[self._TF_schedule[n_tf][self._current_TF_idx]].to(prob.device) + curr_prob = torch.index_select(prob, 0, tf_set) + curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1 + + #ATTENTION DISTRIB DIFFERENTE AVEC MIX + assert not self._mix_dist + w_loss += tmp_w * curr_prob /self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) + + w_loss = torch.sum(w_loss,dim=1) + return w_loss + + def reg_loss(self, reg_factor=0.005): + if self._fixed_mag: # or self._fixed_prob: #Pas de regularisation si trop peu de DOF + return torch.tensor(0) + else: + #return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean') + params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask] + return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean') + + def next_TF_set(self, idx=None): + if idx: + self._current_TF_idx=idx + else: + self._current_TF_idx+=1 + + if self._current_TF_idx>=len(self._TF_schedule[0]): + self._current_TF_idx=0 + for n_tf in range(self._N_seqTF) : + TF.random.shuffle(self._TF_schedule[n_tf]) + #print('-- New schedule --') + + def train(self, mode=None): + if mode is None : + mode=self._data_augmentation + self.augment(mode=mode) #Inutile si mode=None + super(Data_augV6, self).train(mode) + + def eval(self): + self.train(mode=False) + + def augment(self, mode=True): + self._data_augmentation=mode + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + dist_param='' + if self._fixed_prob: dist_param+='Fx' + mag_param='Mag' + if self._fixed_mag: mag_param+= 'Fx' + if self._shared_mag: mag_param+= 'Sh' + if not self._mix_dist: + return "Data_augV6(Uniform%s-%dTF(%d)x%d-%s)" % (dist_param, self._nb_tf, self._TF_set_size, self._N_seqTF, mag_param) + else: + return "Data_augV6(Mix%.1f%s-%dTF(%d)x%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._TF_set_size, self._N_seqTF, mag_param) + +class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training) + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX): + super(RandAugUDA, self).__init__() + + self._data_augmentation = True + + self._TF_dict = TF_dict + self._TF= list(self._TF_dict.keys()) + self._nb_tf= len(self._TF) + self._N_seqTF = N_TF + + self.mag=nn.Parameter(torch.tensor(float(mag))) + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.tensor(0.5).unsqueeze(dim=0)), + "mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX))), + }) + self._shared_mag = True + self._fixed_mag = True + + self._op_list =[] + for tf in self._TF: + for mag in range(1, int(self._params['mag']*10), 1): + self._op_list+=[(tf, self._params['prob'].item(), mag/10)] + self._nb_op = len(self._op_list) + + def forward(self, x): + if self._data_augmentation:# and TF.random.random() < 0.5: + device = x.device + batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] + + x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) + + for _ in range(self._N_seqTF): + ## Echantillonage ## == sampled_ops = np.random.choice(transforms, N) + uniforme_dist = torch.ones(1, self._nb_op, device=device).softmax(dim=1) + cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_op), device=device)*uniforme_dist) + sample = cat_distrib.sample() + + ## Transformations ## + x = self.apply_TF(x, sample) + return x + + def apply_TF(self, x, sampled_TF): + smps_x=[] + + for op_idx in range(self._nb_op): + mask = sampled_TF==op_idx #Create selection mask + smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim) + + if smp_x.shape[0]!=0: #if there's data to TF + if TF.random.random() < self._op_list[op_idx][1]: + magnitude=self._op_list[op_idx][2] + tf=self._op_list[op_idx][0] + + #In place + x[mask]=self._TF_dict[tf](x=smp_x, mag=torch.tensor(magnitude, device=x.device)) + + return x + + def adjust_param(self, soft=False): + pass #Pas de parametre a opti + + def loss_weight(self): + return 1 #Pas d'echantillon = pas de ponderation + + def reg_loss(self, reg_factor=0.005): + return torch.tensor(0) #Pas de regularisation + + def train(self, mode=None): + if mode is None : + mode=self._data_augmentation + self.augment(mode=mode) #Inutile si mode=None + super(RandAugUDA, self).train(mode) + + def eval(self): + self.train(mode=False) + + def augment(self, mode=True): + self._data_augmentation=mode + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + return "RandAugUDA(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) + +''' +import higher +class Augmented_model2(nn.Module): + def __init__(self, data_augmenter, model): + super(Augmented_model2, self).__init__() + + self._mods = nn.ModuleDict({ + 'data_aug': data_augmenter, + 'model': model, + 'fmodel': None + }) + + self.augment(mode=True) + + def initialize(self): + self._mods['model'].initialize() + + def forward(self, x): + if self._mods['fmodel']: + return self._mods['fmodel'](self._mods['data_aug'](x)) + else: + return self._mods['model'](self._mods['data_aug'](x)) + + def functional(self, opt, track_higher_grads=True): + self._mods['fmodel'] = higher.patch.monkeypatch(self._mods['model'], device=None, copy_initial_weights=True) + + return higher.optim.get_diff_optim(opt, + self._mods['model'].parameters(), + fmodel=self._mods['fmodel'], + track_higher_grads=track_higher_grads) + + def detach_(self): + tmp = self._mods['fmodel'].fast_params + self._mods['fmodel']._fast_params=[] + self._mods['fmodel'].update_params(tmp) + for p in self._mods['fmodel'].fast_params: + p.detach_().requires_grad_() + + def augment(self, mode=True): + self._data_augmentation=mode + self._mods['data_aug'].augment(mode) + + def train(self, mode=None): + if mode is None : + mode=self._data_augmentation + self._mods['data_aug'].augment(mode) + super(Augmented_model2, self).train(mode) + return self + + def eval(self): + return self.train(mode=False) + #super(Augmented_model, self).eval() + + def items(self): + """Return an iterable of the ModuleDict key/value pairs. + """ + return self._mods.items() + + def update(self, modules): + self._mods.update(modules) + + def is_augmenting(self): + return self._data_augmentation + + def TF_names(self): + try: + return self._mods['data_aug']._TF + except: + return None + + def __getitem__(self, key): + return self._mods[key] + + def __str__(self): + return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")" +''' + +from torchvision.datasets.vision import VisionDataset +from PIL import Image +import augmentation_transforms +import numpy as np +class AugmentedDataset(VisionDataset): + def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None): + + super(AugmentedDataset, self).__init__(root, transform=transform, target_transform=target_transform) + + supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform) + + self.sup_data = supervised_dataset.data if not subset else supervised_dataset.data[subset[0]:subset[1]] + self.sup_targets = supervised_dataset.targets if not subset else supervised_dataset.targets[subset[0]:subset[1]] + assert len(self.sup_data)==len(self.sup_targets) + + for idx, img in enumerate(self.sup_data): + self.sup_data[idx]= Image.fromarray(img) #to PIL Image + + self.unsup_data=[] + self.unsup_targets=[] + + self.data= self.sup_data + self.targets= self.sup_targets + + self.dataset_info= { + 'name': 'CIFAR10', + 'sup': len(self.sup_data), + 'unsup': len(self.unsup_data), + 'length': len(self.sup_data)+len(self.unsup_data), + } + + + self._TF = [ + ## Geometric TF ## + 'Rotate', + 'TranslateX', + 'TranslateY', + 'ShearX', + 'ShearY', + + 'Cutout', + + ## Color TF ## + 'Contrast', + 'Color', + 'Brightness', + 'Sharpness', + #'Posterize', + #'Solarize', + + 'Invert', + 'AutoContrast', + 'Equalize', + ] + self._op_list =[] + self.prob=0.5 + for tf in self._TF: + for mag in range(1, 10): + self._op_list+=[(tf, self.prob, mag)] + self._nb_op = len(self._op_list) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + #img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def augement_data(self, aug_copy=1): + + policies = [] + for op_1 in self._op_list: + for op_2 in self._op_list: + policies += [[op_1, op_2]] + + for idx, image in enumerate(self.sup_data): + if (idx/self.dataset_info['sup'])%0.2==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup']) + #if idx==10000:break + + for _ in range(aug_copy): + chosen_policy = policies[np.random.choice(len(policies))] + aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image + #aug_image = augmentation_transforms.cutout_numpy(aug_image) + + self.unsup_data+=[(aug_image*255.).astype(self.sup_data.dtype)]#Cast float image to uint8 + self.unsup_targets+=[self.sup_targets[idx]] + + #self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8 + self.unsup_data=np.array(self.unsup_data) + self.data= np.concatenate((self.sup_data, self.unsup_data), axis=0) + self.targets= np.concatenate((self.sup_targets, self.unsup_targets), axis=0) + + assert len(self.unsup_data)==len(self.unsup_targets) + assert len(self.data)==len(self.targets) + self.dataset_info['unsup']=len(self.unsup_data) + self.dataset_info['length']=self.dataset_info['sup']+self.dataset_info['unsup'] + + def len_supervised(self): + return self.dataset_info['sup'] + + def len_unsupervised(self): + return self.dataset_info['unsup'] + + def __len__(self): + return self.dataset_info['length'] + + def __str__(self): + return "CIFAR10(Sup:{}-Unsup:{}-{}TF)".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF)) diff --git a/higher/old/higher_repro.py b/higher/old/higher_repro.py new file mode 100644 index 0000000..3c57c67 --- /dev/null +++ b/higher/old/higher_repro.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import higher +import time + +data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=torchvision.transforms.ToTensor()) +dl_train = torch.utils.data.DataLoader(data_train, batch_size=300, shuffle=True, num_workers=0, pin_memory=False) + + +class Aug_model(nn.Module): + def __init__(self, model, hyper_param=True): + super(Aug_model, self).__init__() + + #### Origin of the issue ? #### + if hyper_param: + self._params = nn.ParameterDict({ + "hyper_param": nn.Parameter(torch.Tensor([0.5])), + }) + ############################### + + self._mods = nn.ModuleDict({ + 'model': model, + }) + + def forward(self, x): + return self._mods['model'](x) #* self._params['hyper_param'] + + def __getitem__(self, key): + return self._mods[key] + +class Aug_model2(nn.Module): #Slow increase like no hyper_param + def __init__(self, model, hyper_param=True): + super(Aug_model2, self).__init__() + + #### Origin of the issue ? #### + if hyper_param: + self._params = nn.ParameterDict({ + "hyper_param": nn.Parameter(torch.Tensor([0.5])), + }) + ############################### + + self._mods = nn.ModuleDict({ + 'model': model, + 'fmodel': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + }) + + def forward(self, x): + return self._mods['fmodel'](x) * self._params['hyper_param'] + + def get_diffopt(self, opt, track_higher_grads=True): + return higher.optim.get_diff_optim(opt, + self._mods['model'].parameters(), + fmodel=self._mods['fmodel'], + track_higher_grads=track_higher_grads) + + def __getitem__(self, key): + return self._mods[key] + +if __name__ == "__main__": + + device = torch.device('cuda:1') + aug_model = Aug_model2( + model=torch.hub.load('pytorch/vision:v0.4.2', 'resnet18', pretrained=False), + hyper_param=True #False will not extend step time + ).to(device) + + inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9) + + #fmodel = higher.patch.monkeypatch(aug_model, device=None, copy_initial_weights=True) + #diffopt = higher.optim.get_diff_optim(inner_opt, aug_model.parameters(),fmodel=fmodel,track_higher_grads=True) + diffopt = aug_model.get_diffopt(inner_opt) + + for i, (xs, ys) in enumerate(dl_train): + xs, ys = xs.to(device), ys.to(device) + + #logits = fmodel(xs) + logits = aug_model(xs) + loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean') + + t = time.process_time() + diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) + #print(len(fmodel._fast_params),"step", time.process_time()-t) + print(len(aug_model['fmodel']._fast_params),"step", time.process_time()-t) \ No newline at end of file diff --git a/higher/old/model_old.py b/higher/old/model_old.py new file mode 100644 index 0000000..ec24f25 --- /dev/null +++ b/higher/old/model_old.py @@ -0,0 +1,502 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +## Basic CNN ## +class LeNet_F(nn.Module): + def __init__(self, num_inp, num_out): + super(LeNet_F, self).__init__() + self._params = nn.ParameterDict({ + 'w1': nn.Parameter(torch.zeros(20, num_inp, 5, 5)), + 'b1': nn.Parameter(torch.zeros(20)), + 'w2': nn.Parameter(torch.zeros(50, 20, 5, 5)), + 'b2': nn.Parameter(torch.zeros(50)), + #'w3': nn.Parameter(torch.zeros(500,4*4*50)), #num_imp=1 + 'w3': nn.Parameter(torch.zeros(500,5*5*50)), #num_imp=3 + 'b3': nn.Parameter(torch.zeros(500)), + 'w4': nn.Parameter(torch.zeros(num_out, 500)), + 'b4': nn.Parameter(torch.zeros(num_out)) + }) + self.initialize() + + + def initialize(self): + nn.init.kaiming_uniform_(self._params["w1"], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self._params["w2"], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self._params["w3"], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self._params["w4"], a=math.sqrt(5)) + + def forward(self, x): + #print("Start Shape ", x.shape) + out = F.relu(F.conv2d(input=x, weight=self._params["w1"], bias=self._params["b1"])) + #print("Shape ", out.shape) + out = F.max_pool2d(out, 2) + #print("Shape ", out.shape) + out = F.relu(F.conv2d(input=out, weight=self._params["w2"], bias=self._params["b2"])) + #print("Shape ", out.shape) + out = F.max_pool2d(out, 2) + #print("Shape ", out.shape) + out = out.view(out.size(0), -1) + #print("Shape ", out.shape) + out = F.relu(F.linear(out, self._params["w3"], self._params["b3"])) + #print("Shape ", out.shape) + out = F.linear(out, self._params["w4"], self._params["b4"]) + #print("Shape ", out.shape) + #return F.log_softmax(out, dim=1) + return out + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + return "LeNet" + + +## MobileNetv2 ## + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + padding = (kernel_size - 1) // 2 + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, + num_classes=1000, + width_mult=1.0, + inverted_residual_setting=None, + round_nearest=8, + block=None): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + if block is None: + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) + # make it nn.Sequential + self.features = nn.Sequential(*features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, num_classes), + ) + + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + x = self.features(x) + x = x.mean([2, 3]) + x = self.classifier(x) + return x + + def forward(self, x): + return self._forward_impl(x) + + def __str__(self): + return "MobileNetV2" + +## ResNet ## +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +#ResNet18 : block=BasicBlock, layers=[2, 2, 2, 2] +class ResNet(nn.Module): + + def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + + def __str__(self): + return "ResNet18" + +## Wide ResNet ## +#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py +#https://github.com/arcelien/pba/blob/master/pba/wrn.py +#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py +''' +class BasicBlock(nn.Module): + def __init__(self, in_planes, out_planes, stride, dropRate=0.0): + super(BasicBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.relu1 = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_planes) + self.relu2 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.droprate = dropRate + self.equalInOut = (in_planes == out_planes) + self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) or None + def forward(self, x): + if not self.equalInOut: + x = self.relu1(self.bn1(x)) + else: + out = self.relu1(self.bn1(x)) + out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, training=self.training) + out = self.conv2(out) + return torch.add(x if self.equalInOut else self.convShortcut(x), out) + +class NetworkBlock(nn.Module): + def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): + super(NetworkBlock, self).__init__() + self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) + def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): + layers = [] + for i in range(int(nb_layers)): + layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) + return nn.Sequential(*layers) + def forward(self, x): + return self.layer(x) + +#wrn_size: 32 = WRN-28-2 ? 160 = WRN-28-10 +class WideResNet(nn.Module): + #def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): + def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0): + super(WideResNet, self).__init__() + + self.kernel_size = wrn_size + self.depth=depth + filter_size = 3 + nChannels = [min(self.kernel_size, 16), self.kernel_size, self.kernel_size * 2, self.kernel_size * 4] + strides = [1, 2, 2] # stride for each resblock + + #nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] + assert((depth - 4) % 6 == 0) + n = (depth - 4) / 6 + block = BasicBlock + # 1st conv before any network block + self.conv1 = nn.Conv2d(filter_size, nChannels[0], kernel_size=3, stride=1, + padding=1, bias=False) + # 1st block + self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, strides[0], dropRate) + # 2nd block + self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, strides[1], dropRate) + # 3rd block + self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, strides[2], dropRate) + # global average pooling and classifier + self.bn1 = nn.BatchNorm2d(nChannels[3]) + self.relu = nn.ReLU(inplace=True) + self.fc = nn.Linear(nChannels[3], num_classes) + self.nChannels = nChannels[3] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.bias.data.zero_() + def forward(self, x): + out = self.conv1(x) + out = self.block1(out) + out = self.block2(out) + out = self.block3(out) + out = self.relu(self.bn1(out)) + out = F.avg_pool2d(out, 8) + out = out.view(-1, self.nChannels) + return self.fc(out) + + def architecture(self): + return super(WideResNet, self).__str__() + + def __str__(self): + return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth) +''' \ No newline at end of file diff --git a/higher/test_lr.py b/higher/old/test_lr.py similarity index 100% rename from higher/test_lr.py rename to higher/old/test_lr.py diff --git a/higher/old/train_utils_old.py b/higher/old/train_utils_old.py new file mode 100644 index 0000000..389dd9d --- /dev/null +++ b/higher/old/train_utils_old.py @@ -0,0 +1,590 @@ +import torch +#import torch.optim +import torchvision +import higher + +from datasets import * +from utils import * + +def train_classic_tests(model, epochs=1): + device = next(model.parameters()).device + #opt = torch.optim.Adam(model.parameters(), lr=1e-3) + optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) + + countcopy=0 + model.train() + dl_val_it = iter(dl_val) + log = [] + + fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + doptim = higher.optim.get_diff_optim(optim, model.parameters(), fmodel=fmodel, track_higher_grads=False) + for epoch in range(epochs): + print_torch_mem("Start epoch") + print(len(fmodel._fast_params)) + t0 = time.process_time() + #with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=True) as (fmodel, doptim): + + #fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + #doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True) + + for i, (features, labels) in enumerate(dl_train): + features,labels = features.to(device), labels.to(device) + + #with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, doptim): + + + #optim.zero_grad() + pred = fmodel.forward(features) + loss = F.cross_entropy(pred,labels) + doptim.step(loss) #(opt.zero_grad, loss.backward, opt.step) + #loss.backward() + #new_params = doptim.step(loss, params=fmodel.parameters()) + #fmodel.update_params(new_params) + + + #print('Fast param',len(fmodel._fast_params)) + #print('opt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][2]['momentum_buffer'].shape) + + if False or (len(fmodel._fast_params)>1): + print("fmodel fast param",len(fmodel._fast_params)) + ''' + #val_loss = F.cross_entropy(fmodel(features), labels) + + #print_graph(val_loss) + + #val_loss.backward() + #print('bip') + + tmp = fmodel.parameters() + + #print(list(tmp)[1]) + tmp = [higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in tmp] + #print(len(tmp)) + + #fmodel._fast_params.clear() + del fmodel._fast_params + fmodel._fast_params=None + + fmodel.fast_params=tmp # Surcharge la memoire + #fmodel.update_params(tmp) #Meilleur perf / Surcharge la memoire avec trach higher grad + + #optim._fmodel=fmodel + ''' + + + countcopy+=1 + model_copy(src=fmodel, dst=model, patch_copy=False) + fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + #doptim.detach_dyn() + #tmp = doptim.state + #tmp = doptim.state_dict() + #for k, v in tmp['state'].items(): + # print('dict',k, type(v)) + + a = optim.param_groups[0]['params'][0] + state = optim.state[a] + #state['momentum_buffer'] = None + #print('opt state', type(optim.state[a]), len(optim.state[a])) + #optim.load_state_dict(tmp) + + + for group_idx, group in enumerate(optim.param_groups): + # print('gp idx',group_idx) + for p_idx, p in enumerate(group['params']): + optim.state[p]=doptim.state[group_idx][p_idx] + + #print('opt state', type(optim.state[a]['momentum_buffer']), optim.state[a]['momentum_buffer'][0:10]) + #print('dopt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][0]['momentum_buffer'][0:10]) + ''' + for a in tmp: + #print(type(a), len(a)) + for nb, b in a.items(): + #print(nb, type(b), len(b)) + for n, state in b.items(): + #print(n, type(states)) + #print(state.grad_fn) + state = torch.tensor(state.data).requires_grad_() + #print(state.grad_fn) + ''' + + + doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True) + #doptim.state = tmp + + + countcopy+=1 + model_copy(src=fmodel, dst=model) + optim_copy(dopt=diffopt, opt=inner_opt) + + #### Tests #### + tf = time.process_time() + try: + xs_val, ys_val = next(dl_val_it) + except StopIteration: #Fin epoch val + dl_val_it = iter(dl_val) + xs_val, ys_val = next(dl_val_it) + xs_val, ys_val = xs_val.to(device), ys_val.to(device) + + val_loss = F.cross_entropy(model(xs_val), ys_val) + accuracy, _ =test(model) + model.train() + #### Log #### + data={ + "epoch": epoch, + "train_loss": loss.item(), + "val_loss": val_loss.item(), + "acc": accuracy, + "time": tf - t0, + + "param": None, + } + log.append(data) + + #countcopy+=1 + #model_copy(src=fmodel, dst=model, patch_copy=False) + #optim.load_state_dict(doptim.state_dict()) #Besoin sauver etat otpim ? + + print("Copy ", countcopy) + return log + + + +def run_simple_dataug(inner_it, epochs=1): + device = next(model.parameters()).device + dl_train_it = iter(dl_train) + dl_val_it = iter(dl_val) + + #aug_model = nn.Sequential( + # Data_aug(), + # LeNet(1,10), + # ) + aug_model = Augmented_model(Data_aug(), LeNet(1,10)).to(device) + print(str(aug_model)) + meta_opt = torch.optim.Adam(aug_model['data_aug'].parameters(), lr=1e-2) + inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9) + + log = [] + t0 = time.process_time() + + epoch = 0 + while epoch < epochs: + meta_opt.zero_grad() + aug_model.train() + with higher.innerloop_ctx(aug_model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair... + + for i in range(n_inner_iter): + try: + xs, ys = next(dl_train_it) + except StopIteration: #Fin epoch train + tf = time.process_time() + epoch +=1 + dl_train_it = iter(dl_train) + xs, ys = next(dl_train_it) + + accuracy, _ =test(model) + aug_model.train() + + #### Print #### + print('-'*9) + print('Epoch %d/%d'%(epoch,epochs)) + print('train loss',loss.item(), '/ val loss', val_loss.item()) + print('acc', accuracy) + print('mag', aug_model['data_aug']['mag'].item()) + + #### Log #### + data={ + "epoch": epoch, + "train_loss": loss.item(), + "val_loss": val_loss.item(), + "acc": accuracy, + "time": tf - t0, + + "param": aug_model['data_aug']['mag'].item(), + } + log.append(data) + t0 = time.process_time() + + xs, ys = xs.to(device), ys.to(device) + + logits = fmodel(xs) # modified `params` can also be passed as a kwarg + + loss = F.cross_entropy(logits, ys) # no need to call loss.backwards() + #loss.backward(retain_graph=True) + #print(fmodel['model']._params['b4'].grad) + #print('mag', fmodel['data_aug']['mag'].grad) + + diffopt.step(loss) # note that `step` must take `loss` as an argument! + # The line above gets P[t+1] from P[t] and loss[t]. `step` also returns + # these new parameters, as an alternative to getting them from + # `fmodel.fast_params` or `fmodel.parameters()` after calling + # `diffopt.step`. + + # At this point, or at any point in the iteration, you can take the + # gradient of `fmodel.parameters()` (or equivalently + # `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently + # `fmodel.init_fast_params`). i.e. `fast_params` will always have + # `grad_fn` as an attribute, and be part of the gradient tape. + + # At the end of your inner loop you can obtain these e.g. ... + #grad_of_grads = torch.autograd.grad( + # meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0)) + try: + xs_val, ys_val = next(dl_val_it) + except StopIteration: #Fin epoch val + dl_val_it = iter(dl_val) + xs_val, ys_val = next(dl_val_it) + xs_val, ys_val = xs_val.to(device), ys_val.to(device) + + fmodel.augment(mode=False) + val_logits = fmodel(xs_val) #Validation sans transfornations ! + val_loss = F.cross_entropy(val_logits, ys_val) + #print('val_loss',val_loss.item()) + val_loss.backward() + + #print('mag', fmodel['data_aug']['mag'], '/', fmodel['data_aug']['mag'].grad) + + #model=copy.deepcopy(fmodel) + aug_model.load_state_dict(fmodel.state_dict()) #Do not copy gradient ! + #Copie des gradients + for paramName, paramValue, in fmodel.named_parameters(): + for netCopyName, netCopyValue, in aug_model.named_parameters(): + if paramName == netCopyName: + netCopyValue.grad = paramValue.grad + + #print('mag', aug_model['data_aug']['mag'], '/', aug_model['data_aug']['mag'].grad) + meta_opt.step() + + plot_res(log, fig_name="res/{}-{} epochs- {} in_it".format(str(aug_model),epochs,inner_it)) + print('-'*9) + times = [x["time"] for x in log] + print(str(aug_model),": acc", max([x["acc"] for x in log]), "in (ms):", np.mean(times), "+/-", np.std(times)) + +def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0): + device = next(model.parameters()).device + dl_train_it = iter(dl_train) + dl_val_it = iter(dl_val) + + meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-3) + inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9) + + high_grad_track = True + if dataug_epoch_start>0: + model.augment(mode=False) + high_grad_track = False + + model.train() + + log = [] + t0 = time.process_time() + + countcopy=0 + val_loss=torch.tensor(0) + opt_param=None + + epoch = 0 + while epoch < epochs: + meta_opt.zero_grad() + with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt): #effet copy_initial_weight pas clair... + + for i in range(n_inner_iter): + try: + xs, ys = next(dl_train_it) + except StopIteration: #Fin epoch train + tf = time.process_time() + epoch +=1 + dl_train_it = iter(dl_train) + xs, ys = next(dl_train_it) + + #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + #viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) + + accuracy, _ =test(model) + model.train() + + #### Print #### + print('-'*9) + print('Epoch : %d/%d'%(epoch,epochs)) + print('Train loss :',loss.item(), '/ val loss', val_loss.item()) + print('Accuracy :', accuracy) + print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) + print('TF Proba :', model['data_aug']['prob'].data) + #print('proba grad',aug_model['data_aug']['prob'].grad) + ############# + #### Log #### + data={ + "epoch": epoch, + "train_loss": loss.item(), + "val_loss": val_loss.item(), + "acc": accuracy, + "time": tf - t0, + + "param": [p for p in model['data_aug']['prob']], + } + log.append(data) + ############# + + if epoch == dataug_epoch_start: + print('Starting Data Augmention...') + model.augment(mode=True) + high_grad_track = True + + t0 = time.process_time() + + xs, ys = xs.to(device), ys.to(device) + + ''' + #Methode exacte + final_loss = 0 + for tf_idx in range(fmodel['data_aug']._nb_tf): + fmodel['data_aug'].transf_idx=tf_idx + logits = fmodel(xs) + loss = F.cross_entropy(logits, ys) + #loss.backward(retain_graph=True) + #print('idx', tf_idx) + #print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad) + final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ? + + loss = final_loss + ''' + #Methode uniforme + logits = fmodel(xs) # modified `params` can also be passed as a kwarg + loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards() + if fmodel._data_augmentation: #Weight loss + w_loss = fmodel['data_aug'].loss_weight().to(device) + loss = loss * w_loss + loss = loss.mean() + #''' + + #to visualize computational graph + #print_graph(loss) + + #loss.backward(retain_graph=True) + #print(fmodel['model']._params['b4'].grad) + #print('prob grad', fmodel['data_aug']['prob'].grad) + + diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) + + try: + xs_val, ys_val = next(dl_val_it) + except StopIteration: #Fin epoch val + dl_val_it = iter(dl_val) + xs_val, ys_val = next(dl_val_it) + xs_val, ys_val = xs_val.to(device), ys_val.to(device) + + fmodel.augment(mode=False) #Validation sans transfornations ! + val_loss = F.cross_entropy(fmodel(xs_val), ys_val) + + #print_graph(val_loss) + + val_loss.backward() + + countcopy+=1 + model_copy(src=fmodel, dst=model) + optim_copy(dopt=diffopt, opt=inner_opt) + + meta_opt.step() + model['data_aug'].adjust_param() #Contrainte sum(proba)=1 + + print("Copy ", countcopy) + return log + +def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False): + device = next(model.parameters()).device + log = [] + countcopy=0 + val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch + dl_val_it = iter(dl_val) + + #if inner_it!=0: + meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2 + inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 + + high_grad_track = True + if inner_it == 0: + high_grad_track=False + if dataug_epoch_start!=0: + model.augment(mode=False) + high_grad_track = False + + val_loss_monitor= None + if loss_patience != None : + if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start + else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor (Not on val data : used by Dataug... => Test data) + + model.train() + + fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) + + meta_opt.zero_grad() + + for epoch in range(1, epochs+1): + #print_torch_mem("Start epoch "+str(epoch)) + #print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params)) + t0 = time.process_time() + #with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt): + + for i, (xs, ys) in enumerate(dl_train): + xs, ys = xs.to(device), ys.to(device) + + #Methode exacte + #final_loss = 0 + #for tf_idx in range(fmodel['data_aug']._nb_tf): + # fmodel['data_aug'].transf_idx=tf_idx + # logits = fmodel(xs) + # loss = F.cross_entropy(logits, ys) + # #loss.backward(retain_graph=True) + # final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ? + #loss = final_loss + + if(not KLdiv): + #Methode uniforme + logits = fmodel(xs) # modified `params` can also be passed as a kwarg + loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards() + + if fmodel._data_augmentation: #Weight loss + w_loss = fmodel['data_aug'].loss_weight()#.to(device) + loss = loss * w_loss + loss = loss.mean() + + else: + #Methode KL div + if fmodel._data_augmentation : + fmodel.augment(mode=False) + sup_logits = fmodel(xs) + fmodel.augment(mode=True) + else: + sup_logits = fmodel(xs) + log_sup=F.log_softmax(sup_logits, dim=1) + loss = F.cross_entropy(log_sup, ys) + + if fmodel._data_augmentation: + aug_logits = fmodel(xs) + log_aug=F.log_softmax(aug_logits, dim=1) + + w_loss = fmodel['data_aug'].loss_weight() #Weight loss + + #if epoch>50: #debut differe ? + #KL div w/ logits - Similarite predictions (distributions) + aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug) + aug_loss = aug_loss.sum(dim=-1) + #aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') + aug_loss = (w_loss * aug_loss).mean() + + aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean() + + unsupp_coeff = 1 + loss += aug_loss * unsupp_coeff + + #to visualize computational graph + #print_graph(loss) + + #loss.backward(retain_graph=True) + #print(fmodel['model']._params['b4'].grad) + #print('prob grad', fmodel['data_aug']['prob'].grad) + + #t = time.process_time() + diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) + #print(len(fmodel._fast_params),"step", time.process_time()-t) + + if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step + #print("meta") + + val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #+ fmodel['data_aug'].reg_loss() + #print_graph(val_loss) + + #t = time.process_time() + val_loss.backward() + #print("meta", time.process_time()-t) + #print('proba grad',model['data_aug']['prob'].grad) + if model['data_aug']['prob'].grad is None or model['data_aug']['mag'] is None: + print("Warning no grad (iter",i,") :\n Prob-",model['data_aug']['prob'].grad,"\n Mag-", model['data_aug']['mag'].grad) + + countcopy+=1 + model_copy(src=fmodel, dst=model) + optim_copy(dopt=diffopt, opt=inner_opt) + + torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN + + #if epoch>50: + meta_opt.step() + model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 + try: #Dataugv6 + model['data_aug'].next_TF_set() + except: + pass + + fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) + + meta_opt.zero_grad() + + tf = time.process_time() + + #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + #viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight()) + + if(not high_grad_track): + countcopy+=1 + model_copy(src=fmodel, dst=model) + optim_copy(dopt=diffopt, opt=inner_opt) + val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + + #Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False) + fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) + + accuracy, test_loss =test(model) + model.train() + + #### Log #### + #print(type(model['data_aug']) is dataug.Data_augV5) + param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])] + data={ + "epoch": epoch, + "train_loss": loss.item(), + "val_loss": val_loss.item(), + "acc": accuracy, + "time": tf - t0, + + "param": param #if isinstance(model['data_aug'], Data_augV5) + #else [p.item() for p in model['data_aug']['prob']], + } + log.append(data) + ############# + #### Print #### + if(print_freq and epoch%print_freq==0): + print('-'*9) + print('Epoch : %d/%d'%(epoch,epochs)) + print('Time : %.00f'%(tf - t0)) + print('Train loss :',loss.item(), '/ val loss', val_loss.item()) + print('Accuracy :', max([x["acc"] for x in log])) + print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) + print('TF Proba :', model['data_aug']['prob'].data) + #print('proba grad',model['data_aug']['prob'].grad) + print('TF Mag :', model['data_aug']['mag'].data) + #print('Mag grad',model['data_aug']['mag'].grad) + #print('Reg loss:', model['data_aug'].reg_loss().item()) + #print('Aug loss', aug_loss.item()) + ############# + if val_loss_monitor : + model.eval() + val_loss_monitor.register(test_loss)#val_loss.item()) + if val_loss_monitor.end_training(): break #Stop training + model.train() + + if not model.is_augmenting() and (epoch == dataug_epoch_start or (val_loss_monitor and val_loss_monitor.limit_reached()==1)): + print('Starting Data Augmention...') + dataug_epoch_start = epoch + model.augment(mode=True) + if inner_it != 0: high_grad_track = True + + try: + viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight()) + except: + print("Couldn't save finals samples") + pass + + #print("Copy ", countcopy) + return log diff --git a/higher/old/utils_old.py b/higher/old/utils_old.py new file mode 100644 index 0000000..8416486 --- /dev/null +++ b/higher/old/utils_old.py @@ -0,0 +1,161 @@ +import numpy as np +import json, math, time, os +import matplotlib.pyplot as plt +import copy +import gc + +from torchviz import make_dot + +import torch +import torch.nn.functional as F + +import time + +class timer(): + def __init__(self): + self._start_time=time.time() + def exec_time(self): + end = time.time() + res = end-self._start_time + self._start_time=end + return res + +def plot_res(log, fig_name='res', param_names=None): + + epochs = [x["epoch"] for x in log] + + fig, ax = plt.subplots(ncols=3, figsize=(15, 3)) + + ax[0].set_title('Loss') + ax[0].plot(epochs,[x["train_loss"] for x in log], label='Train') + ax[0].plot(epochs,[x["val_loss"] for x in log], label='Val') + ax[0].legend() + + ax[1].set_title('Acc') + ax[1].plot(epochs,[x["acc"] for x in log]) + + if log[0]["param"]!= None: + if isinstance(log[0]["param"],float): + ax[2].set_title('Mag') + ax[2].plot(epochs,[x["param"] for x in log], label='Mag') + ax[2].legend() + else : + ax[2].set_title('Prob') + #for idx, _ in enumerate(log[0]["param"]): + #ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx)) + if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])] + proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])] + ax[2].stackplot(epochs, proba, labels=param_names) + ax[2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5)) + + + fig_name = fig_name.replace('.',',') + plt.savefig(fig_name) + plt.close() + +def plot_res_compare(filenames, fig_name='res'): + + all_data=[] + #legend="" + for idx, file in enumerate(filenames): + #legend+=str(idx)+'-'+file+'\n' + with open(file) as json_file: + data = json.load(json_file) + all_data.append(data) + + n_tf = [len(x["Param_names"]) for x in all_data] + acc = [x["Accuracy"] for x in all_data] + time = [x["Time"][0] for x in all_data] + + fig, ax = plt.subplots(ncols=3, figsize=(30, 8)) + + ax[0].plot(n_tf, acc) + ax[1].plot(n_tf, time) + + ax[0].set_title('Acc') + ax[1].set_title('Time') + #for a in ax: a.legend() + + fig_name = fig_name.replace('.',',') + plt.savefig(fig_name, bbox_inches='tight') + plt.close() + +def plot_TF_res(log, tf_names, fig_name='res'): + + mean = np.mean([x["param"] for x in log], axis=0) + std = np.std([x["param"] for x in log], axis=0) + + fig, ax = plt.subplots(1, 1, figsize=(30, 8), sharey=True) + ax.bar(tf_names, mean, yerr=std) + #ax.bar(tf_names, log[-1]["param"]) + + fig_name = fig_name.replace('.',',') + plt.savefig(fig_name, bbox_inches='tight') + plt.close() + +def model_copy(src,dst, patch_copy=True, copy_grad=True): + #model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats) + + dst.load_state_dict(src.state_dict()) #Do not copy gradient ! + + if patch_copy: + dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ? + dst['data_aug'].load_state_dict(src['data_aug'].state_dict()) + + #Copie des gradients + if copy_grad: + for paramName, paramValue, in src.named_parameters(): + for netCopyName, netCopyValue, in dst.named_parameters(): + if paramName == netCopyName: + netCopyValue.grad = paramValue.grad + #netCopyValue=copy.deepcopy(paramValue) + + try: #Data_augV4 + dst['data_aug']._input_info = src['data_aug']._input_info + dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix + except: + pass + +def optim_copy(dopt, opt): + + #inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state... + #opt_param=higher.optim.get_trainable_opt_params(diffopt) + + for group_idx, group in enumerate(opt.param_groups): + # print('gp idx',group_idx) + for p_idx, p in enumerate(group['params']): + opt.state[p]=dopt.state[group_idx][p_idx] + +class loss_monitor(): #Voir https://github.com/pytorch/ignite + def __init__(self, patience, end_train=1): + self.patience = patience + self.end_train = end_train + self.counter = 0 + self.best_score = None + self.reached_limit = 0 + + def register(self, loss): + if self.best_score is None: + self.best_score = loss + elif loss > self.best_score: + self.counter += 1 + #if not self.reached_limit: + print("loss no improve counter", self.counter, self.reached_limit) + else: + self.best_score = loss + self.counter = 0 + def limit_reached(self): + if self.counter >= self.patience: + self.counter = 0 + self.reached_limit +=1 + self.best_score = None + return self.reached_limit + + def end_training(self): + if self.limit_reached() >= self.end_train: + return True + else: + return False + + def reset(self): + self.__init__(self.patience, self.end_train) diff --git a/higher/train_utils.py b/higher/train_utils.py index 5a792bb..7846ef5 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -157,147 +157,6 @@ def train_classic_higher(model, epochs=1): return log -def train_classic_tests(model, epochs=1): - device = next(model.parameters()).device - #opt = torch.optim.Adam(model.parameters(), lr=1e-3) - optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) - - countcopy=0 - model.train() - dl_val_it = iter(dl_val) - log = [] - - fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - doptim = higher.optim.get_diff_optim(optim, model.parameters(), fmodel=fmodel, track_higher_grads=False) - for epoch in range(epochs): - print_torch_mem("Start epoch") - print(len(fmodel._fast_params)) - t0 = time.process_time() - #with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=True) as (fmodel, doptim): - - #fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - #doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True) - - for i, (features, labels) in enumerate(dl_train): - features,labels = features.to(device), labels.to(device) - - #with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, doptim): - - - #optim.zero_grad() - pred = fmodel.forward(features) - loss = F.cross_entropy(pred,labels) - doptim.step(loss) #(opt.zero_grad, loss.backward, opt.step) - #loss.backward() - #new_params = doptim.step(loss, params=fmodel.parameters()) - #fmodel.update_params(new_params) - - - #print('Fast param',len(fmodel._fast_params)) - #print('opt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][2]['momentum_buffer'].shape) - - if False or (len(fmodel._fast_params)>1): - print("fmodel fast param",len(fmodel._fast_params)) - ''' - #val_loss = F.cross_entropy(fmodel(features), labels) - - #print_graph(val_loss) - - #val_loss.backward() - #print('bip') - - tmp = fmodel.parameters() - - #print(list(tmp)[1]) - tmp = [higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in tmp] - #print(len(tmp)) - - #fmodel._fast_params.clear() - del fmodel._fast_params - fmodel._fast_params=None - - fmodel.fast_params=tmp # Surcharge la memoire - #fmodel.update_params(tmp) #Meilleur perf / Surcharge la memoire avec trach higher grad - - #optim._fmodel=fmodel - ''' - - - countcopy+=1 - model_copy(src=fmodel, dst=model, patch_copy=False) - fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - #doptim.detach_dyn() - #tmp = doptim.state - #tmp = doptim.state_dict() - #for k, v in tmp['state'].items(): - # print('dict',k, type(v)) - - a = optim.param_groups[0]['params'][0] - state = optim.state[a] - #state['momentum_buffer'] = None - #print('opt state', type(optim.state[a]), len(optim.state[a])) - #optim.load_state_dict(tmp) - - - for group_idx, group in enumerate(optim.param_groups): - # print('gp idx',group_idx) - for p_idx, p in enumerate(group['params']): - optim.state[p]=doptim.state[group_idx][p_idx] - - #print('opt state', type(optim.state[a]['momentum_buffer']), optim.state[a]['momentum_buffer'][0:10]) - #print('dopt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][0]['momentum_buffer'][0:10]) - ''' - for a in tmp: - #print(type(a), len(a)) - for nb, b in a.items(): - #print(nb, type(b), len(b)) - for n, state in b.items(): - #print(n, type(states)) - #print(state.grad_fn) - state = torch.tensor(state.data).requires_grad_() - #print(state.grad_fn) - ''' - - - doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True) - #doptim.state = tmp - - - countcopy+=1 - model_copy(src=fmodel, dst=model) - optim_copy(dopt=diffopt, opt=inner_opt) - - #### Tests #### - tf = time.process_time() - try: - xs_val, ys_val = next(dl_val_it) - except StopIteration: #Fin epoch val - dl_val_it = iter(dl_val) - xs_val, ys_val = next(dl_val_it) - xs_val, ys_val = xs_val.to(device), ys_val.to(device) - - val_loss = F.cross_entropy(model(xs_val), ys_val) - accuracy, _ =test(model) - model.train() - #### Log #### - data={ - "epoch": epoch, - "train_loss": loss.item(), - "val_loss": val_loss.item(), - "acc": accuracy, - "time": tf - t0, - - "param": None, - } - log.append(data) - - #countcopy+=1 - #model_copy(src=fmodel, dst=model, patch_copy=False) - #optim.load_state_dict(doptim.state_dict()) #Besoin sauver etat otpim ? - - print("Copy ", countcopy) - return log - def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1): device = next(model.parameters()).device @@ -383,446 +242,6 @@ def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1): return log -def run_simple_dataug(inner_it, epochs=1): - device = next(model.parameters()).device - dl_train_it = iter(dl_train) - dl_val_it = iter(dl_val) - - #aug_model = nn.Sequential( - # Data_aug(), - # LeNet(1,10), - # ) - aug_model = Augmented_model(Data_aug(), LeNet(1,10)).to(device) - print(str(aug_model)) - meta_opt = torch.optim.Adam(aug_model['data_aug'].parameters(), lr=1e-2) - inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9) - - log = [] - t0 = time.process_time() - - epoch = 0 - while epoch < epochs: - meta_opt.zero_grad() - aug_model.train() - with higher.innerloop_ctx(aug_model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair... - - for i in range(n_inner_iter): - try: - xs, ys = next(dl_train_it) - except StopIteration: #Fin epoch train - tf = time.process_time() - epoch +=1 - dl_train_it = iter(dl_train) - xs, ys = next(dl_train_it) - - accuracy, _ =test(model) - aug_model.train() - - #### Print #### - print('-'*9) - print('Epoch %d/%d'%(epoch,epochs)) - print('train loss',loss.item(), '/ val loss', val_loss.item()) - print('acc', accuracy) - print('mag', aug_model['data_aug']['mag'].item()) - - #### Log #### - data={ - "epoch": epoch, - "train_loss": loss.item(), - "val_loss": val_loss.item(), - "acc": accuracy, - "time": tf - t0, - - "param": aug_model['data_aug']['mag'].item(), - } - log.append(data) - t0 = time.process_time() - - xs, ys = xs.to(device), ys.to(device) - - logits = fmodel(xs) # modified `params` can also be passed as a kwarg - - loss = F.cross_entropy(logits, ys) # no need to call loss.backwards() - #loss.backward(retain_graph=True) - #print(fmodel['model']._params['b4'].grad) - #print('mag', fmodel['data_aug']['mag'].grad) - - diffopt.step(loss) # note that `step` must take `loss` as an argument! - # The line above gets P[t+1] from P[t] and loss[t]. `step` also returns - # these new parameters, as an alternative to getting them from - # `fmodel.fast_params` or `fmodel.parameters()` after calling - # `diffopt.step`. - - # At this point, or at any point in the iteration, you can take the - # gradient of `fmodel.parameters()` (or equivalently - # `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently - # `fmodel.init_fast_params`). i.e. `fast_params` will always have - # `grad_fn` as an attribute, and be part of the gradient tape. - - # At the end of your inner loop you can obtain these e.g. ... - #grad_of_grads = torch.autograd.grad( - # meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0)) - try: - xs_val, ys_val = next(dl_val_it) - except StopIteration: #Fin epoch val - dl_val_it = iter(dl_val) - xs_val, ys_val = next(dl_val_it) - xs_val, ys_val = xs_val.to(device), ys_val.to(device) - - fmodel.augment(mode=False) - val_logits = fmodel(xs_val) #Validation sans transfornations ! - val_loss = F.cross_entropy(val_logits, ys_val) - #print('val_loss',val_loss.item()) - val_loss.backward() - - #print('mag', fmodel['data_aug']['mag'], '/', fmodel['data_aug']['mag'].grad) - - #model=copy.deepcopy(fmodel) - aug_model.load_state_dict(fmodel.state_dict()) #Do not copy gradient ! - #Copie des gradients - for paramName, paramValue, in fmodel.named_parameters(): - for netCopyName, netCopyValue, in aug_model.named_parameters(): - if paramName == netCopyName: - netCopyValue.grad = paramValue.grad - - #print('mag', aug_model['data_aug']['mag'], '/', aug_model['data_aug']['mag'].grad) - meta_opt.step() - - plot_res(log, fig_name="res/{}-{} epochs- {} in_it".format(str(aug_model),epochs,inner_it)) - print('-'*9) - times = [x["time"] for x in log] - print(str(aug_model),": acc", max([x["acc"] for x in log]), "in (ms):", np.mean(times), "+/-", np.std(times)) - -def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0): - device = next(model.parameters()).device - dl_train_it = iter(dl_train) - dl_val_it = iter(dl_val) - - meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-3) - inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9) - - high_grad_track = True - if dataug_epoch_start>0: - model.augment(mode=False) - high_grad_track = False - - model.train() - - log = [] - t0 = time.process_time() - - countcopy=0 - val_loss=torch.tensor(0) - opt_param=None - - epoch = 0 - while epoch < epochs: - meta_opt.zero_grad() - with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt): #effet copy_initial_weight pas clair... - - for i in range(n_inner_iter): - try: - xs, ys = next(dl_train_it) - except StopIteration: #Fin epoch train - tf = time.process_time() - epoch +=1 - dl_train_it = iter(dl_train) - xs, ys = next(dl_train_it) - - #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) - #viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) - - accuracy, _ =test(model) - model.train() - - #### Print #### - print('-'*9) - print('Epoch : %d/%d'%(epoch,epochs)) - print('Train loss :',loss.item(), '/ val loss', val_loss.item()) - print('Accuracy :', accuracy) - print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) - print('TF Proba :', model['data_aug']['prob'].data) - #print('proba grad',aug_model['data_aug']['prob'].grad) - ############# - #### Log #### - data={ - "epoch": epoch, - "train_loss": loss.item(), - "val_loss": val_loss.item(), - "acc": accuracy, - "time": tf - t0, - - "param": [p for p in model['data_aug']['prob']], - } - log.append(data) - ############# - - if epoch == dataug_epoch_start: - print('Starting Data Augmention...') - model.augment(mode=True) - high_grad_track = True - - t0 = time.process_time() - - xs, ys = xs.to(device), ys.to(device) - - ''' - #Methode exacte - final_loss = 0 - for tf_idx in range(fmodel['data_aug']._nb_tf): - fmodel['data_aug'].transf_idx=tf_idx - logits = fmodel(xs) - loss = F.cross_entropy(logits, ys) - #loss.backward(retain_graph=True) - #print('idx', tf_idx) - #print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad) - final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ? - - loss = final_loss - ''' - #Methode uniforme - logits = fmodel(xs) # modified `params` can also be passed as a kwarg - loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards() - if fmodel._data_augmentation: #Weight loss - w_loss = fmodel['data_aug'].loss_weight().to(device) - loss = loss * w_loss - loss = loss.mean() - #''' - - #to visualize computational graph - #print_graph(loss) - - #loss.backward(retain_graph=True) - #print(fmodel['model']._params['b4'].grad) - #print('prob grad', fmodel['data_aug']['prob'].grad) - - diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) - - try: - xs_val, ys_val = next(dl_val_it) - except StopIteration: #Fin epoch val - dl_val_it = iter(dl_val) - xs_val, ys_val = next(dl_val_it) - xs_val, ys_val = xs_val.to(device), ys_val.to(device) - - fmodel.augment(mode=False) #Validation sans transfornations ! - val_loss = F.cross_entropy(fmodel(xs_val), ys_val) - - #print_graph(val_loss) - - val_loss.backward() - - countcopy+=1 - model_copy(src=fmodel, dst=model) - optim_copy(dopt=diffopt, opt=inner_opt) - - meta_opt.step() - model['data_aug'].adjust_param() #Contrainte sum(proba)=1 - - print("Copy ", countcopy) - return log - -def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False): - device = next(model.parameters()).device - log = [] - countcopy=0 - val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch - dl_val_it = iter(dl_val) - - #if inner_it!=0: - meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2 - inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 - - high_grad_track = True - if inner_it == 0: - high_grad_track=False - if dataug_epoch_start!=0: - model.augment(mode=False) - high_grad_track = False - - val_loss_monitor= None - if loss_patience != None : - if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start - else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor (Not on val data : used by Dataug... => Test data) - - model.train() - - fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) - - meta_opt.zero_grad() - - for epoch in range(1, epochs+1): - #print_torch_mem("Start epoch "+str(epoch)) - #print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params)) - t0 = time.process_time() - #with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt): - - for i, (xs, ys) in enumerate(dl_train): - xs, ys = xs.to(device), ys.to(device) - - #Methode exacte - #final_loss = 0 - #for tf_idx in range(fmodel['data_aug']._nb_tf): - # fmodel['data_aug'].transf_idx=tf_idx - # logits = fmodel(xs) - # loss = F.cross_entropy(logits, ys) - # #loss.backward(retain_graph=True) - # final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ? - #loss = final_loss - - if(not KLdiv): - #Methode uniforme - logits = fmodel(xs) # modified `params` can also be passed as a kwarg - loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards() - - if fmodel._data_augmentation: #Weight loss - w_loss = fmodel['data_aug'].loss_weight()#.to(device) - loss = loss * w_loss - loss = loss.mean() - - else: - #Methode KL div - if fmodel._data_augmentation : - fmodel.augment(mode=False) - sup_logits = fmodel(xs) - fmodel.augment(mode=True) - else: - sup_logits = fmodel(xs) - log_sup=F.log_softmax(sup_logits, dim=1) - loss = F.cross_entropy(log_sup, ys) - - if fmodel._data_augmentation: - aug_logits = fmodel(xs) - log_aug=F.log_softmax(aug_logits, dim=1) - - w_loss = fmodel['data_aug'].loss_weight() #Weight loss - - #if epoch>50: #debut differe ? - #KL div w/ logits - Similarite predictions (distributions) - aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug) - aug_loss = aug_loss.sum(dim=-1) - #aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') - aug_loss = (w_loss * aug_loss).mean() - - aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean() - - unsupp_coeff = 1 - loss += aug_loss * unsupp_coeff - - #to visualize computational graph - #print_graph(loss) - - #loss.backward(retain_graph=True) - #print(fmodel['model']._params['b4'].grad) - #print('prob grad', fmodel['data_aug']['prob'].grad) - - #t = time.process_time() - diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) - #print(len(fmodel._fast_params),"step", time.process_time()-t) - - if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step - #print("meta") - - val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #+ fmodel['data_aug'].reg_loss() - #print_graph(val_loss) - - #t = time.process_time() - val_loss.backward() - #print("meta", time.process_time()-t) - #print('proba grad',model['data_aug']['prob'].grad) - if model['data_aug']['prob'].grad is None or model['data_aug']['mag'] is None: - print("Warning no grad (iter",i,") :\n Prob-",model['data_aug']['prob'].grad,"\n Mag-", model['data_aug']['mag'].grad) - - countcopy+=1 - model_copy(src=fmodel, dst=model) - optim_copy(dopt=diffopt, opt=inner_opt) - - torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN - - #if epoch>50: - meta_opt.step() - model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 - try: #Dataugv6 - model['data_aug'].next_TF_set() - except: - pass - - fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) - - meta_opt.zero_grad() - - tf = time.process_time() - - #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) - #viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight()) - - if(not high_grad_track): - countcopy+=1 - model_copy(src=fmodel, dst=model) - optim_copy(dopt=diffopt, opt=inner_opt) - val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) - - #Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False) - fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) - - accuracy, test_loss =test(model) - model.train() - - #### Log #### - #print(type(model['data_aug']) is dataug.Data_augV5) - param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])] - data={ - "epoch": epoch, - "train_loss": loss.item(), - "val_loss": val_loss.item(), - "acc": accuracy, - "time": tf - t0, - - "param": param #if isinstance(model['data_aug'], Data_augV5) - #else [p.item() for p in model['data_aug']['prob']], - } - log.append(data) - ############# - #### Print #### - if(print_freq and epoch%print_freq==0): - print('-'*9) - print('Epoch : %d/%d'%(epoch,epochs)) - print('Time : %.00f'%(tf - t0)) - print('Train loss :',loss.item(), '/ val loss', val_loss.item()) - print('Accuracy :', max([x["acc"] for x in log])) - print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) - print('TF Proba :', model['data_aug']['prob'].data) - #print('proba grad',model['data_aug']['prob'].grad) - print('TF Mag :', model['data_aug']['mag'].data) - #print('Mag grad',model['data_aug']['mag'].grad) - #print('Reg loss:', model['data_aug'].reg_loss().item()) - #print('Aug loss', aug_loss.item()) - ############# - if val_loss_monitor : - model.eval() - val_loss_monitor.register(test_loss)#val_loss.item()) - if val_loss_monitor.end_training(): break #Stop training - model.train() - - if not model.is_augmenting() and (epoch == dataug_epoch_start or (val_loss_monitor and val_loss_monitor.limit_reached()==1)): - print('Starting Data Augmention...') - dataug_epoch_start = epoch - model.augment(mode=True) - if inner_it != 0: high_grad_track = True - - try: - viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) - viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight()) - except: - print("Couldn't save finals samples") - pass - - #print("Copy ", countcopy) - return log - def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, hp_opt=False, save_sample=False): device = next(model.parameters()).device log = [] @@ -1004,4 +423,4 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start print("Couldn't save finals samples") pass - return log + return log \ No newline at end of file diff --git a/higher/transformations.py b/higher/transformations.py index 430e7e8..0eb4456 100755 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -1,58 +1,25 @@ +""" PyTorch implementation of some PIL image transformations. + + Those implementation are thinked to take advantages of batched computation of PyTorch on GPU. + + Based on Kornia library. + See: https://github.com/kornia/kornia + + And PIL. + See: + https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py + https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818 + + Inspired from AutoAugment. + See: https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py +""" + import torch import kornia import random ### Available TF for Dataug ### -''' -TF_dict={ #Dataugv4 - ## Geometric TF ## - 'Identity' : (lambda x, mag: x), - 'FlipUD' : (lambda x, mag: flipUD(x)), - 'FlipLR' : (lambda x, mag: flipLR(x)), - 'Rotate': (lambda x, mag: rotate(x, angle=torch.tensor([rand_int(mag, maxval=30)for _ in x], device=x.device))), - 'TranslateX': (lambda x, mag: translate(x, translation=torch.tensor([[rand_int(mag, maxval=20), 0] for _ in x], device=x.device))), - 'TranslateY': (lambda x, mag: translate(x, translation=torch.tensor([[0, rand_int(mag, maxval=20)] for _ in x], device=x.device))), - 'ShearX': (lambda x, mag: shear(x, shear=torch.tensor([[rand_float(mag, maxval=0.3), 0] for _ in x], device=x.device))), - 'ShearY': (lambda x, mag: shear(x, shear=torch.tensor([[0, rand_float(mag, maxval=0.3)] for _ in x], device=x.device))), - ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), - 'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), - 'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), - 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), - 'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))), - 'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch - - #Non fonctionnel - #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) - #'Equalize': (lambda mag: None), -} -''' -''' -TF_dict={ #Dataugv5 #AutoAugment - ## Geometric TF ## - 'Identity' : (lambda x, mag: x), - 'FlipUD' : (lambda x, mag: flipUD(x)), - 'FlipLR' : (lambda x, mag: flipLR(x)), - 'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))), - 'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))), - 'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))), - 'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))), - 'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))), - - ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient - 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] - - #Non fonctionnel - #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) - #'Equalize': (lambda mag: None), -} -''' # Dictionnary mapping tranformations identifiers to their function. # Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data. TF_dict={ #Dataugv5 @@ -112,6 +79,9 @@ TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} #TF that don't TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} #TF which implemetation doesn't allow gradient propagaition. TF_ignore_mag= TF_no_mag | TF_no_grad #TF for which magnitude should be ignored (Magnitude fixed). +PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted +PARAMETER_MIN = 0.1 # What is the min 'level' a transform could be predicted + def int_image(float_image): """Convert a float Tensor/Image to an int Tensor/Image. @@ -121,10 +91,10 @@ def int_image(float_image): This will also result in the loss of the gradient associated to input as gradient cannot be tracked on int Tensor. Args: - float_image (torch.float): Image tensor. + float_image (FloatTensor): Image tensor. Returns: - (torch.uint8) Converted tensor. + (ByteTensor) Converted tensor. """ return (float_image*255.).type(torch.uint8) @@ -132,10 +102,10 @@ def float_image(int_image): """Convert a int Tensor/Image to an float Tensor/Image. Args: - int_image (torch.uint8): Image tensor. + int_image (ByteTensor): Image tensor. Returns: - (torch.float) Converted tensor. + (FloatTensor) Converted tensor. """ return int_image.type(torch.float)/255. @@ -162,7 +132,7 @@ def rand_floats(size, mag, maxval, minval=None): minval (float): Minimum value that can be generated. (default: -maxval) Returns: - Generated batch of float values between [minval, maxval]. + (Tensor) Generated batch of float values between [minval, maxval]. """ real_mag = float_parameter(mag, maxval=maxval) if not minval : minval = -real_mag @@ -170,30 +140,52 @@ def rand_floats(size, mag, maxval, minval=None): return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag] def invScale_rand_floats(size, mag, maxval, minval): - #Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval] - real_mag = float_parameter(float(PARAMETER_MAX) - mag, maxval=maxval-minval)+minval - return real_mag + (maxval-real_mag) * torch.rand(size, device=mag.device) #[real_mag, max_val] + """Generate a batch of random values. + + Similar to rand_floats() except that the mag is used in an inversed scale. + + Mag:[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] + + Args: + size (int): Number of value to generate. + mag (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX]. + maxval (float): Maximum value that can be generated. This will be scaled to mag/PARAMETER_MAX. + minval (float): Minimum value that can be generated. (default: -maxval) + + Returns: + (Tensor) Generated batch of float values between [minval, maxval]. + """ + real_mag = float_parameter(float(PARAMETER_MAX) - mag, maxval=maxval-minval)+minval + return real_mag + (maxval-real_mag) * torch.rand(size, device=mag.device) #[real_mag, max_val] def zero_stack(tensor, zero_pos): - if zero_pos==0: - return torch.stack((tensor, torch.zeros((tensor.shape[0],), device=tensor.device)), dim=1) - if zero_pos==1: - return torch.stack((torch.zeros((tensor.shape[0],), device=tensor.device), tensor), dim=1) - else: - raise Exception("Invalid zero_pos : ", zero_pos) + """Add a row of zeros to a Tensor. + + This function is intended to be used with single row Tensor, thus returning a 2 dimension Tensor. + + Args: + tensor (Tensor): Tensor to be stacked with zeros. + zero_pos (int): Wheter the zeros should be added before or after the Tensor. Either 0 or 1. + + Returns: + Stacked Tensor. + """ + if zero_pos==0: + return torch.stack((tensor, torch.zeros((tensor.shape[0],), device=tensor.device)), dim=1) + if zero_pos==1: + return torch.stack((torch.zeros((tensor.shape[0],), device=tensor.device), tensor), dim=1) + else: + raise Exception("Invalid zero_pos : ", zero_pos) -#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137 -PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted -PARAMETER_MIN = 0.1 def float_parameter(level, maxval): - """Helper function to scale `val` between 0 and maxval . - Args: - level: Level of the operation that will be between [0, `PARAMETER_MAX`]. - maxval: Maximum value that the operation can have. This will be scaled - to level/PARAMETER_MAX. - Returns: - A float that results from scaling `maxval` according to `level`. - """ + """Scale level between 0 and maxval. + + Args: + level (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX]. + maxval: Maximum value that the operation can have. This will be scaled to level/PARAMETER_MAX. + Returns: + A float that results from scaling `maxval` according to `level`. + """ #return float(level) * maxval / PARAMETER_MAX return (level * maxval / PARAMETER_MAX)#.to(torch.float) @@ -211,6 +203,14 @@ def float_parameter(level, maxval): # return (level * maxval / PARAMETER_MAX) def flipLR(x): + """Flip horizontaly/Left-Right images. + + Args: + x (Tensor): Batch of images. + + Returns: + (Tensor): Batch of fliped images. + """ device = x.device (batch_size, channels, h, w) = x.shape @@ -222,6 +222,14 @@ def flipLR(x): return kornia.warp_perspective(x, M, dsize=(h, w)) def flipUD(x): + """Flip vertically/Up-Down images. + + Args: + x (Tensor): Batch of images. + + Returns: + (Tensor): Batch of fliped images. + """ device = x.device (batch_size, channels, h, w) = x.shape @@ -233,20 +241,65 @@ def flipUD(x): return kornia.warp_perspective(x, M, dsize=(h, w)) def rotate(x, angle): - return kornia.rotate(x, angle=angle.type(torch.float)) #Kornia ne supporte pas les int + """Rotate images. + + Args: + x (Tensor): Batch of images. + angle (Tensor): Angles (degrees) of rotation for each images. + + Returns: + (Tensor): Batch of rotated images. + """ + return kornia.rotate(x, angle=angle.type(torch.float)) #Kornia ne supporte pas les int def translate(x, translation): - #print(translation) - return kornia.translate(x, translation=translation.type(torch.float)) #Kornia ne supporte pas les int + """Translate images. + + Args: + x (Tensor): Batch of images. + translation (Tensor): Distance (pixels) of translation for each images. + + Returns: + (Tensor): Batch of translated images. + """ + return kornia.translate(x, translation=translation.type(torch.float)) #Kornia ne supporte pas les int def shear(x, shear): - return kornia.shear(x, shear=shear) + """Shear images. + + Args: + x (Tensor): Batch of images. + shear (Tensor): Angle of shear for each images. + + Returns: + (Tensor): Batch of skewed images. + """ + return kornia.shear(x, shear=shear) def contrast(x, contrast_factor): - return kornia.adjust_contrast(x, contrast_factor=contrast_factor) #Expect image in the range of [0, 1] + """Adjust contast of images. + + Args: + x (FloatTensor): Batch of images. + contrast_factor (FloatTensor): Contrast adjust factor per element in the batch. + 0 generates a compleatly black image, 1 does not modify the input image while any other non-negative number modify the brightness by this factor. + + Returns: + (Tensor): Batch of adjusted images. + """ + return kornia.adjust_contrast(x, contrast_factor=contrast_factor) #Expect image in the range of [0, 1] -#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageEnhance.py def color(x, color_factor): + """Adjust color of images. + + Args: + x (Tensor): Batch of images. + color_factor (Tensor): Color factor for each images. + 0.0 gives a black and white image. A factor of 1.0 gives the original image. + + Returns: + (Tensor): Batch of adjusted images. + """ (batch_size, channels, h, w) = x.shape gray_x = kornia.rgb_to_grayscale(x) @@ -254,11 +307,31 @@ def color(x, color_factor): return blend(gray_x, x, color_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1] def brightness(x, brightness_factor): + """Adjust brightness of images. + + Args: + x (Tensor): Batch of images. + brightness_factor (Tensor): Brightness factor for each images. + 0.0 gives a black image. A factor of 1.0 gives the original image. + + Returns: + (Tensor): Batch of adjusted images. + """ device = x.device return blend(torch.zeros(x.size(), device=device), x, brightness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1] def sharpeness(x, sharpness_factor): + """Adjust sharpness of images. + + Args: + x (Tensor): Batch of images. + sharpness_factor (Tensor): Sharpness factor for each images. + 0.0 gives a black image. A factor of 1.0 gives the original image. + + Returns: + (Tensor): Batch of adjusted images. + """ device = x.device (batch_size, channels, h, w) = x.shape @@ -269,7 +342,6 @@ def sharpeness(x, sharpness_factor): return blend(smooth_x, x, sharpness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1] -#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py def posterize(x, bits): bits = bits.type(torch.uint8) #Perte du gradient x = int_image(x) #Expect image in the range of [0, 1] @@ -365,7 +437,6 @@ def solarize(x, thresholds): return x -#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818 def blend(x,y,alpha): #out = image1 * (1.0 - alpha) + image2 * alpha #return kornia.add_weighted(src1=x, alpha=(1-alpha), src2=y, beta=alpha, gamma=0) #out=src1∗alpha+src2∗beta+gamma #Ne fonctionne pas pour des batch de alpha diff --git a/higher/utils.py b/higher/utils.py index 02fc1eb..6fab9bc 100755 --- a/higher/utils.py +++ b/higher/utils.py @@ -11,53 +11,11 @@ import torch.nn.functional as F import time -class timer(): - def __init__(self): - self._start_time=time.time() - def exec_time(self): - end = time.time() - res = end-self._start_time - self._start_time=end - return res - def print_graph(PyTorch_obj, fig_name='graph'): graph=make_dot(PyTorch_obj) #Loss give the whole graph graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats graph.render(fig_name) -def plot_res(log, fig_name='res', param_names=None): - - epochs = [x["epoch"] for x in log] - - fig, ax = plt.subplots(ncols=3, figsize=(15, 3)) - - ax[0].set_title('Loss') - ax[0].plot(epochs,[x["train_loss"] for x in log], label='Train') - ax[0].plot(epochs,[x["val_loss"] for x in log], label='Val') - ax[0].legend() - - ax[1].set_title('Acc') - ax[1].plot(epochs,[x["acc"] for x in log]) - - if log[0]["param"]!= None: - if isinstance(log[0]["param"],float): - ax[2].set_title('Mag') - ax[2].plot(epochs,[x["param"] for x in log], label='Mag') - ax[2].legend() - else : - ax[2].set_title('Prob') - #for idx, _ in enumerate(log[0]["param"]): - #ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx)) - if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])] - proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])] - ax[2].stackplot(epochs, proba, labels=param_names) - ax[2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5)) - - - fig_name = fig_name.replace('.',',') - plt.savefig(fig_name) - plt.close() - def plot_resV2(log, fig_name='res', param_names=None): epochs = [x["epoch"] for x in log] @@ -144,33 +102,6 @@ def plot_compare(filenames, fig_name='res'): plt.savefig(fig_name, bbox_inches='tight') plt.close() -def plot_res_compare(filenames, fig_name='res'): - - all_data=[] - #legend="" - for idx, file in enumerate(filenames): - #legend+=str(idx)+'-'+file+'\n' - with open(file) as json_file: - data = json.load(json_file) - all_data.append(data) - - n_tf = [len(x["Param_names"]) for x in all_data] - acc = [x["Accuracy"] for x in all_data] - time = [x["Time"][0] for x in all_data] - - fig, ax = plt.subplots(ncols=3, figsize=(30, 8)) - - ax[0].plot(n_tf, acc) - ax[1].plot(n_tf, time) - - ax[0].set_title('Acc') - ax[1].set_title('Time') - #for a in ax: a.legend() - - fig_name = fig_name.replace('.',',') - plt.savefig(fig_name, bbox_inches='tight') - plt.close() - def plot_TF_res(log, tf_names, fig_name='res'): mean = np.mean([x["param"] for x in log], axis=0) @@ -203,39 +134,6 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None): print("Sample saved :", fig_name) plt.close() -def model_copy(src,dst, patch_copy=True, copy_grad=True): - #model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats) - - dst.load_state_dict(src.state_dict()) #Do not copy gradient ! - - if patch_copy: - dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ? - dst['data_aug'].load_state_dict(src['data_aug'].state_dict()) - - #Copie des gradients - if copy_grad: - for paramName, paramValue, in src.named_parameters(): - for netCopyName, netCopyValue, in dst.named_parameters(): - if paramName == netCopyName: - netCopyValue.grad = paramValue.grad - #netCopyValue=copy.deepcopy(paramValue) - - try: #Data_augV4 - dst['data_aug']._input_info = src['data_aug']._input_info - dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix - except: - pass - -def optim_copy(dopt, opt): - - #inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state... - #opt_param=higher.optim.get_trainable_opt_params(diffopt) - - for group_idx, group in enumerate(opt.param_groups): - # print('gp idx',group_idx) - for p_idx, p in enumerate(group['params']): - opt.state[p]=dopt.state[group_idx][p_idx] - def print_torch_mem(add_info=''): nb=0 @@ -282,43 +180,8 @@ def plot_TF_influence(log, fig_name='TF_influence', param_names=None): plt.savefig(fig_name, bbox_inches='tight') plt.close() -class loss_monitor(): #Voir https://github.com/pytorch/ignite - def __init__(self, patience, end_train=1): - self.patience = patience - self.end_train = end_train - self.counter = 0 - self.best_score = None - self.reached_limit = 0 - - def register(self, loss): - if self.best_score is None: - self.best_score = loss - elif loss > self.best_score: - self.counter += 1 - #if not self.reached_limit: - print("loss no improve counter", self.counter, self.reached_limit) - else: - self.best_score = loss - self.counter = 0 - def limit_reached(self): - if self.counter >= self.patience: - self.counter = 0 - self.reached_limit +=1 - self.best_score = None - return self.reached_limit - - def end_training(self): - if self.limit_reached() >= self.end_train: - return True - else: - return False - - def reset(self): - self.__init__(self.patience, self.end_train) - ### https://github.com/facebookresearch/higher/issues/18 #### from torch._six import inf - def clip_norm(tensors, max_norm, norm_type=2): r"""Clips norm of passed tensors. The norm is computed over all tensors together, as if they were