Mise a jour de toute les modifs... (Higher: Ajout deux TF, modification val loss, ajout prob dans sample image, ...)

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-10 13:21:34 -05:00
parent e75fb96716
commit c8ce6c8024
6 changed files with 299 additions and 64 deletions

View file

@ -22,7 +22,7 @@ class timer():
def print_graph(PyTorch_obj, fig_name='graph'):
graph=make_dot(PyTorch_obj) #Loss give the whole graph
graph.format = 'svg' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
graph.render(fig_name)
def plot_res(log, fig_name='res', param_names=None):
@ -183,7 +183,7 @@ def plot_TF_res(log, tf_names, fig_name='res'):
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def viz_sample_data(imgs, labels, fig_name='data_sample'):
def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
@ -194,7 +194,9 @@ def viz_sample_data(imgs, labels, fig_name='data_sample'):
plt.yticks([])
plt.grid(False)
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
plt.xlabel(labels[i].item())
label = str(labels[i].item())
if weight_labels is not None : label+= ("- p %.2f" % weight_labels[i].item())
plt.xlabel(label)
plt.savefig(fig_name)
print("Sample saved :", fig_name)