option plot F1

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-21 11:32:53 -05:00
parent d0a49a9d61
commit 49472adfab

View file

@ -131,13 +131,14 @@ def print_graph(PyTorch_obj, fig_name='graph'):
graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
graph.render(fig_name)
def plot_resV2(log, fig_name='res', param_names=None):
def plot_resV2(log, fig_name='res', param_names=None, f1=True):
"""Save a visual graph of the logs.
Args:
log (dict): Logs of the training generated by most of train_utils.
fig_name (string): Relative path where to save the graph. (default: res)
param_names (list): Labels for the parameters. (default: None)
f1 (bool): Wether to plot F1 scores. (default: True)
"""
epochs = [x["epoch"] for x in log]
@ -151,7 +152,7 @@ def plot_resV2(log, fig_name='res', param_names=None):
ax[1, 0].set_title('Test')
ax[1, 0].plot(epochs,[x["acc"] for x in log], label='Acc')
if "f1" in log[0].keys():
if f1 and "f1" in log[0].keys():
#ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1')
#'''
#print(log[0]["f1"])