This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-13 18:02:36 -05:00
parent 53bd421670
commit e291bc2e44
9 changed files with 55 additions and 44 deletions

View file

@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, ):
super(Data_augV5, self).__init__()
assert len(TF_dict)>0
@ -548,8 +548,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
#self._fixed_mag=5 #[0, PARAMETER_MAX]
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)) if self._shared_mag
else torch.tensor(float(TF.PARAMETER_MAX)).expand(self._nb_tf)), #[0, PARAMETER_MAX]
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)/2) if self._shared_mag
else torch.tensor(float(TF.PARAMETER_MAX)/2).expand(self._nb_tf)), #[0, PARAMETER_MAX]
})
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
@ -633,7 +633,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
if not self._fixed_mag:
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) #Bloque une fois au extreme
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
#self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
def loss_weight(self):