From aef002891e98196c1074d6c487f3802a27466e3d Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Fri, 8 Nov 2019 17:41:19 -0500 Subject: [PATCH] Separation script plot compare --- higher/test_dataug.py | 19 +------------------ higher/utils.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 00cea07..a7593e8 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -780,21 +780,4 @@ if __name__ == "__main__": with open(res_folder+"log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f: json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') - print('-'*9) - - #### Comparison #### - ''' - files=[ - #"res/log/LeNet-100 epochs.json", - #"res/log/Aug_mod(Data_augV4(Uniform-4 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.json", - #"res/log/Aug_mod(Data_augV4(Uniform-4 TF)-LeNet)-100 epochs (dataug:50)- 0 in_it.json", - #"res/log/Aug_mod(Data_augV4(Uniform-3 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.json", - #"res/log/Aug_mod(Data_augV3(Uniform-3 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json", - #"res/log/Aug_mod(Data_augV4(Mix 0,5-3 TF)-LeNet)-100 epochs (dataug:0)- 1 in_it.json", - #"res/log/Aug_mod(Data_augV4(Mix 0.5-3 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json", - #"res/log/Aug_mod(Data_augV4(Uniform-3 TF)-LeNet)-100 epochs (dataug:0)- 10 in_it.json", - #"res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json", - #"res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 0 in_it.json", - ] - plot_compare(filenames=files, fig_name="res/compare") - ''' \ No newline at end of file + print('-'*9) \ No newline at end of file diff --git a/higher/utils.py b/higher/utils.py index 5eef43d..c1447fd 100644 --- a/higher/utils.py +++ b/higher/utils.py @@ -80,8 +80,35 @@ def plot_compare(filenames, fig_name='res'): ax[1].set_title('Acc') ax[2].set_title('Param') for a in ax: a.legend() - fig_name = fig_name.replace('.',',') + fig_name = fig_name.replace('.',',') + plt.savefig(fig_name, bbox_inches='tight') + plt.close() + +def plot_res_compare(filenames, fig_name='res'): + + all_data=[] + #legend="" + for idx, file in enumerate(filenames): + #legend+=str(idx)+'-'+file+'\n' + with open(file) as json_file: + data = json.load(json_file) + all_data.append(data) + + n_tf = [len(x["Param_names"]) for x in all_data] + acc = [x["Accuracy"] for x in all_data] + time = [x["Time"][0] for x in all_data] + + fig, ax = plt.subplots(ncols=3, figsize=(30, 8)) + + ax[0].plot(n_tf, acc) + ax[1].plot(n_tf, time) + + ax[0].set_title('Acc') + ax[1].set_title('Time') + #for a in ax: a.legend() + + fig_name = fig_name.replace('.',',') plt.savefig(fig_name, bbox_inches='tight') plt.close()