mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Brutus bis
This commit is contained in:
parent
b67ec3c469
commit
2fe5070b09
5 changed files with 49 additions and 35 deletions
|
@ -829,9 +829,9 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
dl_val_it = iter(dl_val)
|
||||
|
||||
high_grad_track = True
|
||||
if inner_it == 0:
|
||||
if inner_it == 0: #No HP optimization
|
||||
high_grad_track=False
|
||||
if dataug_epoch_start!=0:
|
||||
if dataug_epoch_start!=0: #Augmentation de donnee differee
|
||||
model.augment(mode=False)
|
||||
high_grad_track = False
|
||||
|
||||
|
@ -874,6 +874,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
|
||||
else:
|
||||
#Methode KL div
|
||||
# Supervised loss (classic)
|
||||
if model._data_augmentation :
|
||||
model.augment(mode=False)
|
||||
sup_logits = model(xs)
|
||||
|
@ -883,6 +884,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
# Unsupervised loss (KLdiv)
|
||||
if model._data_augmentation:
|
||||
aug_logits = model(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
|
@ -916,21 +918,22 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=True) #Contrainte sum(proba)=1
|
||||
|
||||
if hp_opt:
|
||||
#Adjust Hyper-parameters
|
||||
model['data_aug'].adjust_param(soft=True) #Contrainte sum(proba)=1
|
||||
if hp_opt:
|
||||
for param_group in diffopt.param_groups:
|
||||
for param in list(opt_param['Inner'].keys())[1:]:
|
||||
param_group[param].data = param_group[param].data.clamp(min=1e-4)
|
||||
|
||||
#Reset gradients
|
||||
diffopt.detach_()
|
||||
model['model'].detach_()
|
||||
|
||||
meta_opt.zero_grad()
|
||||
|
||||
tf = time.process_time()
|
||||
|
||||
if save_sample:
|
||||
if save_sample: #Data sample saving
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
@ -939,10 +942,10 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
pass
|
||||
|
||||
|
||||
if(not val_loss):
|
||||
if(not val_loss): #Compute val loss for logs
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
|
||||
# Test model
|
||||
accuracy, test_loss =test(model)
|
||||
model.train()
|
||||
|
||||
|
@ -956,8 +959,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
"time": tf - t0,
|
||||
|
||||
"mix_dist": model['data_aug']['mix_dist'].item(),
|
||||
"param": param, #if isinstance(model['data_aug'], Data_augV5)
|
||||
#else [p.item() for p in model['data_aug']['prob']],
|
||||
"param": param,
|
||||
}
|
||||
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'].item(), 'momentum': p_grp['momentum'].item()} for p_grp in diffopt.param_groups]
|
||||
log.append(data)
|
||||
|
@ -981,12 +983,15 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
for param_group in diffopt.param_groups:
|
||||
print('Opt param - lr:', param_group['lr'].item(),'- momentum:', param_group['momentum'].item())
|
||||
#############
|
||||
|
||||
#Augmentation de donnee differee
|
||||
if not model.is_augmenting() and (epoch == dataug_epoch_start):
|
||||
print('Starting Data Augmention...')
|
||||
dataug_epoch_start = epoch
|
||||
model.augment(mode=True)
|
||||
if inner_it != 0: high_grad_track = True
|
||||
|
||||
#Data sample saving
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue