Option Weight_loss avec mean + lisbilite mixed_loss

This commit is contained in:
Harle, Antoine (Contracteur) 2020-03-06 14:25:07 -05:00
parent b820f49437
commit 755e3ca024
9 changed files with 23214 additions and 1396 deletions

View file

@ -113,9 +113,10 @@ def mixed_loss(xs, ys, model, unsup_factor=1):
# Unsupervised loss
aug_logits = model(xs)
w_loss = model['data_aug'].loss_weight() #Weight loss
log_aug = F.log_softmax(aug_logits, dim=1)
aug_loss = (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
aug_loss = (aug_loss * w_loss).mean()
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
@ -295,7 +296,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
meta_scheduler=None
if opt_param['Meta']['scheduler']=='multiStep':
meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt,
milestones=[int(epochs/3), int(epochs*2/3)]#, int(epochs*2.7/3)],
milestones=[int(epochs/3), int(epochs*2/3)],# int(epochs*2.7/3)],
gamma=3.16)#10)
elif opt_param['Meta']['scheduler'] is not None:
raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler'])