Minor improvement + Comments

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-21 13:53:07 -05:00
parent d21a6bbf5c
commit c1ad787d97
5 changed files with 165 additions and 62 deletions

View file

@ -876,7 +876,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
else:
#Methode KL div
# Supervised loss (classic)
if model._data_augmentation :
if model.is_augmenting() :
model.augment(mode=False)
sup_logits = model(xs)
model.augment(mode=True)
@ -886,7 +886,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
loss = F.cross_entropy(log_sup, ys)
# Unsupervised loss (KLdiv)
if model._data_augmentation:
if model.is_augmenting() :
aug_logits = model(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
aug_loss=0
@ -948,7 +948,6 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
accuracy, test_loss =test(model)
model.train()
print(model['data_aug']._data_augmentation)
#### Log ####
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
data={