From 4a7e73088de4b75e1c3b3c2f4b57ccc3e3c9afed Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Wed, 27 Nov 2019 12:54:19 -0500 Subject: [PATCH 1/2] Ajout RandAugment --- higher/dataug.py | 170 +++++++++++++++++++++++++++++++++++++- higher/test_dataug.py | 47 +++++++---- higher/train_utils.py | 4 + higher/transformations.py | 65 ++++++++++----- 4 files changed, 249 insertions(+), 37 deletions(-) diff --git a/higher/dataug.py b/higher/dataug.py index 32d1aee..b4ba0d9 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -659,7 +659,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) return w_loss def reg_loss(self, reg_factor=0.005): - if self._fixed_mag: + if self._fixed_mag: # or self._fixed_prob: #Pas de regularisation si trop peu de DOF return torch.tensor(0) else: #return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean') @@ -692,6 +692,174 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) else: return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param) +class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX): + super(RandAug, self).__init__() + + self._data_augmentation = True + + self._TF_dict = TF_dict + self._TF= list(self._TF_dict.keys()) + self._nb_tf= len(self._TF) + self._N_seqTF = N_TF + + self.mag=nn.Parameter(torch.tensor(float(mag))) + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #pas utilise + "mag" : nn.Parameter(torch.tensor(float(mag))), + }) + self._shared_mag = True + self._fixed_mag = True + + def forward(self, x): + if self._data_augmentation:# and TF.random.random() < 0.5: + device = x.device + batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] + + x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) + + for _ in range(self._N_seqTF): + ## Echantillonage ## == sampled_ops = np.random.choice(transforms, N) + uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) + cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist) + sample = cat_distrib.sample() + + ## Transformations ## + x = self.apply_TF(x, sample) + return x + + def apply_TF(self, x, sampled_TF): + smps_x=[] + + for tf_idx in range(self._nb_tf): + mask = sampled_TF==tf_idx #Create selection mask + smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim) + + if smp_x.shape[0]!=0: #if there's data to TF + magnitude=self._params["mag"].detach() + + tf=self._TF[tf_idx] + #print(magnitude) + + #In place + x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) + + return x + + def adjust_param(self, soft=False): + pass #Pas de parametre a opti + + def loss_weight(self): + return 1 #Pas d'echantillon = pas de ponderation + + def reg_loss(self, reg_factor=0.005): + return torch.tensor(0) #Pas de regularisation + + def train(self, mode=None): + if mode is None : + mode=self._data_augmentation + self.augment(mode=mode) #Inutile si mode=None + super(RandAug, self).train(mode) + + def eval(self): + self.train(mode=False) + + def augment(self, mode=True): + self._data_augmentation=mode + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) + +class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX): + super(RandAug, self).__init__() + + self._data_augmentation = True + + self._TF_dict = TF_dict + self._TF= list(self._TF_dict.keys()) + self._nb_tf= len(self._TF) + self._N_seqTF = N_TF + + self.mag=nn.Parameter(torch.tensor(float(mag))) + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.tensor(0.5)), + "mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX))), + }) + self._shared_mag = True + self._fixed_mag = True + + self._op_list =[] + for tf in self._TF: + for mag in range(0.1, self._params['mag'], 0.1): + op_list+=[(tf, self._params['prob'], mag)] + self._nb_op = len(self._op_list) + + print(self._op_list) + + def forward(self, x): + if self._data_augmentation:# and TF.random.random() < 0.5: + device = x.device + batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] + + x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) + + for _ in range(self._N_seqTF): + ## Echantillonage ## == sampled_ops = np.random.choice(transforms, N) + uniforme_dist = torch.ones(1, self._nb_op, device=device).softmax(dim=1) + cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_op), device=device)*uniforme_dist) + sample = cat_distrib.sample() + + ## Transformations ## + x = self.apply_TF(x, sample) + return x + + def apply_TF(self, x, sampled_TF): + smps_x=[] + + for op_idx in range(self._nb_op): + mask = sampled_TF==tf_idx #Create selection mask + smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim) + + if smp_x.shape[0]!=0: #if there's data to TF + if TF.random.random() < self.op_list[op_idx][1]: + magnitude=self.op_list[op_idx][2] + tf=self.op_list[op_idx][0] + + #In place + x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) + + return x + + def adjust_param(self, soft=False): + pass #Pas de parametre a opti + + def loss_weight(self): + return 1 #Pas d'echantillon = pas de ponderation + + def reg_loss(self, reg_factor=0.005): + return torch.tensor(0) #Pas de regularisation + + def train(self, mode=None): + if mode is None : + mode=self._data_augmentation + self.augment(mode=mode) #Inutile si mode=None + super(RandAug, self).train(mode) + + def eval(self): + self.train(mode=False) + + def augment(self, mode=True): + self._data_augmentation=mode + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) class Augmented_model(nn.Module): def __init__(self, data_augmenter, model): diff --git a/higher/test_dataug.py b/higher/test_dataug.py index d950bde..84f6bbb 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -14,6 +14,26 @@ tf_names = [ 'ShearX', 'ShearY', + ## Color TF (Expect image in the range of [0, 1]) ## + 'Contrast', + 'Color', + 'Brightness', + 'Sharpness', + 'Posterize', + 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch + + #Color TF (Common mag scale) + #'+Contrast', + #'+Color', + #'+Brightness', + #'+Sharpness', + #'-Contrast', + #'-Color', + #'-Brightness', + #'-Sharpness', + #'=Posterize', + #'=Solarize', + #'BRotate', #'BTranslateX', #'BTranslateY', @@ -24,14 +44,10 @@ tf_names = [ #'BadTranslateY', #'BadTranslateY_neg', - ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast', - 'Color', - 'Brightness', - 'Sharpness', - 'Posterize', - 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch - + #'BadColor', + #'BadSharpness', + #'BadContrast', + #'BadBrightness', #Non fonctionnel #'Auto_Contrast', #Pas opti pour des batch (Super lent) #'Equalize', @@ -47,8 +63,8 @@ else: ########################################## if __name__ == "__main__": - n_inner_iter = 10 - epochs = 100 + n_inner_iter = 0 + epochs = 150 dataug_epoch_start=0 #### Classic #### @@ -74,12 +90,13 @@ if __name__ == "__main__": print('-'*9) ''' #### Augmented Model #### - ''' + #''' t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=True, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device) + #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device) #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device) + aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None) @@ -98,9 +115,9 @@ if __name__ == "__main__": print('Execution Time : %.00f '%(time.process_time() - t0)) print('-'*9) - ''' - #### TF tests #### #''' + #### TF tests #### + ''' res_folder="res/brutus-tests/" epochs= 150 inner_its = [1, 10] @@ -150,4 +167,4 @@ if __name__ == "__main__": #plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names) print('-'*9) - #''' + ''' diff --git a/higher/train_utils.py b/higher/train_utils.py index 0020928..9fd46eb 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -540,6 +540,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch dl_val_it = iter(dl_val) + #if inner_it!=0: meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2) inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9) @@ -680,5 +681,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f model.augment(mode=True) if inner_it != 0: high_grad_track = True + viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) + #print("Copy ", countcopy) return log \ No newline at end of file diff --git a/higher/transformations.py b/higher/transformations.py index e1e0be3..a8b708e 100644 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -64,6 +64,27 @@ TF_dict={ #Dataugv5 'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))), 'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))), + ## Color TF (Expect image in the range of [0, 1]) ## + 'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient + 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch + + #Color TF (Common mag scale) + '+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), + '+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), + '+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), + '+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), + '-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), + '-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), + '-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), + '-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), + '=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient + '=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch + + 'BRotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30*3))), 'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=0))), 'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=1))), @@ -74,14 +95,11 @@ TF_dict={ #Dataugv5 'BadTranslateX_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=0))), 'BadTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=1))), 'BadTranslateY_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=1))), - - ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient - 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch + + 'BadColor':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))), + 'BadSharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))), + 'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))), + 'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))), #Non fonctionnel #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) @@ -111,10 +129,15 @@ def float_image(int_image): # return random.uniform(minval, real_max) def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval] - real_max = float_parameter(mag, maxval=maxval) - if not minval : minval = -real_max + real_mag = float_parameter(mag, maxval=maxval) + if not minval : minval = -real_mag #return random.uniform(minval, real_max) - return minval +(real_max-minval) * torch.rand(size, device=mag.device) + return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag] + +def invScale_rand_floats(size, mag, maxval, minval): + #Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval] + real_mag = float_parameter(float(PARAMETER_MAX) - mag, maxval=maxval-minval)+minval + return real_mag + (maxval-real_mag) * torch.rand(size, device=mag.device) #[real_mag, max_val] def zero_stack(tensor, zero_pos): if zero_pos==0: @@ -139,7 +162,7 @@ def float_parameter(level, maxval): #return float(level) * maxval / PARAMETER_MAX return (level * maxval / PARAMETER_MAX)#.to(torch.float) -def int_parameter(level, maxval): #Perte de gradient +#def int_parameter(level, maxval): #Perte de gradient """Helper function to scale `val` between 0 and maxval . Args: level: Level of the operation that will be between [0, `PARAMETER_MAX`]. @@ -149,7 +172,7 @@ def int_parameter(level, maxval): #Perte de gradient An int that results from scaling `maxval` according to `level`. """ #return int(level * maxval / PARAMETER_MAX) - return (level * maxval / PARAMETER_MAX) +# return (level * maxval / PARAMETER_MAX) def flipLR(x): device = x.device @@ -279,19 +302,19 @@ def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH for idx, t in enumerate(thresholds): #Operation par image mask = x[idx] > t #Perte du gradient #In place - #inv_x = 1-x[idx][mask] - #x[idx][mask]=inv_x + inv_x = 1-x[idx][mask] + x[idx][mask]=inv_x # #Out of place - im = x[idx] - inv_x = 1-im[mask] + # im = x[idx] + # inv_x = 1-im[mask] - imgs.append(im.masked_scatter(mask,inv_x)) + # imgs.append(im.masked_scatter(mask,inv_x)) - idxs=torch.tensor(range(x.shape[0]), device=x.device) - idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... - x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs)) + #idxs=torch.tensor(range(x.shape[0]), device=x.device) + #idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + #x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs)) # return x From d822f8f92e7a056ca4a8f7315faea0a90aa0ec85 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Wed, 27 Nov 2019 17:19:51 -0500 Subject: [PATCH 2/2] Modif solarize (Tjrs pas differentiable...) --- higher/compare_res.py | 9 ++++++--- higher/dataug.py | 28 +++++++++++++--------------- higher/test_dataug.py | 37 +++++++++++++++++++------------------ higher/train_utils.py | 2 +- higher/transformations.py | 22 +++++++++++++++++----- 5 files changed, 56 insertions(+), 42 deletions(-) diff --git a/higher/compare_res.py b/higher/compare_res.py index c399aa1..57f16e0 100644 --- a/higher/compare_res.py +++ b/higher/compare_res.py @@ -2,11 +2,12 @@ from utils import * if __name__ == "__main__": - ''' + #''' files=[ #"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json", #"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json", - "res/brutus-tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx1-Mag)-LeNet)-150epochs(dataug:0)-1in_it-0.json", + #"res/brutus-tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx1-Mag)-LeNet)-150epochs(dataug:0)-1in_it-0.json", + "res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json", ] for idx, file in enumerate(files): @@ -15,7 +16,7 @@ if __name__ == "__main__": data = json.load(json_file) plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names']) #plot_TF_influence(data['Log'], param_names=data['Param_names']) - ''' + #''' ## Loss , Acc, Proba = f(epoch) ## #plot_compare(filenames=files, fig_name="res/compare") @@ -75,6 +76,7 @@ if __name__ == "__main__": ''' #Res print + ''' nb_run=3 accs = [] times = [] @@ -88,3 +90,4 @@ if __name__ == "__main__": times.append(data['Time'][0]) print(files[0], np.mean(accs), np.std(accs), np.mean(times)) + ''' \ No newline at end of file diff --git a/higher/dataug.py b/higher/dataug.py index b4ba0d9..538f4ab 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -692,7 +692,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) else: return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param) -class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh +class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX): super(RandAug, self).__init__() @@ -773,9 +773,9 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh def __str__(self): return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) -class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh +class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training) def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX): - super(RandAug, self).__init__() + super(RandAugUDA, self).__init__() self._data_augmentation = True @@ -786,7 +786,7 @@ class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh self.mag=nn.Parameter(torch.tensor(float(mag))) self._params = nn.ParameterDict({ - "prob": nn.Parameter(torch.tensor(0.5)), + "prob": nn.Parameter(torch.tensor(0.5).unsqueeze(dim=0)), "mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX))), }) self._shared_mag = True @@ -794,12 +794,10 @@ class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh self._op_list =[] for tf in self._TF: - for mag in range(0.1, self._params['mag'], 0.1): - op_list+=[(tf, self._params['prob'], mag)] + for mag in range(1, int(self._params['mag']*10), 1): + self._op_list+=[(tf, self._params['prob'].item(), mag/10)] self._nb_op = len(self._op_list) - print(self._op_list) - def forward(self, x): if self._data_augmentation:# and TF.random.random() < 0.5: device = x.device @@ -821,16 +819,16 @@ class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh smps_x=[] for op_idx in range(self._nb_op): - mask = sampled_TF==tf_idx #Create selection mask + mask = sampled_TF==op_idx #Create selection mask smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim) if smp_x.shape[0]!=0: #if there's data to TF - if TF.random.random() < self.op_list[op_idx][1]: - magnitude=self.op_list[op_idx][2] - tf=self.op_list[op_idx][0] + if TF.random.random() < self._op_list[op_idx][1]: + magnitude=self._op_list[op_idx][2] + tf=self._op_list[op_idx][0] #In place - x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) + x[mask]=self._TF_dict[tf](x=smp_x, mag=torch.tensor(magnitude, device=x.device)) return x @@ -847,7 +845,7 @@ class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh if mode is None : mode=self._data_augmentation self.augment(mode=mode) #Inutile si mode=None - super(RandAug, self).train(mode) + super(RandAugUDA, self).train(mode) def eval(self): self.train(mode=False) @@ -859,7 +857,7 @@ class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh return self._params[key] def __str__(self): - return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) + return "RandAugUDA(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) class Augmented_model(nn.Module): def __init__(self, data_augmenter, model): diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 84f6bbb..d443aa0 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -5,21 +5,21 @@ from train_utils import * tf_names = [ ## Geometric TF ## - 'Identity', - 'FlipUD', - 'FlipLR', - 'Rotate', - 'TranslateX', - 'TranslateY', - 'ShearX', - 'ShearY', + #'Identity', + #'FlipUD', + #'FlipLR', + #'Rotate', + #'TranslateX', + #'TranslateY', + #'ShearX', + #'ShearY', ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast', - 'Color', - 'Brightness', - 'Sharpness', - 'Posterize', + #'Contrast', + #'Color', + #'Brightness', + #'Sharpness', + #'Posterize', 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch #Color TF (Common mag scale) @@ -48,6 +48,7 @@ tf_names = [ #'BadSharpness', #'BadContrast', #'BadBrightness', + #Non fonctionnel #'Auto_Contrast', #Pas opti pour des batch (Super lent) #'Equalize', @@ -63,8 +64,8 @@ else: ########################################## if __name__ == "__main__": - n_inner_iter = 0 - epochs = 150 + n_inner_iter = 10 + epochs = 1 dataug_epoch_start=0 #### Classic #### @@ -94,12 +95,12 @@ if __name__ == "__main__": t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict - #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=True, fixed_mag=False, shared_mag=True), LeNet(3,10)).to(device) #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device) - aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device) + #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=None) #### print('-'*9) diff --git a/higher/train_utils.py b/higher/train_utils.py index 9fd46eb..0c3c750 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -651,7 +651,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f print('TF Proba :', model['data_aug']['prob'].data) #print('proba grad',model['data_aug']['prob'].grad) print('TF Mag :', model['data_aug']['mag'].data) - #print('Mag grad',model['data_aug']['mag'].grad) + print('Mag grad',model['data_aug']['mag'].grad) #print('Reg loss:', model['data_aug'].reg_loss().item()) ############# #### Log #### diff --git a/higher/transformations.py b/higher/transformations.py index a8b708e..cefe253 100644 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -298,12 +298,12 @@ def equalize(x): #PAS OPTIMISE POUR DES BATCH def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH # Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B)) batch_size, channels, h, w = x.shape - imgs=[] - for idx, t in enumerate(thresholds): #Operation par image - mask = x[idx] > t #Perte du gradient + #imgs=[] + #for idx, t in enumerate(thresholds): #Operation par image + # mask = x[idx] > t #Perte du gradient #In place - inv_x = 1-x[idx][mask] - x[idx][mask]=inv_x + # inv_x = 1-x[idx][mask] + # x[idx][mask]=inv_x # #Out of place @@ -316,6 +316,18 @@ def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH #idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... #x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs)) # + + thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + #print(thresholds.grad_fn) + x=torch.where(x>thresholds,1-x, x) + #print(mask.grad_fn) + + #x=x.min(thresholds) + #inv_x = 1-x[mask] + #x=x.where(x