diff --git a/higher/dataug.py b/higher/dataug.py index a377e51..1dbb0b9 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -316,6 +316,8 @@ class Data_augV3(nn.Module): #Echantillonage uniforme/Mixte class Data_augV4(nn.Module): #Transformations avec mask def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0): super(Data_augV4, self).__init__() + assert len(TF_dict)>0 + self._data_augmentation = True #self._TF_matrix={} diff --git a/higher/res/Aug_mod(Data_augV4(Uniform-11 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.png b/higher/res/Aug_mod(Data_augV4(Uniform-11 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.png deleted file mode 100644 index 62cf98d..0000000 Binary files a/higher/res/Aug_mod(Data_augV4(Uniform-11 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.png and /dev/null differ diff --git a/higher/res/Aug_mod(Data_augV4(Uniform-11 TF)-LeNet)-100 epochs (dataug:0)- 10 in_it.png b/higher/res/Aug_mod(Data_augV4(Uniform-11 TF)-LeNet)-100 epochs (dataug:0)- 10 in_it.png deleted file mode 100644 index 5885842..0000000 Binary files a/higher/res/Aug_mod(Data_augV4(Uniform-11 TF)-LeNet)-100 epochs (dataug:0)- 10 in_it.png and /dev/null differ diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 15b1550..00cea07 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -696,7 +696,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f model.augment(mode=True) if inner_it != 0: high_grad_track = True - print("Copy ", countcopy) + #print("Copy ", countcopy) return log ########################################## @@ -728,7 +728,7 @@ if __name__ == "__main__": print('-'*9) ''' #### Augmented Model #### - #''' + ''' aug_model = Augmented_model(Data_augV4(TF_dict=TF.TF_dict, mix_dist=0.0), LeNet(3,10)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) @@ -744,7 +744,43 @@ if __name__ == "__main__": json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') print('-'*9) - #''' + ''' + ## TF number tests ## + res_folder="res/TF_nb_tests/" + epochs= 200 + inner_its = [0, 10] + dataug_epoch_starts= [0, -1] + max_TF_nb = len(TF.TF_dict) + + try: + os.mkdir(res_folder) + os.mkdir(res_folder+"log/") + except FileExistsError: + pass + + for n_inner_iter in inner_its: + print("---Starting inner_it", n_inner_iter,"---") + for dataug_epoch_start in dataug_epoch_starts: + print("---Starting dataug", dataug_epoch_start,"---") + for i in range(1,max_TF_nb): + keys = list(TF.TF_dict.keys())[0:i] + ntf_dict = {k: TF.TF_dict[k] for k in keys} + + aug_model = Augmented_model(Data_augV4(TF_dict=ntf_dict, mix_dist=0.0), LeNet(3,10)).to(device) + print(str(aug_model), 'on', device_name) + #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10) + + #### + plot_res(log, fig_name=res_folder+"{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)) + print('-'*9) + times = [x["time"] for x in log] + out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} + print(str(aug_model),": acc", out["Accuracy"], "in (ms):", out["Time"][0], "+/-", out["Time"][1]) + with open(res_folder+"log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f: + json.dump(out, f, indent=True) + print('Log :\"',f.name, '\" saved !') + print('-'*9) #### Comparison #### ''' @@ -757,8 +793,8 @@ if __name__ == "__main__": #"res/log/Aug_mod(Data_augV4(Mix 0,5-3 TF)-LeNet)-100 epochs (dataug:0)- 1 in_it.json", #"res/log/Aug_mod(Data_augV4(Mix 0.5-3 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json", #"res/log/Aug_mod(Data_augV4(Uniform-3 TF)-LeNet)-100 epochs (dataug:0)- 10 in_it.json", - "res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json", - "res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 0 in_it.json", + #"res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json", + #"res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 0 in_it.json", ] plot_compare(filenames=files, fig_name="res/compare") ''' \ No newline at end of file diff --git a/higher/utils.py b/higher/utils.py index b9826bb..5eef43d 100644 --- a/higher/utils.py +++ b/higher/utils.py @@ -43,6 +43,7 @@ def plot_res(log, fig_name='res'): fig_name = fig_name.replace('.',',') plt.savefig(fig_name) + plt.close() def plot_compare(filenames, fig_name='res'): @@ -82,6 +83,7 @@ def plot_compare(filenames, fig_name='res'): fig_name = fig_name.replace('.',',') plt.savefig(fig_name, bbox_inches='tight') + plt.close() def viz_sample_data(imgs, labels, fig_name='data_sample'): @@ -97,6 +99,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample'): plt.xlabel(labels[i].item()) plt.savefig(fig_name) + plt.close() def model_copy(src,dst, patch_copy=True, copy_grad=True): #model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)