diff --git a/higher/smart_aug/dataug.py b/higher/smart_aug/dataug.py index f14a3d3..bc36b04 100755 --- a/higher/smart_aug/dataug.py +++ b/higher/smart_aug/dataug.py @@ -251,7 +251,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) Returns: Tensor containing the regularisation loss value. """ - if self._fixed_mag: + if self._fixed_mag or self._fixed_prob: #Not enough DOF return torch.tensor(0) else: #return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean') @@ -559,7 +559,7 @@ class Data_augV7(nn.Module): #Proba sequentielles Returns: Tensor containing the regularisation loss value. """ - if self._fixed_mag: + if self._fixed_mag or self._fixed_prob: #Not enough DOF return torch.tensor(0) else: #return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')