F1 par classes + plot OK

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-03 11:21:54 -05:00
parent 3ccacd0366
commit fcd0217d54
5 changed files with 57 additions and 22 deletions

View file

@ -52,7 +52,7 @@ def test(model):
#from sklearn.metrics import f1_score
#f1 = f1_score(labels.data.to('cpu'), predicted.data.to('cpu'), average="macro")
return accuracy, confmat.f1_metric(average="macro")
return accuracy, confmat.f1_metric(average=None)
def compute_vaLoss(model, dl_it, dl):
"""Evaluate a model on a batch of data.
@ -167,7 +167,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
tf = time.process_time()
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
accuracy, _ =test(model)
accuracy, f1 =test(model)
model.train()
#### Print ####
@ -177,6 +177,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy :', accuracy)
print('F1 :', f1.data)
#### Log ####
data={
@ -184,6 +185,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"f1": f1.cpu().numpy().tolist(),
"time": tf - t0,
"param": None,