Suite de test brutus

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-20 16:06:27 -05:00
parent cc737b7997
commit 0e7ec8b5b0
6 changed files with 65 additions and 47 deletions

View file

@ -563,8 +563,11 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
#Mag regularisation
if not self._fixed_mag:
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
self._reg_tgt = torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
if self._shared_mag :
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
else:
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
def forward(self, x):
if self._data_augmentation:
@ -628,7 +631,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX) #Bloque a PARAMETER_MAX
def loss_weight(self):
# 1 seule TF
@ -638,7 +641,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
#w_loss = torch.sum(w_loss,dim=1)
#Plusieurs TF sequentielles
#Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !)
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
for sample in self._samples:
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
@ -650,8 +653,12 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
return w_loss
def reg_loss(self, reg_factor=0.005):
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
return reg_factor * F.mse_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt.to(self._params['mag'].device), reduction='mean')
if self._fixed_mag:
return torch.tensor(0)
else:
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean')
def train(self, mode=None):
if mode is None :