Remplacement mag regularization par clamp

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-13 14:35:48 -05:00
parent 9693dd7113
commit 53bd421670
3 changed files with 7 additions and 5 deletions

View file

@ -633,8 +633,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
if not self._fixed_mag:
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) #Bloque une fois au extreme
#self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
def loss_weight(self):
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
@ -663,8 +663,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
return torch.tensor(0)
else:
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean')
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
return max_mag_reg
def train(self, mode=None):
if mode is None :