diff --git a/higher/dataug.py b/higher/dataug.py index 9a0db02..86616e2 100755 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -19,6 +19,7 @@ import copy import transformations as TF + class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) """Data augmentation module with learnable parameters. @@ -68,6 +69,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) #Mag self._shared_mag = shared_mag self._fixed_mag = fixed_mag + if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0: + print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF) + self._fixed_mag=True #Distribution self._fixed_prob=fixed_prob @@ -289,6 +293,308 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) else: return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param) +class Data_augV7(nn.Module): #Proba sequentielles + """Data augmentation module with learnable parameters. + + Applies transformations (TF) to batch of data. + Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py. + The TF probabilities defines a distribution from which we sample the TF applied. + + Replace the use of TF by TF sets which are combinaisons of classic TF. + + Attributes: + _data_augmentation (bool): Wether TF will be applied during forward pass. + _TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied. + _TF (list) : List of TF names. + _nb_tf (int) : Number of TF used. + _N_seqTF (int) : Number of TF to be applied sequentially to each inputs + _shared_mag (bool) : Wether to share a single magnitude parameters for all TF. + _fixed_mag (bool): Wether to lock the TF magnitudes. + _fixed_prob (bool): Wether to lock the TF probabilies. + _samples (list): Sampled TF index during last forward pass. + _mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used. + _fixed_mix (bool): Wether we lock the mix distribution factor. + _params (nn.ParameterDict): Learnable parameters. + _reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes). + _reg_mask (list): Mask selecting the TF considered for the regularisation. + """ + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True): + """Init Data_augv7. + + Args: + TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py) + N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1) + mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0) + fixed_prob (bool): Wether to lock the TF probabilies. (default: False) + fixed_mag (bool): Wether to lock the TF magnitudes. (default: True) + shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True) + """ + super(Data_augV7, self).__init__() + assert len(TF_dict)>0 + assert N_TF>=0 + + self._data_augmentation = True + + #TF + self._TF_dict = TF_dict + self._TF= list(self._TF_dict.keys()) + self._nb_tf= len(self._TF) + self._N_seqTF = N_TF + + #Mag + self._shared_mag = shared_mag + self._fixed_mag = fixed_mag + if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0: + print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF) + self._fixed_mag=True + + #Distribution + self._fixed_prob=fixed_prob + self._samples = [] + + self._mix_dist = False + if mix_dist != 0.0: #Mix dist + self._mix_dist = True + + self._fixed_mix=True + if mix_dist is None: #Learn Mix dist + self._fixed_mix = False + mix_dist=0.5 + + #TF sets + no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}} + cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive + def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF + TF_sets=[] + if set_size>1: + for i in range(n_TF): + if not cons_test(i, idx_prefix): + TF_sets += generate_TF_sets(n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i]) + else: + TF_sets+=[[idx_prefix+[i]] for i in range(n_TF) if not cons_test(i, idx_prefix)] + return TF_sets + + self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze() + self._nb_TF_sets=len(self._TF_sets) + print("Number of TF sets:",self._nb_TF_sets) + + #Params + init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2 + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.ones(self._nb_TF_sets)/self._nb_TF_sets), #Distribution prob uniforme + "mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag + else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX] + "mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999)) + }) + + #for tf in TF.TF_no_grad : + # if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter + #for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag + + #Mag regularisation + if not self._fixed_mag: + if self._shared_mag : + self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max + else: + self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in TF.TF_ignore_mag] + self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max + + def forward(self, x): + """ Main method of the Data augmentation module. + + Args: + x (Tensor): Batch of data. + + Returns: + Tensor : Batch of tranformed data. + """ + self._samples = None + if self._data_augmentation:# and TF.random.random() < 0.5: + device = x.device + batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] + + x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) + + + ## Echantillonage ## + uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1) + + if not self._mix_dist: + self._distrib = uniforme_dist + else: + prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] + mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"] + self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor + + cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_TF_sets), device=device)*self._distrib) + sample = cat_distrib.sample() + + self._samples=sample + TF_samples=self._TF_sets[sample,:].to(device) #[Batch_size, TFseq] + + for i in range(self._N_seqTF): + ## Transformations ## + x = self.apply_TF(x, TF_samples[:,i]) + return x + + def apply_TF(self, x, sampled_TF): + """ Applies the sampled transformations. + + Args: + x (Tensor): Batch of data. + sampled_TF (Tensor): Indexes of the TF to be applied to each element of data. + + Returns: + Tensor: Batch of tranformed data. + """ + device = x.device + batch_size, channels, h, w = x.shape + smps_x=[] + + for tf_idx in range(self._nb_tf): + mask = sampled_TF==tf_idx #Create selection mask + smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim) + + if smp_x.shape[0]!=0: #if there's data to TF + magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx] + if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param + + tf=self._TF[tf_idx] + + #In place + #x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) + + #Out of place + smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude) + idx= mask.nonzero() + idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + x=x.scatter(dim=0, index=idx, src=smp_x) + + return x + + def adjust_param(self, soft=False): #Detach from gradient ? + """ Enforce limitations to the learned parameters. + + Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters. + + Args: + soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False) + """ + if not self._fixed_prob: + if soft : + self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible + else: + self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0) + self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 + + if not self._fixed_mag: + self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) + + if not self._fixed_mix: + self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999) + + def loss_weight(self): + """ Weights for the loss. + Compute the weights for the loss of each inputs depending on wich TF was applied to them. + Should be applied to the loss before reduction. + + TODO: Take into account the order of application of the TF. + + Returns: + Tensor : Loss weights. + """ + if self._samples is None : return 1 #Pas d'echantillon = pas de ponderation + + prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] + + w_loss = torch.zeros((self._samples.shape[0],self._nb_TF_sets), device=self._samples.device) + w_loss.scatter_(1, self._samples.view(-1,1), 1) + w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) + w_loss = torch.sum(w_loss,dim=1) + return w_loss + + def reg_loss(self, reg_factor=0.005): + """ Regularisation term used to learn the magnitudes. + Use an L2 loss to encourage high magnitudes TF. + + Args: + reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005) + Returns: + Tensor containing the regularisation loss value. + """ + if self._fixed_mag: + return torch.tensor(0) + else: + #return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean') + mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask] + max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') + return max_mag_reg + + def TF_prob(self): #Eviter recalcul si pas de changement des proba + #print("WARNING: Calcul de proba inexact") + res=torch.zeros(self._nb_tf) + for idx_tf in range(self._nb_tf): + for i, t_set in enumerate(self._TF_sets): + if idx_tf in t_set: + res[idx_tf]+=self._params['prob'][i] + + return res/sum(res) #*(self._nb_tf/self._nb_TF_sets) + + def train(self, mode=True): + """ Set the module training mode. + + Args: + mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None) + """ + #if mode is None : + # mode=self._data_augmentation + self.augment(mode=mode) #Inutile si mode=None + super(Data_augV7, self).train(mode) + return self + + def eval(self): + """ Set the module to evaluation mode. + """ + return self.train(mode=False) + + def augment(self, mode=True): + """ Set the augmentation mode. + + Args: + mode (bool): Wether to perform data augmentation on the forward pass. (default: True) + """ + self._data_augmentation=mode + + def __getitem__(self, key): + """Access to the learnable parameters + Args: + key (string): Name of the learnable parameter to access. + + Returns: + nn.Parameter. + """ + if key == 'prob': #Override prob access + return self.TF_prob() + return self._params[key] + + def __str__(self): + """Name of the module + + Returns: + String containing the name of the module as well as the higher levels parameters. + """ + dist_param='' + if self._fixed_prob: dist_param+='Fx' + mag_param='Mag' + if self._fixed_mag: mag_param+= 'Fx' + if self._shared_mag: mag_param+= 'Sh' + if not self._mix_dist: + return "Data_augV7(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param) + elif self._fixed_mix: + return "Data_augV7(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param) + else: + return "Data_augV7(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param) + class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide """RandAugment implementation. diff --git a/higher/old/dataug_old.py b/higher/old/dataug_old.py index 2ffdf51..523022a 100644 --- a/higher/old/dataug_old.py +++ b/higher/old/dataug_old.py @@ -585,7 +585,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle #Mauvais resultats print("Warning : using only fixed set of TF : ", self._fixed_TF) self._TF_sets=torch.tensor([self._fixed_TF]) else: - def generate_TF_sets(n_TF, set_size, idx_prefix=[]): + def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every combinaison (without reuse) of TF TF_sets=[] if len(idx_prefix)!=0: if set_size>2: diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 972c6aa..8bcd5d7 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -170,7 +170,7 @@ if __name__ == "__main__": tf_dict = {k: TF.TF_dict[k] for k in tf_names} model = Higher_model(model) #run_dist_dataugV3 - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) @@ -179,7 +179,7 @@ if __name__ == "__main__": inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, opt_param=optim_param, - print_freq=1, + print_freq=20, KLdiv=True, hp_opt=False) diff --git a/higher/transformations.py b/higher/transformations.py index 0eb4456..c6cefde 100755 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -187,20 +187,20 @@ def float_parameter(level, maxval): A float that results from scaling `maxval` according to `level`. """ - #return float(level) * maxval / PARAMETER_MAX - return (level * maxval / PARAMETER_MAX)#.to(torch.float) + #return float(level) * maxval / PARAMETER_MAX + return (level * maxval / PARAMETER_MAX)#.to(torch.float) #def int_parameter(level, maxval): #Perte de gradient - """Helper function to scale `val` between 0 and maxval . - Args: + """Helper function to scale `val` between 0 and maxval . + Args: level: Level of the operation that will be between [0, `PARAMETER_MAX`]. maxval: Maximum value that the operation can have. This will be scaled to level/PARAMETER_MAX. - Returns: + Returns: An int that results from scaling `maxval` according to `level`. - """ - #return int(level * maxval / PARAMETER_MAX) -# return (level * maxval / PARAMETER_MAX) + """ + #return int(level * maxval / PARAMETER_MAX) + # return (level * maxval / PARAMETER_MAX) def flipLR(x): """Flip horizontaly/Left-Right images.