mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Option Weight_loss avec mean + lisbilite mixed_loss
This commit is contained in:
parent
b820f49437
commit
755e3ca024
9 changed files with 23214 additions and 1396 deletions
|
@ -217,13 +217,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
if not self._fixed_mix:
|
||||
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
||||
|
||||
def loss_weight(self):
|
||||
def loss_weight(self, mean_norm=False):
|
||||
""" Weights for the loss.
|
||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||
Should be applied to the loss before reduction.
|
||||
|
||||
Do not take into account the order of application of the TF. See Data_augV7.
|
||||
|
||||
Args:
|
||||
mean_norm (bool): Wether to normalize weights by mean or by distribution. (Default: Normalize by distribution.)
|
||||
Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
|
||||
Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if mix_dist=1.
|
||||
Returns:
|
||||
Tensor : Loss weights.
|
||||
"""
|
||||
|
@ -238,8 +242,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
||||
w_loss += tmp_w
|
||||
|
||||
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
if mean_norm:
|
||||
w_loss = w_loss * prob
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
|
||||
else:
|
||||
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue