Option Weight_loss avec mean + lisbilite mixed_loss

This commit is contained in:
Harle, Antoine (Contracteur) 2020-03-06 14:25:07 -05:00
parent b820f49437
commit 755e3ca024
9 changed files with 23214 additions and 1396 deletions

View file

@ -217,13 +217,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
if not self._fixed_mix:
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
def loss_weight(self):
def loss_weight(self, mean_norm=False):
""" Weights for the loss.
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction.
Do not take into account the order of application of the TF. See Data_augV7.
Args:
mean_norm (bool): Wether to normalize weights by mean or by distribution. (Default: Normalize by distribution.)
Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if mix_dist=1.
Returns:
Tensor : Loss weights.
"""
@ -238,8 +242,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
w_loss += tmp_w
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
w_loss = torch.sum(w_loss,dim=1)
if mean_norm:
w_loss = w_loss * prob
w_loss = torch.sum(w_loss,dim=1)
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
else:
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
w_loss = torch.sum(w_loss,dim=1)
return w_loss
def reg_loss(self, reg_factor=0.005):

View file

@ -113,9 +113,10 @@ def mixed_loss(xs, ys, model, unsup_factor=1):
# Unsupervised loss
aug_logits = model(xs)
w_loss = model['data_aug'].loss_weight() #Weight loss
log_aug = F.log_softmax(aug_logits, dim=1)
aug_loss = (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
aug_loss = (aug_loss * w_loss).mean()
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
@ -295,7 +296,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
meta_scheduler=None
if opt_param['Meta']['scheduler']=='multiStep':
meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt,
milestones=[int(epochs/3), int(epochs*2/3)]#, int(epochs*2.7/3)],
milestones=[int(epochs/3), int(epochs*2/3)],# int(epochs*2.7/3)],
gamma=3.16)#10)
elif opt_param['Meta']['scheduler'] is not None:
raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler'])