mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Suppresion Softmax pour les mix distrib (Pousse seulement vers dist uniforme)
This commit is contained in:
parent
c8ce6c8024
commit
23351ec13c
1 changed files with 5 additions and 4 deletions
|
@ -560,7 +560,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
self._mix_dist = False
|
self._mix_dist = False
|
||||||
if mix_dist != 0.0:
|
if mix_dist != 0.0:
|
||||||
self._mix_dist = True
|
self._mix_dist = True
|
||||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
self._mix_factor = max(min(mix_dist, 0.999), 0.0)
|
||||||
|
|
||||||
#Mag regularisation
|
#Mag regularisation
|
||||||
if not self._fixed_mag:
|
if not self._fixed_mag:
|
||||||
|
@ -586,7 +586,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
self._distrib = uniforme_dist
|
self._distrib = uniforme_dist
|
||||||
else:
|
else:
|
||||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||||
self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||||
|
|
||||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||||
sample = cat_distrib.sample()
|
sample = cat_distrib.sample()
|
||||||
|
@ -762,7 +762,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle
|
||||||
self._mix_dist = False
|
self._mix_dist = False
|
||||||
if mix_dist != 0.0:
|
if mix_dist != 0.0:
|
||||||
self._mix_dist = True
|
self._mix_dist = True
|
||||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
self._mix_factor = max(min(mix_dist, 0.999), 0.0)
|
||||||
|
|
||||||
#Mag regularisation
|
#Mag regularisation
|
||||||
if not self._fixed_mag:
|
if not self._fixed_mag:
|
||||||
|
@ -793,7 +793,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle
|
||||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||||
curr_prob = torch.index_select(prob, 0, tf_set)
|
curr_prob = torch.index_select(prob, 0, tf_set)
|
||||||
curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1
|
curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1
|
||||||
self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||||
|
|
||||||
cat_distrib= Categorical(probs=torch.ones((batch_size, len(tf_set)), device=device)*self._distrib)
|
cat_distrib= Categorical(probs=torch.ones((batch_size, len(tf_set)), device=device)*self._distrib)
|
||||||
sample = cat_distrib.sample()
|
sample = cat_distrib.sample()
|
||||||
|
@ -885,6 +885,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle
|
||||||
self._current_TF_idx=0
|
self._current_TF_idx=0
|
||||||
for n_tf in range(self._N_seqTF) :
|
for n_tf in range(self._N_seqTF) :
|
||||||
TF.random.shuffle(self._TF_schedule[n_tf])
|
TF.random.shuffle(self._TF_schedule[n_tf])
|
||||||
|
#print('-- New schedule --')
|
||||||
|
|
||||||
def train(self, mode=None):
|
def train(self, mode=None):
|
||||||
if mode is None :
|
if mode is None :
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue