Modif pour shared_mag

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-18 14:18:15 -05:00
parent 9ad3f0453b
commit 860d9f1bbb
3 changed files with 10 additions and 8 deletions

View file

@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, glob_mag=True):
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, shared_mag=True):
super(Data_augV5, self).__init__()
assert len(TF_dict)>0
@ -542,11 +542,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._nb_tf= len(self._TF)
self._N_seqTF = N_TF
self._shared_mag = shared_mag
#self._fixed_mag=5 #[0, PARAMETER_MAX]
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
"mag" : nn.Parameter(torch.tensor(0.5).expand(self._nb_tf) if glob_mag else torch.tensor(0.5).repeat(self._nb_tf)) #[0, PARAMETER_MAX]/10
"mag" : nn.Parameter(torch.tensor(0.5) if shared_mag
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]/10
})
self._samples = []
@ -591,7 +593,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
if smp_x.shape[0]!=0: #if there's data to TF
magnitude=self._params["mag"][tf_idx]*10
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
tf=self._TF[tf_idx]
#print(magnitude)