Brutus bis

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-20 11:05:40 -05:00
parent b67ec3c469
commit 2fe5070b09
5 changed files with 49 additions and 35 deletions

View file

@ -829,9 +829,9 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
dl_val_it = iter(dl_val)
high_grad_track = True
if inner_it == 0:
if inner_it == 0: #No HP optimization
high_grad_track=False
if dataug_epoch_start!=0:
if dataug_epoch_start!=0: #Augmentation de donnee differee
model.augment(mode=False)
high_grad_track = False
@ -874,6 +874,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
else:
#Methode KL div
# Supervised loss (classic)
if model._data_augmentation :
model.augment(mode=False)
sup_logits = model(xs)
@ -883,6 +884,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
log_sup=F.log_softmax(sup_logits, dim=1)
loss = F.cross_entropy(log_sup, ys)
# Unsupervised loss (KLdiv)
if model._data_augmentation:
aug_logits = model(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
@ -916,21 +918,22 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
meta_opt.step()
model['data_aug'].adjust_param(soft=True) #Contrainte sum(proba)=1
if hp_opt:
#Adjust Hyper-parameters
model['data_aug'].adjust_param(soft=True) #Contrainte sum(proba)=1
if hp_opt:
for param_group in diffopt.param_groups:
for param in list(opt_param['Inner'].keys())[1:]:
param_group[param].data = param_group[param].data.clamp(min=1e-4)
#Reset gradients
diffopt.detach_()
model['model'].detach_()
meta_opt.zero_grad()
tf = time.process_time()
if save_sample:
if save_sample: #Data sample saving
try:
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
@ -939,10 +942,10 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
pass
if(not val_loss):
if(not val_loss): #Compute val loss for logs
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
# Test model
accuracy, test_loss =test(model)
model.train()
@ -956,8 +959,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
"time": tf - t0,
"mix_dist": model['data_aug']['mix_dist'].item(),
"param": param, #if isinstance(model['data_aug'], Data_augV5)
#else [p.item() for p in model['data_aug']['prob']],
"param": param,
}
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'].item(), 'momentum': p_grp['momentum'].item()} for p_grp in diffopt.param_groups]
log.append(data)
@ -981,12 +983,15 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
for param_group in diffopt.param_groups:
print('Opt param - lr:', param_group['lr'].item(),'- momentum:', param_group['momentum'].item())
#############
#Augmentation de donnee differee
if not model.is_augmenting() and (epoch == dataug_epoch_start):
print('Starting Data Augmention...')
dataug_epoch_start = epoch
model.augment(mode=True)
if inner_it != 0: high_grad_track = True
#Data sample saving
try:
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))