Comment Confmat + Cross-Val (sans Skorch) + minor improv

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-03 17:46:32 -05:00
parent 385bc9977c
commit be8491268a
4 changed files with 133 additions and 33 deletions

View file

@ -177,7 +177,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy max:', accuracy)
print('F1 :', f1)
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
#### Log ####
data={
@ -185,7 +185,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"f1": f1.cpu().numpy().tolist(),
"f1": f1.tolist(),
"time": tf - t0,
"param": None,
@ -253,7 +253,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
for epoch in range(1, epochs+1):
t0 = time.perf_counter()
dl_train, dl_val = next_CVSplit()
dl_train, dl_val = cvs.next_split()
dl_val_it = iter(dl_val)
for i, (xs, ys) in enumerate(dl_train):
@ -333,7 +333,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"f1": f1.cpu().numpy().tolist(),
"f1": f1.tolist(),
"time": tf - t0,
"param": param,
@ -349,11 +349,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy max:', max([x["acc"] for x in log]))
print('F1 :', f1)
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
if not model['data_aug']._fixed_prob: print('TF Proba :', ["{0:0.4f}".format(p) for p in model['data_aug']['prob']])
#print('proba grad',model['data_aug']['prob'].grad)
if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
if not model['data_aug']._fixed_mag: print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']])
#print('Mag grad',model['data_aug']['mag'].grad)
if not model['data_aug']._fixed_mix: print('Mix:', model['data_aug']['mix_dist'].item())
#print('Reg loss:', model['data_aug'].reg_loss().item())