diff --git a/higher/dataug.py b/higher/dataug.py index 431bd28..f492f2a 100755 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -633,8 +633,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 if not self._fixed_mag: - #self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme - self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX) + self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) #Bloque une fois au extreme + #self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX) def loss_weight(self): if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation @@ -663,8 +663,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) return torch.tensor(0) else: #return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean') - params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask] - return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean') + mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask] + max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') + return max_mag_reg def train(self, mode=None): if mode is None : diff --git a/higher/train_utils.py b/higher/train_utils.py index c51668d..d22fb31 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -728,7 +728,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step #print("meta") - val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss() + val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #+ fmodel['data_aug'].reg_loss() #print_graph(val_loss) #t = time.process_time() diff --git a/higher/transformations.py b/higher/transformations.py index d9b1306..0410a78 100755 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -145,6 +145,7 @@ def zero_stack(tensor, zero_pos): #https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137 PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted +PARAMETER_MIN = 0.1 def float_parameter(level, maxval): """Helper function to scale `val` between 0 and maxval . Args: