Suppresion Softmax pour les mix distrib (Pousse seulement vers dist uniforme)

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-10 16:28:01 -05:00
parent c8ce6c8024
commit 23351ec13c

View file

@ -560,7 +560,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._mix_dist = False
if mix_dist != 0.0:
self._mix_dist = True
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
self._mix_factor = max(min(mix_dist, 0.999), 0.0)
#Mag regularisation
if not self._fixed_mag:
@ -586,7 +586,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._distrib = uniforme_dist
else:
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
sample = cat_distrib.sample()
@ -762,7 +762,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle
self._mix_dist = False
if mix_dist != 0.0:
self._mix_dist = True
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
self._mix_factor = max(min(mix_dist, 0.999), 0.0)
#Mag regularisation
if not self._fixed_mag:
@ -793,7 +793,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
curr_prob = torch.index_select(prob, 0, tf_set)
curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1
self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, len(tf_set)), device=device)*self._distrib)
sample = cat_distrib.sample()
@ -885,6 +885,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle
self._current_TF_idx=0
for n_tf in range(self._N_seqTF) :
TF.random.shuffle(self._TF_schedule[n_tf])
#print('-- New schedule --')
def train(self, mode=None):
if mode is None :