mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
+controle mag reg + Tests fonction mag reg
This commit is contained in:
parent
fc0fb25148
commit
534e244307
6 changed files with 45 additions and 22 deletions
|
@ -115,8 +115,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
if self._shared_mag :
|
||||
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
|
||||
else:
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in self._TF_ignore_mag]
|
||||
TF_mag=[t for t in self._TF if t not in self._TF_ignore_mag] #TF w/ differentiable mag
|
||||
self._reg_mask=[self._TF.index(t) for t in TF_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
#Prevent Identity
|
||||
#print(TF.TF_identity)
|
||||
#self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=0.0)
|
||||
#for val in TF.TF_identity.keys():
|
||||
# idx=[self._reg_mask.index(self._TF.index(t)) for t in TF_mag if t in TF.TF_identity[val]]
|
||||
# self._reg_tgt[idx]=val
|
||||
#print(TF_mag, self._reg_tgt)
|
||||
|
||||
def forward(self, x):
|
||||
""" Main method of the Data augmentation module.
|
||||
|
@ -247,7 +256,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
else:
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
||||
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Close to target ?
|
||||
#max_mag_reg = - reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Far from target ?
|
||||
return max_mag_reg
|
||||
|
||||
def train(self, mode=True):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue