Refactoring de TF_dict

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-14 21:17:54 -05:00
parent fd4dcdb392
commit 103277fadd
8 changed files with 245 additions and 23 deletions

View file

@ -48,6 +48,38 @@ def plot_res(log, fig_name='res', param_names=None):
plt.savefig(fig_name)
plt.close()
def plot_resV2(log, fig_name='res', param_names=None):
epochs = [x["epoch"] for x in log]
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 15))
ax[0, 0].set_title('Loss')
ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train')
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
ax[0, 0].legend()
ax[0, 1].set_title('Acc')
ax[0, 1].plot(epochs,[x["acc"] for x in log])
if log[0]["param"]!= None:
ax[1, 1].set_title('Prob')
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
ax[1, 1].stackplot(epochs, proba, labels=param_names)
ax[1, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
ax[1, 0].set_title('Mean prob')
mean = np.mean([x["param"] for x in log], axis=0)
std = np.std([x["param"] for x in log], axis=0)
ax[1, 0].bar(param_names, mean, yerr=std)
plt.sca(ax[1, 0]), plt.xticks(rotation=90)
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def plot_compare(filenames, fig_name='res'):
all_data=[]