diff --git a/higher/dataug.py b/higher/dataug.py index f0d6bf8..63ddaf2 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -327,7 +327,7 @@ class Data_augV4(nn.Module): #Transformations avec mask self._TF= list(self._TF_dict.keys()) self._nb_tf= len(self._TF) - self._N_TF = N_TF + self._N_seqTF = N_TF self._fixed_mag=5 #[0, PARAMETER_MAX] self._params = nn.ParameterDict({ @@ -349,7 +349,7 @@ class Data_augV4(nn.Module): #Transformations avec mask x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) self._samples = [] - for _ in range(self._N_TF): + for _ in range(self._N_seqTF): ## Echantillonage ## uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) @@ -501,7 +501,7 @@ class Data_augV4(nn.Module): #Transformations avec mask w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device) for sample in self._samples: tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) - tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF) + tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF) w_loss += tmp_w w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) @@ -526,28 +526,27 @@ class Data_augV4(nn.Module): #Transformations avec mask def __str__(self): if not self._mix_dist: - return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF) + return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF) else: - return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF) + return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF) -class Data_augV5(nn.Module): #Transformations avec mask - def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0): +class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, glob_mag=True): super(Data_augV5, self).__init__() assert len(TF_dict)>0 self._data_augmentation = True - #self._TF_matrix={} - #self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix - self._mag_fct = TF_dict - self._TF=list(self._mag_fct.keys()) + self._TF_dict = TF_dict + self._TF= list(self._TF_dict.keys()) self._nb_tf= len(self._TF) - self._N_TF = N_TF + self._N_seqTF = N_TF - self._fixed_mag=5 #[0, PARAMETER_MAX] + #self._fixed_mag=5 #[0, PARAMETER_MAX] self._params = nn.ParameterDict({ "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme + "mag" : nn.Parameter(torch.tensor(0.5).expand(self._nb_tf) if glob_mag else torch.tensor(0.5).repeat(self._nb_tf)) #[0, PARAMETER_MAX]/10 }) self._samples = [] @@ -565,7 +564,7 @@ class Data_augV5(nn.Module): #Transformations avec mask x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) self._samples = [] - for _ in range(self._N_TF): + for _ in range(self._N_seqTF): ## Echantillonage ## uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) @@ -581,7 +580,7 @@ class Data_augV5(nn.Module): #Transformations avec mask ## Transformations ## x = self.apply_TF(x, sample) return x - + def apply_TF(self, x, sampled_TF): device = x.device smps_x=[] @@ -591,44 +590,12 @@ class Data_augV5(nn.Module): #Transformations avec mask smp_x = x[mask] #torch.masked_select() ? if smp_x.shape[0]!=0: #if there's data to TF - magnitude=self._fixed_mag + magnitude=self._params["mag"][tf_idx]*10 tf=self._TF[tf_idx] + #print(magnitude) - ## Geometric TF ## - if tf=='Identity': - pass - elif tf=='FlipLR': - smp_x = TF.flipLR(smp_x) - elif tf=='FlipUD': - smp_x = TF.flipUD(smp_x) - elif tf=='Rotate': - smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='TranslateX' or tf=='TranslateY': - smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='ShearX' or tf=='ShearY' : - smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - - ## Color TF (Expect image in the range of [0, 1]) ## - elif tf=='Contrast': - smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Color': - smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Brightness': - smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Sharpness': - smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Posterize': - smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device)) - elif tf=='Solarize': - smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Equalize': - smp_x = TF.equalize(smp_x) - elif tf=='Auto_Contrast': - smp_x = TF.auto_contrast(smp_x) - else: - raise Exception("Invalid TF requested : ", tf) - - x[mask]=smp_x # Refusionner eviter x[mask] : in place + x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place + return x def adjust_prob(self, soft=False): #Detach from gradient ? @@ -636,9 +603,14 @@ class Data_augV5(nn.Module): #Transformations avec mask if soft : self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible else: + #self._params['prob'].clamp(min=0.0,max=1.0) self._params['prob'].data = F.relu(self._params['prob'].data) + #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) + self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 + + def loss_weight(self): # 1 seule TF #self._sample = self._samples[-1] @@ -647,11 +619,11 @@ class Data_augV5(nn.Module): #Transformations avec mask #w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) #w_loss = torch.sum(w_loss,dim=1) - #Plusieurs TF sequentielles (Hypothese ordre negligeable) + #Plusieurs TF sequentielles w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device) for sample in self._samples: tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) - tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF) + tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF) w_loss += tmp_w w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) @@ -676,9 +648,9 @@ class Data_augV5(nn.Module): #Transformations avec mask def __str__(self): if not self._mix_dist: - return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF) + return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF) else: - return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF) + return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF) class Augmented_model(nn.Module): diff --git a/higher/test_dataug.py b/higher/test_dataug.py index f2b88ab..3485f35 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -68,7 +68,7 @@ if __name__ == "__main__": t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict - aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device) #aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) @@ -88,7 +88,7 @@ if __name__ == "__main__": print('-'*9) #''' #### TF number tests #### - #''' + ''' res_folder="res/TF_nb_tests/" epochs= 100 inner_its = [0, 1, 10] @@ -128,4 +128,4 @@ if __name__ == "__main__": print('Log :\"',f.name, '\" saved !') print('-'*9) - #''' \ No newline at end of file + ''' \ No newline at end of file diff --git a/higher/train_utils.py b/higher/train_utils.py index 6ed1e09..97c2c81 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -649,6 +649,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) print('TF Proba :', model['data_aug']['prob'].data) #print('proba grad',aug_model['data_aug']['prob'].grad) + print('TF Mag :', model['data_aug']['mag'].data) ############# #### Log #### data={