+controle mag reg + Tests fonction mag reg

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-25 14:05:17 -05:00
parent fc0fb25148
commit 534e244307
6 changed files with 45 additions and 22 deletions

View file

@ -115,8 +115,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
if self._shared_mag :
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
else:
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in self._TF_ignore_mag]
TF_mag=[t for t in self._TF if t not in self._TF_ignore_mag] #TF w/ differentiable mag
self._reg_mask=[self._TF.index(t) for t in TF_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
#Prevent Identity
#print(TF.TF_identity)
#self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=0.0)
#for val in TF.TF_identity.keys():
# idx=[self._reg_mask.index(self._TF.index(t)) for t in TF_mag if t in TF.TF_identity[val]]
# self._reg_tgt[idx]=val
#print(TF_mag, self._reg_tgt)
def forward(self, x):
""" Main method of the Data augmentation module.
@ -247,7 +256,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
else:
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Close to target ?
#max_mag_reg = - reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Far from target ?
return max_mag_reg
def train(self, mode=True):