diff --git a/higher/smart_aug/utils.py b/higher/smart_aug/utils.py index 4a4b5ae..8fbe1a6 100755 --- a/higher/smart_aug/utils.py +++ b/higher/smart_aug/utils.py @@ -131,13 +131,14 @@ def print_graph(PyTorch_obj, fig_name='graph'): graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats graph.render(fig_name) -def plot_resV2(log, fig_name='res', param_names=None): +def plot_resV2(log, fig_name='res', param_names=None, f1=True): """Save a visual graph of the logs. Args: log (dict): Logs of the training generated by most of train_utils. fig_name (string): Relative path where to save the graph. (default: res) param_names (list): Labels for the parameters. (default: None) + f1 (bool): Wether to plot F1 scores. (default: True) """ epochs = [x["epoch"] for x in log] @@ -151,7 +152,7 @@ def plot_resV2(log, fig_name='res', param_names=None): ax[1, 0].set_title('Test') ax[1, 0].plot(epochs,[x["acc"] for x in log], label='Acc') - if "f1" in log[0].keys(): + if f1 and "f1" in log[0].keys(): #ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1') #''' #print(log[0]["f1"])