mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Remplacement mag regularization par clamp
This commit is contained in:
parent
9693dd7113
commit
53bd421670
3 changed files with 7 additions and 5 deletions
|
@ -633,8 +633,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
if not self._fixed_mag:
|
||||
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||
#self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||
|
||||
def loss_weight(self):
|
||||
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
||||
|
@ -663,8 +663,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
return torch.tensor(0)
|
||||
else:
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||
return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean')
|
||||
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
||||
return max_mag_reg
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue