mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Suite de test brutus
This commit is contained in:
parent
cc737b7997
commit
0e7ec8b5b0
6 changed files with 65 additions and 47 deletions
|
@ -563,8 +563,11 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
|
||||
self._reg_tgt = torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
if self._shared_mag :
|
||||
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
|
||||
else:
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
def forward(self, x):
|
||||
if self._data_augmentation:
|
||||
|
@ -628,7 +631,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX) #Bloque a PARAMETER_MAX
|
||||
|
||||
def loss_weight(self):
|
||||
# 1 seule TF
|
||||
|
@ -638,7 +641,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
#w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
#Plusieurs TF sequentielles
|
||||
#Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !)
|
||||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||
for sample in self._samples:
|
||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||
|
@ -650,8 +653,12 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
return reg_factor * F.mse_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt.to(self._params['mag'].device), reduction='mean')
|
||||
if self._fixed_mag:
|
||||
return torch.tensor(0)
|
||||
else:
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||
return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean')
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue