Minimum de proba !=0 (ReLu=>Clamp)

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-13 13:26:53 -05:00
parent 18fe14079a
commit 9693dd7113
2 changed files with 4 additions and 2 deletions

View file

@ -628,8 +628,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
if soft : if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
else: else:
self._params['prob'].data = F.relu(self._params['prob'].data) #self._params['prob'].data = F.relu(self._params['prob'].data)
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
if not self._fixed_mag: if not self._fixed_mag:

View file

@ -735,6 +735,8 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
val_loss.backward() val_loss.backward()
#print("meta", time.process_time()-t) #print("meta", time.process_time()-t)
#print('proba grad',model['data_aug']['prob'].grad) #print('proba grad',model['data_aug']['prob'].grad)
if model['data_aug']['prob'].grad is None or model['data_aug']['mag'] is None:
print("Warning no grad (iter",i,") :\n Prob-",model['data_aug']['prob'].grad,"\n Mag-", model['data_aug']['mag'].grad)
countcopy+=1 countcopy+=1
model_copy(src=fmodel, dst=model) model_copy(src=fmodel, dst=model)