mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Comment Confmat + Cross-Val (sans Skorch) + minor improv
This commit is contained in:
parent
385bc9977c
commit
be8491268a
4 changed files with 133 additions and 33 deletions
|
@ -177,7 +177,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy max:', accuracy)
|
||||
print('F1 :', f1)
|
||||
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
|
@ -185,7 +185,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"f1": f1.cpu().numpy().tolist(),
|
||||
"f1": f1.tolist(),
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
|
@ -253,7 +253,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
for epoch in range(1, epochs+1):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
dl_train, dl_val = next_CVSplit()
|
||||
dl_train, dl_val = cvs.next_split()
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
|
@ -333,7 +333,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"f1": f1.cpu().numpy().tolist(),
|
||||
"f1": f1.tolist(),
|
||||
"time": tf - t0,
|
||||
|
||||
"param": param,
|
||||
|
@ -349,11 +349,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy max:', max([x["acc"] for x in log]))
|
||||
print('F1 :', f1)
|
||||
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
|
||||
if not model['data_aug']._fixed_prob: print('TF Proba :', ["{0:0.4f}".format(p) for p in model['data_aug']['prob']])
|
||||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
|
||||
if not model['data_aug']._fixed_mag: print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']])
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
if not model['data_aug']._fixed_mix: print('Mix:', model['data_aug']['mix_dist'].item())
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue