diff --git a/higher/dataug.py b/higher/dataug.py index 714ae4c..431bd28 100755 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -628,8 +628,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) if soft : self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible else: - self._params['prob'].data = F.relu(self._params['prob'].data) - #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) + #self._params['prob'].data = F.relu(self._params['prob'].data) + self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0) self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 if not self._fixed_mag: diff --git a/higher/train_utils.py b/higher/train_utils.py index f4fd85f..c51668d 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -735,6 +735,8 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start val_loss.backward() #print("meta", time.process_time()-t) #print('proba grad',model['data_aug']['prob'].grad) + if model['data_aug']['prob'].grad is None or model['data_aug']['mag'] is None: + print("Warning no grad (iter",i,") :\n Prob-",model['data_aug']['prob'].grad,"\n Mag-", model['data_aug']['mag'].grad) countcopy+=1 model_copy(src=fmodel, dst=model)