Refactoring de TF_dict

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-14 21:17:54 -05:00
parent fd4dcdb392
commit 103277fadd
8 changed files with 245 additions and 23 deletions

View file

@ -586,8 +586,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
#PAS PONDERE LOSS POUR DIST MIX
if fmodel._data_augmentation: # and not fmodel['data_aug']._mix_dist: #Weight loss
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight().to(device)
loss = loss * w_loss
loss = loss.mean()