mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Mise a jour de toute les modifs... (Higher: Ajout deux TF, modification val loss, ajout prob dans sample image, ...)
This commit is contained in:
parent
e75fb96716
commit
c8ce6c8024
6 changed files with 299 additions and 64 deletions
|
@ -22,7 +22,7 @@ class timer():
|
|||
|
||||
def print_graph(PyTorch_obj, fig_name='graph'):
|
||||
graph=make_dot(PyTorch_obj) #Loss give the whole graph
|
||||
graph.format = 'svg' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||
graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||
graph.render(fig_name)
|
||||
|
||||
def plot_res(log, fig_name='res', param_names=None):
|
||||
|
@ -183,7 +183,7 @@ def plot_TF_res(log, tf_names, fig_name='res'):
|
|||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def viz_sample_data(imgs, labels, fig_name='data_sample'):
|
||||
def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
|
||||
|
||||
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
|
||||
|
||||
|
@ -194,7 +194,9 @@ def viz_sample_data(imgs, labels, fig_name='data_sample'):
|
|||
plt.yticks([])
|
||||
plt.grid(False)
|
||||
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
|
||||
plt.xlabel(labels[i].item())
|
||||
label = str(labels[i].item())
|
||||
if weight_labels is not None : label+= ("- p %.2f" % weight_labels[i].item())
|
||||
plt.xlabel(label)
|
||||
|
||||
plt.savefig(fig_name)
|
||||
print("Sample saved :", fig_name)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue