mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Minimum de proba !=0 (ReLu=>Clamp)
This commit is contained in:
parent
18fe14079a
commit
9693dd7113
2 changed files with 4 additions and 2 deletions
|
@ -628,8 +628,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
if soft :
|
if soft :
|
||||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||||
else:
|
else:
|
||||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
#self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||||
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||||
|
|
||||||
if not self._fixed_mag:
|
if not self._fixed_mag:
|
||||||
|
|
|
@ -735,6 +735,8 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
val_loss.backward()
|
val_loss.backward()
|
||||||
#print("meta", time.process_time()-t)
|
#print("meta", time.process_time()-t)
|
||||||
#print('proba grad',model['data_aug']['prob'].grad)
|
#print('proba grad',model['data_aug']['prob'].grad)
|
||||||
|
if model['data_aug']['prob'].grad is None or model['data_aug']['mag'] is None:
|
||||||
|
print("Warning no grad (iter",i,") :\n Prob-",model['data_aug']['prob'].grad,"\n Mag-", model['data_aug']['mag'].grad)
|
||||||
|
|
||||||
countcopy+=1
|
countcopy+=1
|
||||||
model_copy(src=fmodel, dst=model)
|
model_copy(src=fmodel, dst=model)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue