mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Option Weight_loss avec mean + lisbilite mixed_loss
This commit is contained in:
parent
b820f49437
commit
755e3ca024
9 changed files with 23214 additions and 1396 deletions
|
@ -113,9 +113,10 @@ def mixed_loss(xs, ys, model, unsup_factor=1):
|
|||
# Unsupervised loss
|
||||
aug_logits = model(xs)
|
||||
w_loss = model['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
|
||||
log_aug = F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss = (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||
aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
|
||||
aug_loss = (aug_loss * w_loss).mean()
|
||||
|
||||
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
|
||||
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
|
||||
|
@ -295,7 +296,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
meta_scheduler=None
|
||||
if opt_param['Meta']['scheduler']=='multiStep':
|
||||
meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt,
|
||||
milestones=[int(epochs/3), int(epochs*2/3)]#, int(epochs*2.7/3)],
|
||||
milestones=[int(epochs/3), int(epochs*2/3)],# int(epochs*2.7/3)],
|
||||
gamma=3.16)#10)
|
||||
elif opt_param['Meta']['scheduler'] is not None:
|
||||
raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler'])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue