Separation script plot compare

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-08 17:41:19 -05:00
parent 0d1a684aed
commit aef002891e
2 changed files with 29 additions and 19 deletions

View file

@ -781,20 +781,3 @@ if __name__ == "__main__":
json.dump(out, f, indent=True) json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !') print('Log :\"',f.name, '\" saved !')
print('-'*9) print('-'*9)
#### Comparison ####
'''
files=[
#"res/log/LeNet-100 epochs.json",
#"res/log/Aug_mod(Data_augV4(Uniform-4 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
#"res/log/Aug_mod(Data_augV4(Uniform-4 TF)-LeNet)-100 epochs (dataug:50)- 0 in_it.json",
#"res/log/Aug_mod(Data_augV4(Uniform-3 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
#"res/log/Aug_mod(Data_augV3(Uniform-3 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json",
#"res/log/Aug_mod(Data_augV4(Mix 0,5-3 TF)-LeNet)-100 epochs (dataug:0)- 1 in_it.json",
#"res/log/Aug_mod(Data_augV4(Mix 0.5-3 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json",
#"res/log/Aug_mod(Data_augV4(Uniform-3 TF)-LeNet)-100 epochs (dataug:0)- 10 in_it.json",
#"res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json",
#"res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 0 in_it.json",
]
plot_compare(filenames=files, fig_name="res/compare")
'''

View file

@ -80,8 +80,35 @@ def plot_compare(filenames, fig_name='res'):
ax[1].set_title('Acc') ax[1].set_title('Acc')
ax[2].set_title('Param') ax[2].set_title('Param')
for a in ax: a.legend() for a in ax: a.legend()
fig_name = fig_name.replace('.',',')
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def plot_res_compare(filenames, fig_name='res'):
all_data=[]
#legend=""
for idx, file in enumerate(filenames):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
all_data.append(data)
n_tf = [len(x["Param_names"]) for x in all_data]
acc = [x["Accuracy"] for x in all_data]
time = [x["Time"][0] for x in all_data]
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
ax[0].plot(n_tf, acc)
ax[1].plot(n_tf, time)
ax[0].set_title('Acc')
ax[1].set_title('Time')
#for a in ax: a.legend()
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight') plt.savefig(fig_name, bbox_inches='tight')
plt.close() plt.close()