mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Confmat / F1 + Minor fix
This commit is contained in:
parent
250ce2c3cf
commit
3ccacd0366
5 changed files with 120 additions and 32 deletions
|
@ -10,6 +10,8 @@ import higher
|
|||
from datasets import *
|
||||
from utils import *
|
||||
|
||||
confmat = ConfusionMatrix(num_classes=len(dl_test.dataset.classes))
|
||||
|
||||
def test(model):
|
||||
"""Evaluate a model on test data.
|
||||
|
||||
|
@ -17,7 +19,7 @@ def test(model):
|
|||
model (nn.Module): Model to test.
|
||||
|
||||
Returns:
|
||||
(float, Tensor) Returns the accuracy and test loss of the model.
|
||||
(float, Tensor) Returns the accuracy and F1 score of the model.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
|
@ -30,7 +32,8 @@ def test(model):
|
|||
|
||||
correct = 0
|
||||
total = 0
|
||||
loss = []
|
||||
#loss = []
|
||||
confmat.reset()
|
||||
with torch.no_grad():
|
||||
for features, labels in dl_test:
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
@ -40,11 +43,16 @@ def test(model):
|
|||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
loss.append(F.cross_entropy(outputs, labels).item())
|
||||
#loss.append(F.cross_entropy(outputs, labels).item())
|
||||
confmat.update(labels, predicted)
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
|
||||
return accuracy, np.mean(loss)
|
||||
#print(confmat)
|
||||
#from sklearn.metrics import f1_score
|
||||
#f1 = f1_score(labels.data.to('cpu'), predicted.data.to('cpu'), average="macro")
|
||||
|
||||
return accuracy, confmat.f1_metric(average="macro")
|
||||
|
||||
def compute_vaLoss(model, dl_it, dl):
|
||||
"""Evaluate a model on a batch of data.
|
||||
|
@ -202,7 +210,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
|
||||
hp_opt (bool): Wether to learn inner optimizer parameters. (default: False)
|
||||
save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, only one save would be done at the end of the training. (default: None)
|
||||
save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, no sample will be saved. (default: None)
|
||||
|
||||
Returns:
|
||||
(list) Logs of training. Each items is a dict containing results of an epoch.
|
||||
|
@ -310,7 +318,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
# Test model
|
||||
accuracy, test_loss =test(model)
|
||||
accuracy, f1 =test(model)
|
||||
model.train()
|
||||
|
||||
#### Log ####
|
||||
|
@ -320,6 +328,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"f1": f1.cpu().numpy().tolist(),
|
||||
"time": tf - t0,
|
||||
|
||||
"param": param,
|
||||
|
@ -360,15 +369,6 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
||||
track_higher_grads=high_grad_track)
|
||||
|
||||
|
||||
#Data sample saving
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch))
|
||||
except:
|
||||
print("Couldn't save finals samples")
|
||||
pass
|
||||
|
||||
return log
|
||||
|
||||
def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue