import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import * #import kornia #import random #import numpy as np import copy import transformations as TF class Data_aug(nn.Module): #Rotation parametree def __init__(self): super(Data_aug, self).__init__() self._data_augmentation = True self._params = nn.ParameterDict({ "prob": nn.Parameter(torch.tensor(0.5)), "mag": nn.Parameter(torch.tensor(1.0)) }) #self.params["mag"].register_hook(print) def forward(self, x): if self._data_augmentation and random.random() < self._params["prob"]: #print('Aug') batch_size = x.shape[0] # create transformation (rotation) alpha = self._params["mag"]*180 # in degrees angle = torch.ones(batch_size, device=x.device) * alpha # define the rotation center center = torch.ones(batch_size, 2, device=x.device) center[..., 0] = x.shape[3] / 2 # x center[..., 1] = x.shape[2] / 2 # y #print(x.shape, center) # define the scale factor scale = torch.ones(batch_size, device=x.device) # compute the transformation matrix M = kornia.get_rotation_matrix2d(center, angle, scale) # apply the transformation to original image x = kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w) return x def eval(self): self.augment(mode=False) nn.Module.eval(self) def augment(self, mode=True): self._data_augmentation=mode def __getitem__(self, key): return self._params[key] def __str__(self): return "Data_aug(Mag-1 TF)" class Data_augV2(nn.Module): #Methode exacte def __init__(self): super(Data_augV2, self).__init__() self._data_augmentation = True self._fixed_transf=[0.0, 45.0, 180.0] #Degree rotation #self._fixed_transf=[0.0] self._nb_tf= len(self._fixed_transf) self._params = nn.ParameterDict({ "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme #"prob2": nn.Parameter(torch.ones(len(self._fixed_transf)).softmax(dim=0)) }) #print(self._params["prob"], self._params["prob2"]) self.transf_idx=0 def forward(self, x): if self._data_augmentation: #print('Aug',self._fixed_transf[self.transf_idx]) device = x.device batch_size = x.shape[0] # create transformation (rotation) #alpha = 180 # in degrees alpha = self._fixed_transf[self.transf_idx] angle = torch.ones(batch_size, device=device) * alpha x = self.rotate(x,angle) return x def rotate(self, x, angle): device = x.device batch_size = x.shape[0] # define the rotation center center = torch.ones(batch_size, 2, device=device) center[..., 0] = x.shape[3] / 2 # x center[..., 1] = x.shape[2] / 2 # y #print(x.shape, center) # define the scale factor scale = torch.ones(batch_size, device=device) # compute the transformation matrix M = kornia.get_rotation_matrix2d(center, angle, scale) # apply the transformation to original image return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w) def adjust_prob(self): #Detach from gradient ? self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) #print('proba',self._params['prob']) self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 #print('Sum p', sum(self._params['prob'])) def eval(self): self.augment(mode=False) nn.Module.eval(self) def augment(self, mode=True): self._data_augmentation=mode def __getitem__(self, key): return self._params[key] def __str__(self): return "Data_augV2(Exact-%d TF)" % self._nb_tf class Data_augV3(nn.Module): #Echantillonage uniforme/Mixte def __init__(self, mix_dist=0.0): super(Data_augV3, self).__init__() self._data_augmentation = True #self._fixed_transf=[0.0, 45.0, 180.0] #Degree rotation self._fixed_transf=[0.0, 1.0, -1.0] #Flips (Identity,Horizontal,Vertical) #self._fixed_transf=[0.0] self._nb_tf= len(self._fixed_transf) self._params = nn.ParameterDict({ "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme #"prob2": nn.Parameter(torch.ones(len(self._fixed_transf)).softmax(dim=0)) }) #print(self._params["prob"], self._params["prob2"]) self._sample = [] self._mix_dist = False if mix_dist != 0.0: self._mix_dist = True self._mix_factor = max(min(mix_dist, 1.0), 0.0) def forward(self, x): if self._data_augmentation: device = x.device batch_size = x.shape[0] #good_distrib = Uniform(low=torch.zeros(batch_size,1, device=device),high=torch.new_full((batch_size,1),self._params["prob"], device=device)) #bad_distrib = Uniform(low=torch.zeros(batch_size,1, device=device),high=torch.new_full((batch_size,1), 1-self._params["prob"], device=device)) #transform_dist = Categorical(probs=torch.tensor([self._params["prob"], 1-self._params["prob"]], device=device)) #self._sample = transform_dist._sample(sample_shape=torch.Size([batch_size,1])) uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=0) if not self._mix_dist: distrib = uniforme_dist else: distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=0) #Mix distrib reel / uniforme avec mix_factor cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*distrib) self._sample = cat_distrib.sample() TF_param = torch.tensor([self._fixed_transf[x] for x in self._sample], device=device) #Approche de marco peut-etre plus rapide #x = self.rotate(x,angle=TF_param) x = self.flip(x,flip_mat=TF_param) return x def rotate(self, x, angle): device = x.device batch_size = x.shape[0] # define the rotation center center = torch.ones(batch_size, 2, device=device) center[..., 0] = x.shape[3] / 2 # x center[..., 1] = x.shape[2] / 2 # y #print(x.shape, center) # define the scale factor scale = torch.ones(batch_size, device=device) # compute the transformation matrix M = kornia.get_rotation_matrix2d(center, angle, scale) # apply the transformation to original image return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w) def flip(self, x, flip_mat): #print(flip_mat) device = x.device batch_size = x.shape[0] h, w = x.shape[2], x.shape[3] # destination size #points_src = torch.ones(batch_size, 4, 2, device=device) #points_dst = torch.ones(batch_size, 4, 2, device=device) #Identity iM=torch.tensor(np.eye(3)) #Horizontal flip # the source points are the region to crop corners #points_src = torch.FloatTensor([[ # [w - 1, 0], [0, 0], [0, h - 1], [w - 1, h - 1], #]]) # the destination points are the image vertexes #points_dst = torch.FloatTensor([[ # [0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1], #]]) # compute perspective transform #hM = kornia.get_perspective_transform(points_src, points_dst) hM =torch.tensor( [[[-1., 0., w-1], [ 0., 1., 0.], [ 0., 0., 1.]]]) #Vertical flip # the source points are the region to crop corners #points_src = torch.FloatTensor([[ # [0, h - 1], [w - 1, h - 1], [w - 1, 0], [0, 0], #]]) # the destination points are the image vertexes #points_dst = torch.FloatTensor([[ # [0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1], #]]) # compute perspective transform #vM = kornia.get_perspective_transform(points_src, points_dst) vM =torch.tensor( [[[ 1., 0., 0.], [ 0., -1., h-1], [ 0., 0., 1.]]]) #print(vM) M=torch.ones(batch_size, 3, 3, device=device) for i in range(batch_size): # A optimiser if flip_mat[i]==0.0: M[i,]=iM elif flip_mat[i]==1.0: M[i,]=hM elif flip_mat[i]==-1.0: M[i,]=vM # warp the original image by the found transform return kornia.warp_perspective(x, M, dsize=(h, w)) def adjust_prob(self, soft=False): #Detach from gradient ? if soft : self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible else: #self._params['prob'].clamp(min=0.0,max=1.0) self._params['prob'].data = F.relu(self._params['prob'].data) #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) #print('proba',self._params['prob']) self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 #print('Sum p', sum(self._params['prob'])) def loss_weight(self): #w_loss = [self._params["prob"][x] for x in self._sample] #print(self._sample.view(-1,1).shape) #print(self._sample[:10]) w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device) w_loss.scatter_(1, self._sample.view(-1,1), 1) #print(w_loss.shape) #print(w_loss[:10,:]) w_loss = w_loss * self._params["prob"] #print(w_loss.shape) #print(w_loss[:10,:]) w_loss = torch.sum(w_loss,dim=1) #print(w_loss.shape) #print(w_loss[:10]) return w_loss def train(self, mode=None): if mode is None : mode=self._data_augmentation self.augment(mode=mode) #Inutile si mode=None super(Data_augV3, self).train(mode) def eval(self): self.train(mode=False) #super(Augmented_model, self).eval() def augment(self, mode=True): self._data_augmentation=mode def __getitem__(self, key): return self._params[key] def __str__(self): if not self._mix_dist: return "Data_augV3(Uniform-%d TF)" % self._nb_tf else: return "Data_augV3(Mix %.1f-%d TF)" % (self._mix_factor, self._nb_tf) class Data_augV4(nn.Module): #Transformations avec mask def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0): super(Data_augV4, self).__init__() assert len(TF_dict)>0 self._data_augmentation = True #self._TF_matrix={} #self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix self._mag_fct = TF_dict self._TF=list(self._mag_fct.keys()) self._nb_tf= len(self._TF) self._N_TF = N_TF self._fixed_mag=5 #[0, PARAMETER_MAX] self._params = nn.ParameterDict({ "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme }) self._samples = [] self._mix_dist = False if mix_dist != 0.0: self._mix_dist = True self._mix_factor = max(min(mix_dist, 1.0), 0.0) def forward(self, x): if self._data_augmentation: device = x.device batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) self._samples = [] for _ in range(self._N_TF): ## Echantillonage ## uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) if not self._mix_dist: self._distrib = uniforme_dist else: self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor print(self.distrib.shape) cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) sample = cat_distrib.sample() self._samples.append(sample) ## Transformations ## x = self.apply_TF(x, sample) return x ''' def compute_TF_matrix(self, magnitude=None, sample_info= None): print('Computing TF_matrix...') if not magnitude : magnitude=self._fixed_mag if sample_info: self._input_info['h']= sample_info['h'] self._input_info['w']= sample_info['w'] self._input_info['device'] = sample_info['device'] h, w, device= self._input_info['h'], self._input_info['w'], self._input_info['device'] self._TF_matrix={} for tf in self._TF : if tf=='Id': self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.], [ 0., 1., 0.], [ 0., 0., 1.]]], device=device) elif tf=='Rot': center = torch.ones(1, 2, device=device) center[0, 0] = w / 2 # x center[0, 1] = h / 2 # y scale = torch.ones(1, device=device) angle = self._mag_fct[tf](magnitude) * torch.ones(1, device=device) R = kornia.get_rotation_matrix2d(center, angle, scale) #Rotation matrix (1,2,3) self._TF_matrix[tf]=torch.cat((R,torch.tensor([[[ 0., 0., 1.]]], device=device)), dim=1) #TF matrix (1,3,3) elif tf=='FlipLR': self._TF_matrix[tf]=torch.tensor([[[-1., 0., w-1], [ 0., 1., 0.], [ 0., 0., 1.]]], device=device) elif tf=='FlipUD': self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.], [ 0., -1., h-1], [ 0., 0., 1.]]], device=device) else: raise Exception("Invalid TF requested") ''' def apply_TF(self, x, sampled_TF): device = x.device smps_x=[] masks=[] for tf_idx in range(self._nb_tf): mask = sampled_TF==tf_idx #Create selection mask smp_x = x[mask] #torch.masked_select() ? if smp_x.shape[0]!=0: #if there's data to TF magnitude=self._fixed_mag tf=self._TF[tf_idx] ## Geometric TF ## if tf=='Identity': pass elif tf=='FlipLR': smp_x = TF.flipLR(smp_x) elif tf=='FlipUD': smp_x = TF.flipUD(smp_x) elif tf=='Rotate': smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) elif tf=='TranslateX' or tf=='TranslateY': smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) elif tf=='ShearX' or tf=='ShearY' : smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) ## Color TF (Expect image in the range of [0, 1]) ## elif tf=='Contrast': smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) elif tf=='Color': smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) elif tf=='Brightness': smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) elif tf=='Sharpness': smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) elif tf=='Posterize': smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device)) elif tf=='Solarize': smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) elif tf=='Equalize': smp_x = TF.equalize(smp_x) elif tf=='Auto_Contrast': smp_x = TF.auto_contrast(smp_x) else: raise Exception("Invalid TF requested : ", tf) x[mask]=smp_x # Refusionner eviter x[mask] : in place #idx= mask.nonzero() #print('-'*8) #print(idx[0], tf_idx) #print(smp_x[0,]) #x=x.view(-1,3*32*32) #x=x.scatter(dim=0, index=idx, src=smp_x.view(-1,3*32*32)) #Changement des Tensor mais pas visible sur la visualisation... #x=x.view(-1,3,32,32) #print(x[0,]) ''' if len(self._TF_matrix)==0 or self._input_info['h']!=h or self._input_info['w']!=w or self._input_info['device']!=device: #Device different:Pas necessaire de tout recalculer self.compute_TF_matrix(sample_info={'h': x.shape[2], 'w': x.shape[3], 'device': x.device}) TF_matrix = torch.zeros(batch_size, 3, 3, device=device) #All geom TF for tf_idx in range(self._nb_tf): mask = self._sample==tf_idx #Create selection mask TF_matrix[mask,]=self._TF_matrix[self._TF[tf_idx]] x=kornia.warp_perspective(x, TF_matrix, dsize=(h, w)) ''' return x def adjust_prob(self, soft=False): #Detach from gradient ? if soft : self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible else: #self._params['prob'].clamp(min=0.0,max=1.0) self._params['prob'].data = F.relu(self._params['prob'].data) #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 def loss_weight(self): # 1 seule TF #self._sample = self._samples[-1] #w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device) #w_loss.scatter_(dim=1, index=self._sample.view(-1,1), value=1) #w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) #w_loss = torch.sum(w_loss,dim=1) #Plusieurs TF sequentielles w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device) for sample in self._samples: tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF) w_loss += tmp_w w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) w_loss = torch.sum(w_loss,dim=1) return w_loss def train(self, mode=None): if mode is None : mode=self._data_augmentation self.augment(mode=mode) #Inutile si mode=None super(Data_augV4, self).train(mode) def eval(self): self.train(mode=False) def augment(self, mode=True): self._data_augmentation=mode def __getitem__(self, key): return self._params[key] def __str__(self): if not self._mix_dist: return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF) else: return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF) class Augmented_model(nn.Module): def __init__(self, data_augmenter, model): super(Augmented_model, self).__init__() self._mods = nn.ModuleDict({ 'data_aug': data_augmenter, 'model': model }) self.augment(mode=True) def initialize(self): self._mods['model'].initialize() def forward(self, x): return self._mods['model'](self._mods['data_aug'](x)) def augment(self, mode=True): self._data_augmentation=mode self._mods['data_aug'].augment(mode) def train(self, mode=None): if mode is None : mode=self._data_augmentation self._mods['data_aug'].augment(mode) super(Augmented_model, self).train(mode) def eval(self): self.train(mode=False) #super(Augmented_model, self).eval() def items(self): """Return an iterable of the ModuleDict key/value pairs. """ return self._mods.items() def update(self, modules): self._mods.update(modules) def is_augmenting(self): return self._data_augmentation def TF_names(self): try: return self._mods['data_aug']._TF except: return None def __getitem__(self, key): return self._mods[key] def __str__(self): return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"