Remplacement mag regularization par clamp

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-13 14:35:48 -05:00
parent 9693dd7113
commit 53bd421670
3 changed files with 7 additions and 5 deletions

View file

@ -633,8 +633,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
if not self._fixed_mag:
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) #Bloque une fois au extreme
#self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
def loss_weight(self):
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
@ -663,8 +663,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
return torch.tensor(0)
else:
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean')
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
return max_mag_reg
def train(self, mode=None):
if mode is None :

View file

@ -728,7 +728,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
#print("meta")
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #+ fmodel['data_aug'].reg_loss()
#print_graph(val_loss)
#t = time.process_time()

View file

@ -145,6 +145,7 @@ def zero_stack(tensor, zero_pos):
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted
PARAMETER_MIN = 0.1
def float_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
Args: