From 534e24430706eb07860d564817c33dd77b3b51e3 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Tue, 25 Feb 2020 14:05:17 -0500 Subject: [PATCH] +controle mag reg + Tests fonction mag reg --- higher/smart_aug/benchmark.py | 9 +++++---- higher/smart_aug/dataug.py | 14 ++++++++++++-- higher/smart_aug/process_res.py | 7 ++++--- higher/smart_aug/test_dataug.py | 28 ++++++++++++++++------------ higher/smart_aug/train_utils.py | 2 +- higher/smart_aug/transformations.py | 7 +++++++ 6 files changed, 45 insertions(+), 22 deletions(-) diff --git a/higher/smart_aug/benchmark.py b/higher/smart_aug/benchmark.py index 0d3d332..5b600a4 100644 --- a/higher/smart_aug/benchmark.py +++ b/higher/smart_aug/benchmark.py @@ -29,11 +29,11 @@ optim_param={ res_folder="../res/benchmark/CIFAR10/" #res_folder="../res/HPsearch/" -epochs= 400 +epochs= 200 dataug_epoch_start=0 nb_run= 1 -tf_config='../config/base_tf_config.json' +tf_config='../config/wide_tf_config.json' #'../config/wide_tf_config.json'#'../config/base_tf_config.json' TF_loader=TF_loader() tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(tf_config) @@ -55,15 +55,16 @@ if __name__ == "__main__": ### Benchmark ### #''' - n_inner_iter = 1#[0, 1] + inner_its = [3] dist_mix = [0.5] - N_seq_TF= [3, 4] + N_seq_TF= [3] mag_setup = [(False, False)] #[(True, True), (False, False)] #(FxSh, Independant) for model_type in model_list.keys(): for model_name in model_list[model_type]: for run in range(nb_run): + for n_inner_iter in inner_its: for n_tf in N_seq_TF: for dist in dist_mix: for m_setup in mag_setup: diff --git a/higher/smart_aug/dataug.py b/higher/smart_aug/dataug.py index 93cc598..56fd2e5 100755 --- a/higher/smart_aug/dataug.py +++ b/higher/smart_aug/dataug.py @@ -115,8 +115,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) if self._shared_mag : self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max else: - self._reg_mask=[self._TF.index(t) for t in self._TF if t not in self._TF_ignore_mag] + TF_mag=[t for t in self._TF if t not in self._TF_ignore_mag] #TF w/ differentiable mag + self._reg_mask=[self._TF.index(t) for t in TF_mag] self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max + + #Prevent Identity + #print(TF.TF_identity) + #self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=0.0) + #for val in TF.TF_identity.keys(): + # idx=[self._reg_mask.index(self._TF.index(t)) for t in TF_mag if t in TF.TF_identity[val]] + # self._reg_tgt[idx]=val + #print(TF_mag, self._reg_tgt) def forward(self, x): """ Main method of the Data augmentation module. @@ -247,7 +256,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) else: #return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean') mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask] - max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') + max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Close to target ? + #max_mag_reg = - reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Far from target ? return max_mag_reg def train(self, mode=True): diff --git a/higher/smart_aug/process_res.py b/higher/smart_aug/process_res.py index 4ae65a3..6a288cd 100755 --- a/higher/smart_aug/process_res.py +++ b/higher/smart_aug/process_res.py @@ -10,14 +10,15 @@ if __name__ == "__main__": #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-2.json", #"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json", ] - files = ["../res/benchmark/CIFAR10/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,0.17,'wide_resnet50_2', str(run)) for run in range(3)] - files = ["../res/benchmark/CIFAR100/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-Mag)-%s)-200 epochs (dataug:0)- 1 in_it-%s.json"%(0.5,3,'wide_resnet50_2', str(run)) for run in range(3)] + #files = ["../res/benchmark/CIFAR10/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,0.17,'wide_resnet50_2', str(run)) for run in range(3)] + #files = ["../res/benchmark/CIFAR10/log/Aug_mod(RandAug(14TFx%d-Mag%d)-%s)-200 epochs (dataug:0)- 0 in_it-%s.json"%(2,1,'resnet18', str(run)) for run in range(1)] + files = ["../res/benchmark/CIFAR10/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-Mag)-%s)-200 epochs (dataug:0)- 3 in_it-%s.json"%(0.5,3,'resnet18', str(run)) for run in range(1)] for idx, file in enumerate(files): #legend+=str(idx)+'-'+file+'\n' with open(file) as json_file: data = json.load(json_file) - plot_resV2(data['Log'], fig_name=file.replace("/log","").replace(".json",""))#, param_names=data['Param_names']) + plot_resV2(data['Log'], fig_name=file.replace("/log","").replace(".json",""), param_names=data['Param_names'], f1=True) #plot_TF_influence(data['Log'], param_names=data['Param_names']) #''' ## Loss , Acc, Proba = f(epoch) ## diff --git a/higher/smart_aug/test_dataug.py b/higher/smart_aug/test_dataug.py index b80fe63..9858d4e 100755 --- a/higher/smart_aug/test_dataug.py +++ b/higher/smart_aug/test_dataug.py @@ -34,13 +34,15 @@ if __name__ == "__main__": } #Parameters n_inner_iter = 1 - epochs = 2 + epochs = 200 dataug_epoch_start=0 + Nb_TF_seq=3 optim_param={ 'Meta':{ 'optim':'Adam', 'lr':1e-2, #1e-2 'epoch_start': 2, #0 / 2 (Resnet?) + 'reg_factor': 0.001, }, 'Inner':{ 'optim': 'SGD', @@ -110,7 +112,7 @@ if __name__ == "__main__": #### Augmented Model #### if 'aug_model' in tasks: - tf_config='../config/base_tf_config.json' + tf_config='../config/invScale_wide_tf_config.json'#'../config/base_tf_config.json' tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(tf_config) torch.cuda.reset_max_memory_allocated() #reset_peak_stats @@ -118,15 +120,17 @@ if __name__ == "__main__": t0 = time.perf_counter() model = Higher_model(model, model_name) #run_dist_dataugV3 - aug_model = Augmented_model( - Data_augV5(TF_dict=tf_dict, - N_TF=1, - mix_dist=0.5, - fixed_prob=False, - fixed_mag=False, - shared_mag=False, - TF_ignore_mag=tf_ignore_mag), model).to(device) - #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) + if n_inner_iter !=0: + aug_model = Augmented_model( + Data_augV5(TF_dict=tf_dict, + N_TF=Nb_TF_seq, + mix_dist=0.5, + fixed_prob=False, + fixed_mag=False, + shared_mag=False, + TF_ignore_mag=tf_ignore_mag), model).to(device) + else: + aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=Nb_TF_seq), model).to(device) print("{} on {} for {} epochs - {} inner_it{}".format(str(aug_model), device_name, epochs, n_inner_iter, postfix)) log= run_dist_dataugV3(model=aug_model, @@ -134,7 +138,7 @@ if __name__ == "__main__": inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, opt_param=optim_param, - print_freq=20, + print_freq=10, unsup_loss=1, hp_opt=False, save_sample_freq=None) diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index 7bf1c75..ce4295e 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -329,7 +329,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start if(high_grad_track and i>0 and i%inner_it==0 and epoch>=opt_param['Meta']['epoch_start']): #Perform Meta step #print("meta") - val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss() + val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss(opt_param['Meta']['reg_factor']) #print_graph(val_loss) #to visualize computational graph val_loss.backward() diff --git a/higher/smart_aug/transformations.py b/higher/smart_aug/transformations.py index 3b852d8..9a196a7 100755 --- a/higher/smart_aug/transformations.py +++ b/higher/smart_aug/transformations.py @@ -31,6 +31,13 @@ PARAMETER_MAX = 1 # What is the min 'level' a transform could be predicted PARAMETER_MIN = 0.1 +#Dict containing the value for wich TF are closer to identity +#TF_identity={ +# PARAMETER_MAX:{'Solarize', 'Posterize'}, +# PARAMETER_MAX/2:{'Contrast','Color','Brightness','Sharpness'}, +# PARAMETER_MIN:{'Rotate','TranslateX','TranslateY','ShearX','ShearY'}, +#} + class TF_loader(object): """ Transformations builder.