mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Option Weight_loss avec mean + lisbilite mixed_loss
This commit is contained in:
parent
b820f49437
commit
755e3ca024
9 changed files with 23214 additions and 1396 deletions
|
@ -217,13 +217,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
if not self._fixed_mix:
|
||||
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
||||
|
||||
def loss_weight(self):
|
||||
def loss_weight(self, mean_norm=False):
|
||||
""" Weights for the loss.
|
||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||
Should be applied to the loss before reduction.
|
||||
|
||||
Do not take into account the order of application of the TF. See Data_augV7.
|
||||
|
||||
Args:
|
||||
mean_norm (bool): Wether to normalize weights by mean or by distribution. (Default: Normalize by distribution.)
|
||||
Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
|
||||
Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if mix_dist=1.
|
||||
Returns:
|
||||
Tensor : Loss weights.
|
||||
"""
|
||||
|
@ -238,8 +242,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
||||
w_loss += tmp_w
|
||||
|
||||
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
if mean_norm:
|
||||
w_loss = w_loss * prob
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
|
||||
else:
|
||||
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
|
|
|
@ -113,9 +113,10 @@ def mixed_loss(xs, ys, model, unsup_factor=1):
|
|||
# Unsupervised loss
|
||||
aug_logits = model(xs)
|
||||
w_loss = model['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
|
||||
log_aug = F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss = (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||
aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
|
||||
aug_loss = (aug_loss * w_loss).mean()
|
||||
|
||||
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
|
||||
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
|
||||
|
@ -295,7 +296,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
meta_scheduler=None
|
||||
if opt_param['Meta']['scheduler']=='multiStep':
|
||||
meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt,
|
||||
milestones=[int(epochs/3), int(epochs*2/3)]#, int(epochs*2.7/3)],
|
||||
milestones=[int(epochs/3), int(epochs*2/3)],# int(epochs*2.7/3)],
|
||||
gamma=3.16)#10)
|
||||
elif opt_param['Meta']['scheduler'] is not None:
|
||||
raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler'])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue