mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-03 11:40:46 +02:00
option plot F1
This commit is contained in:
parent
d0a49a9d61
commit
49472adfab
1 changed files with 3 additions and 2 deletions
|
@ -131,13 +131,14 @@ def print_graph(PyTorch_obj, fig_name='graph'):
|
|||
graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||
graph.render(fig_name)
|
||||
|
||||
def plot_resV2(log, fig_name='res', param_names=None):
|
||||
def plot_resV2(log, fig_name='res', param_names=None, f1=True):
|
||||
"""Save a visual graph of the logs.
|
||||
|
||||
Args:
|
||||
log (dict): Logs of the training generated by most of train_utils.
|
||||
fig_name (string): Relative path where to save the graph. (default: res)
|
||||
param_names (list): Labels for the parameters. (default: None)
|
||||
f1 (bool): Wether to plot F1 scores. (default: True)
|
||||
"""
|
||||
epochs = [x["epoch"] for x in log]
|
||||
|
||||
|
@ -151,7 +152,7 @@ def plot_resV2(log, fig_name='res', param_names=None):
|
|||
ax[1, 0].set_title('Test')
|
||||
ax[1, 0].plot(epochs,[x["acc"] for x in log], label='Acc')
|
||||
|
||||
if "f1" in log[0].keys():
|
||||
if f1 and "f1" in log[0].keys():
|
||||
#ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1')
|
||||
#'''
|
||||
#print(log[0]["f1"])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue