diff --git a/higher/dataug.py b/higher/dataug.py index 03609f2..a377e51 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -320,31 +320,6 @@ class Data_augV4(nn.Module): #Transformations avec mask #self._TF_matrix={} #self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix - ''' - self._mag_fct={ #f(mag_normalise)=mag_reelle - ## Geometric TF ## - 'Identity' : (lambda mag: None), - 'FlipUD' : (lambda mag: None), - 'FlipLR' : (lambda mag: None), - 'Rotate': (lambda mag: random.randint(-int_parameter(mag, maxval=30), int_parameter(mag, maxval=30))), - 'TranslateX': (lambda mag: [random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20)), 0]), - 'TranslateY': (lambda mag: [0, random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20))]), - 'ShearX': (lambda mag: [random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3)), 0]), - 'ShearY': (lambda mag: [0, random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3))]), - - ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast': (lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))), - 'Color':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))), - 'Brightness':(lambda mag: random.uniform(1., float_parameter(mag, maxval=1.9))), - 'Sharpness':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))), - 'Posterize': (lambda mag: random.randint(4, int_parameter(mag, maxval=8))), - 'Solarize': (lambda mag: random.randint(1, int_parameter(mag, maxval=256))/256.), #=>Image entre [0,1] #Pas opti pour des batch - - #Non fonctionnel - 'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) - #'Equalize': (lambda mag: None), - } - ''' self._mag_fct = TF_dict self._TF=list(self._mag_fct.keys()) self._nb_tf= len(self._TF) @@ -380,77 +355,8 @@ class Data_augV4(nn.Module): #Transformations avec mask self._sample = cat_distrib.sample() ## Transformations ## - #''' x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) - smps_x=[] - masks=[] - for tf_idx in range(self._nb_tf): - mask = self._sample==tf_idx #Create selection mask - smp_x = x[mask] #torch.masked_select() ? - - if smp_x.shape[0]!=0: #if there's data to TF - magnitude=self._fixed_mag - tf=self._TF[tf_idx] - - ## Geometric TF ## - if tf=='Identity': - pass - elif tf=='FlipLR': - smp_x = TF.flipLR(smp_x) - elif tf=='FlipUD': - smp_x = TF.flipUD(smp_x) - elif tf=='Rotate': - smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='TranslateX' or tf=='TranslateY': - smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='ShearX' or tf=='ShearY' : - smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - - ## Color TF (Expect image in the range of [0, 1]) ## - elif tf=='Contrast': - smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Color': - smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Brightness': - smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Sharpness': - smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Posterize': - smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device)) - elif tf=='Solarize': - smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) - elif tf=='Equalize': - smp_x = TF.equalize(smp_x) - elif tf=='Auto_Contrast': - smp_x = TF.auto_contrast(smp_x) - else: - raise Exception("Invalid TF requested : ", tf) - - x[mask]=smp_x # Refusionner eviter x[mask] : in place - - #idx= mask.nonzero() - #print('-'*8) - #print(idx[0], tf_idx) - #print(smp_x[0,]) - #x=x.view(-1,3*32*32) - #x=x.scatter(dim=0, index=idx, src=smp_x.view(-1,3*32*32)) #Changement des Tensor mais pas visible sur la visualisation... - #x=x.view(-1,3,32,32) - #print(x[0,]) - - ''' - if len(self._TF_matrix)==0 or self._input_info['h']!=h or self._input_info['w']!=w or self._input_info['device']!=device: #Device different:Pas necessaire de tout recalculer - self.compute_TF_matrix(sample_info={'h': x.shape[2], - 'w': x.shape[3], - 'device': x.device}) - - TF_matrix = torch.zeros(batch_size, 3, 3, device=device) #All geom TF - - for tf_idx in range(self._nb_tf): - mask = self._sample==tf_idx #Create selection mask - TF_matrix[mask,]=self._TF_matrix[self._TF[tf_idx]] - - x=kornia.warp_perspective(x, TF_matrix, dsize=(h, w)) - ''' + x = self.apply_TF(x, self._sample) return x ''' def compute_TF_matrix(self, magnitude=None, sample_info= None): @@ -489,6 +395,79 @@ class Data_augV4(nn.Module): #Transformations avec mask else: raise Exception("Invalid TF requested") ''' + def apply_TF(self, x, sampled_TF): + device = x.device + smps_x=[] + masks=[] + for tf_idx in range(self._nb_tf): + mask = sampled_TF==tf_idx #Create selection mask + smp_x = x[mask] #torch.masked_select() ? + + if smp_x.shape[0]!=0: #if there's data to TF + magnitude=self._fixed_mag + tf=self._TF[tf_idx] + + ## Geometric TF ## + if tf=='Identity': + pass + elif tf=='FlipLR': + smp_x = TF.flipLR(smp_x) + elif tf=='FlipUD': + smp_x = TF.flipUD(smp_x) + elif tf=='Rotate': + smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='TranslateX' or tf=='TranslateY': + smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='ShearX' or tf=='ShearY' : + smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + + ## Color TF (Expect image in the range of [0, 1]) ## + elif tf=='Contrast': + smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Color': + smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Brightness': + smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Sharpness': + smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Posterize': + smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device)) + elif tf=='Solarize': + smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Equalize': + smp_x = TF.equalize(smp_x) + elif tf=='Auto_Contrast': + smp_x = TF.auto_contrast(smp_x) + else: + raise Exception("Invalid TF requested : ", tf) + + x[mask]=smp_x # Refusionner eviter x[mask] : in place + + #idx= mask.nonzero() + #print('-'*8) + #print(idx[0], tf_idx) + #print(smp_x[0,]) + #x=x.view(-1,3*32*32) + #x=x.scatter(dim=0, index=idx, src=smp_x.view(-1,3*32*32)) #Changement des Tensor mais pas visible sur la visualisation... + #x=x.view(-1,3,32,32) + #print(x[0,]) + + ''' + if len(self._TF_matrix)==0 or self._input_info['h']!=h or self._input_info['w']!=w or self._input_info['device']!=device: #Device different:Pas necessaire de tout recalculer + self.compute_TF_matrix(sample_info={'h': x.shape[2], + 'w': x.shape[3], + 'device': x.device}) + + TF_matrix = torch.zeros(batch_size, 3, 3, device=device) #All geom TF + + for tf_idx in range(self._nb_tf): + mask = self._sample==tf_idx #Create selection mask + TF_matrix[mask,]=self._TF_matrix[self._TF[tf_idx]] + + x=kornia.warp_perspective(x, TF_matrix, dsize=(h, w)) + ''' + return x + def adjust_prob(self, soft=False): #Detach from gradient ? if soft : diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 061a2dc..15b1550 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -646,8 +646,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f tf = time.process_time() - #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) - #viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) + viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) if(not high_grad_track): countcopy+=1 @@ -732,7 +732,7 @@ if __name__ == "__main__": aug_model = Augmented_model(Data_augV4(TF_dict=TF.TF_dict, mix_dist=0.0), LeNet(3,10)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10) #### plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter))