mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
Suppresion Softmax pour les mix distrib (Pousse seulement vers dist uniforme)
This commit is contained in:
parent
c8ce6c8024
commit
23351ec13c
1 changed files with 5 additions and 4 deletions
|
@ -560,7 +560,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
self._mix_dist = True
|
||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
||||
self._mix_factor = max(min(mix_dist, 0.999), 0.0)
|
||||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
|
@ -586,7 +586,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._distrib = uniforme_dist
|
||||
else:
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
|
@ -762,7 +762,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle
|
|||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
self._mix_dist = True
|
||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
||||
self._mix_factor = max(min(mix_dist, 0.999), 0.0)
|
||||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
|
@ -793,7 +793,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle
|
|||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
curr_prob = torch.index_select(prob, 0, tf_set)
|
||||
curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1
|
||||
self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, len(tf_set)), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
|
@ -885,6 +885,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle
|
|||
self._current_TF_idx=0
|
||||
for n_tf in range(self._N_seqTF) :
|
||||
TF.random.shuffle(self._TF_schedule[n_tf])
|
||||
#print('-- New schedule --')
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue