mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
F1 par classes + plot OK
This commit is contained in:
parent
3ccacd0366
commit
fcd0217d54
5 changed files with 57 additions and 22 deletions
|
@ -52,7 +52,7 @@ def test(model):
|
|||
#from sklearn.metrics import f1_score
|
||||
#f1 = f1_score(labels.data.to('cpu'), predicted.data.to('cpu'), average="macro")
|
||||
|
||||
return accuracy, confmat.f1_metric(average="macro")
|
||||
return accuracy, confmat.f1_metric(average=None)
|
||||
|
||||
def compute_vaLoss(model, dl_it, dl):
|
||||
"""Evaluate a model on a batch of data.
|
||||
|
@ -167,7 +167,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
tf = time.process_time()
|
||||
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
|
||||
accuracy, _ =test(model)
|
||||
accuracy, f1 =test(model)
|
||||
model.train()
|
||||
|
||||
#### Print ####
|
||||
|
@ -177,6 +177,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy :', accuracy)
|
||||
print('F1 :', f1.data)
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
|
@ -184,6 +185,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"f1": f1.cpu().numpy().tolist(),
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue