diff --git a/higher/compare_res.py b/higher/compare_res.py index b29bc7f..429645a 100644 --- a/higher/compare_res.py +++ b/higher/compare_res.py @@ -1,5 +1,29 @@ from utils import * +tf_names = [ + ## Geometric TF ## + 'Identity', + 'FlipUD', + 'FlipLR', + 'Rotate', + 'TranslateX', + 'TranslateY', + 'ShearX', + 'ShearY', + + ## Color TF (Expect image in the range of [0, 1]) ## + 'Contrast', + 'Color', + 'Brightness', + 'Sharpness', + 'Posterize', + 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch + + #Non fonctionnel + #'Auto_Contrast', #Pas opti pour des batch (Super lent) + #'Equalize', +] + if __name__ == "__main__": #### Comparison #### @@ -21,9 +45,9 @@ if __name__ == "__main__": ## Acc, Time, Epochs = f(n_tf) ## fig_name="res/TF_nb_tests_compare" - inner_its = [0, 10] - dataug_epoch_starts= [0, -1] - TF_nb = range(1,14+1) + inner_its = [10] + dataug_epoch_starts= [0] + TF_nb = [14]#range(1,14+1) fig, ax = plt.subplots(ncols=3, figsize=(30, 8)) for in_it in inner_its: @@ -48,9 +72,11 @@ if __name__ == "__main__": ax[1].plot(n_tf, time, label="{} in_it/{} dataug".format(in_it,dataug)) ax[2].plot(n_tf, epochs, label="{} in_it/{} dataug".format(in_it,dataug)) + #for data in all_data: #print(np.mean([x["param"] for x in data["Log"]], axis=0)) - # print(len(data["Param_names"]), np.argsort(np.argsort(np.mean([x["param"] for x in data["Log"]], axis=0)))) + #print(len(data["Param_names"]), np.argsort(np.argsort(np.mean([x["param"] for x in data["Log"]], axis=0)))) + ax[0].set_title('Acc') ax[1].set_title('Time') diff --git a/higher/res/TF_nb_tests_compare.png b/higher/res/TF_nb_tests_compare.png index 7421a17..adbd2bf 100644 Binary files a/higher/res/TF_nb_tests_compare.png and b/higher/res/TF_nb_tests_compare.png differ diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 33f6394..a7f58b7 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -734,7 +734,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f ########################################## if __name__ == "__main__": - n_inner_iter = 10 + n_inner_iter = 0 epochs = 100 dataug_epoch_start=0 @@ -764,7 +764,7 @@ if __name__ == "__main__": #''' tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict - aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=1, mix_dist=0.0), LeNet(3,10)).to(device) + aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), LeNet(3,10)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10) diff --git a/higher/utils.py b/higher/utils.py index 5ced829..4c33a15 100644 --- a/higher/utils.py +++ b/higher/utils.py @@ -112,6 +112,16 @@ def plot_res_compare(filenames, fig_name='res'): plt.savefig(fig_name, bbox_inches='tight') plt.close() +def plot_TF_res(log, tf_names, fig_name='res'): + + fig, ax = plt.subplots(1, 1, figsize=(30, 8), sharey=True) + ax.bar(tf_names, np.mean([x["param"] for x in log], axis=0), yerr=np.std([x["param"] for x in log], axis=0)) + #ax.bar(tf_names, log[-1]["param"]) + + fig_name = fig_name.replace('.',',') + plt.savefig(fig_name, bbox_inches='tight') + plt.close() + def viz_sample_data(imgs, labels, fig_name='data_sample'): sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()