diff --git a/higher/compare_res.py b/higher/compare_res.py index 05e8be7..3fffadf 100644 --- a/higher/compare_res.py +++ b/higher/compare_res.py @@ -1,29 +1,5 @@ from utils import * -tf_names = [ - ## Geometric TF ## - 'Identity', - 'FlipUD', - 'FlipLR', - 'Rotate', - 'TranslateX', - 'TranslateY', - 'ShearX', - 'ShearY', - - ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast', - 'Color', - 'Brightness', - 'Sharpness', - 'Posterize', - 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch - - #Non fonctionnel - #'Auto_Contrast', #Pas opti pour des batch (Super lent) - #'Equalize', -] - if __name__ == "__main__": #### Comparison #### @@ -44,15 +20,17 @@ if __name__ == "__main__": #plot_compare(filenames=files, fig_name="res/compare") ## Acc, Time, Epochs = f(n_tf) ## - fig_name="res/TF_nb_tests_compare" + fig_name="res/TF_seq_tests_compare" inner_its = [0, 10] - dataug_epoch_starts= [0, -1] - TF_nb = range(1,14+1) + dataug_epoch_starts= [0] + TF_nb = 14 #range(1,14+1) + N_seq_TF= [1, 2, 3, 4, 6] fig, ax = plt.subplots(ncols=3, figsize=(30, 8)) for in_it in inner_its: for dataug in dataug_epoch_starts: - filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF)-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(n_tf, dataug, in_it) for n_tf in TF_nb] + #filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF)-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(n_tf, dataug, in_it) for n_tf in TF_nb] + filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(TF_nb, n_tf, dataug, in_it) for n_tf in N_seq_TF] all_data=[] #legend="" @@ -62,7 +40,8 @@ if __name__ == "__main__": data = json.load(json_file) all_data.append(data) - n_tf = [len(x["Param_names"]) for x in all_data] + n_tf = N_seq_TF + #n_tf = [len(x["Param_names"]) for x in all_data] acc = [x["Accuracy"] for x in all_data] epochs = [len(x["Log"]) for x in all_data] time = [x["Time"][0] for x in all_data] diff --git a/higher/test_dataug.py b/higher/test_dataug.py index a7f58b7..68036fa 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -761,7 +761,7 @@ if __name__ == "__main__": print('-'*9) ''' #### Augmented Model #### - #''' + ''' tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), LeNet(3,10)).to(device) @@ -770,7 +770,7 @@ if __name__ == "__main__": log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10) #### - plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it (SOFT)".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)) + plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)) print('-'*9) times = [x["time"] for x in log] out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} @@ -779,14 +779,15 @@ if __name__ == "__main__": json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') print('-'*9) - #''' - ## TF number tests ## ''' + #### TF number tests #### + #''' res_folder="res/TF_nb_tests/" epochs= 200 inner_its = [0, 10] - dataug_epoch_starts= [0, -1] - TF_nb = [14] #range(1,len(TF.TF_dict)+1) + dataug_epoch_starts= [0] + TF_nb = [len(TF.TF_dict)] #range(1,len(TF.TF_dict)+1) + N_seq_TF= [1, 2, 3, 4] try: os.mkdir(res_folder) @@ -798,24 +799,28 @@ if __name__ == "__main__": print("---Starting inner_it", n_inner_iter,"---") for dataug_epoch_start in dataug_epoch_starts: print("---Starting dataug", dataug_epoch_start,"---") - for i in TF_nb: - keys = list(TF.TF_dict.keys())[0:i] - ntf_dict = {k: TF.TF_dict[k] for k in keys} + for n_tf in N_seq_TF: + print("---Starting N_TF", n_tf,"---") + for i in TF_nb: + keys = list(TF.TF_dict.keys())[0:i] + ntf_dict = {k: TF.TF_dict[k] for k in keys} - aug_model = Augmented_model(Data_augV4(TF_dict=ntf_dict, mix_dist=0.0), LeNet(3,10)).to(device) - print(str(aug_model), 'on', device_name) - #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10) + aug_model = Augmented_model(Data_augV4(TF_dict=ntf_dict, N_TF=n_tf, mix_dist=0.0), LeNet(3,10)).to(device) + print(str(aug_model), 'on', device_name) + #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10) - #### - plot_res(log, fig_name=res_folder+"{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)) - print('-'*9) - times = [x["time"] for x in log] - out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} - print(str(aug_model),": acc", out["Accuracy"], "in (ms):", out["Time"][0], "+/-", out["Time"][1]) - with open(res_folder+"log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f: - json.dump(out, f, indent=True) - print('Log :\"',f.name, '\" saved !') - print('-'*9) + #### + plot_res(log, fig_name=res_folder+"{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)) + print('-'*9) + times = [x["time"] for x in log] + out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} + print(str(aug_model),": acc", out["Accuracy"], "in (ms):", out["Time"][0], "+/-", out["Time"][1]) + with open(res_folder+"log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f: + json.dump(out, f, indent=True) + print('Log :\"',f.name, '\" saved !') + print('-'*9) - ''' \ No newline at end of file + #''' + + \ No newline at end of file