mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Remplacement mag regularization par clamp
This commit is contained in:
parent
9693dd7113
commit
53bd421670
3 changed files with 7 additions and 5 deletions
|
@ -633,8 +633,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||||
|
|
||||||
if not self._fixed_mag:
|
if not self._fixed_mag:
|
||||||
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
#self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||||
|
|
||||||
def loss_weight(self):
|
def loss_weight(self):
|
||||||
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
||||||
|
@ -663,8 +663,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
return torch.tensor(0)
|
return torch.tensor(0)
|
||||||
else:
|
else:
|
||||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||||
params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||||
return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean')
|
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
||||||
|
return max_mag_reg
|
||||||
|
|
||||||
def train(self, mode=None):
|
def train(self, mode=None):
|
||||||
if mode is None :
|
if mode is None :
|
||||||
|
|
|
@ -728,7 +728,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
|
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
|
||||||
#print("meta")
|
#print("meta")
|
||||||
|
|
||||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #+ fmodel['data_aug'].reg_loss()
|
||||||
#print_graph(val_loss)
|
#print_graph(val_loss)
|
||||||
|
|
||||||
#t = time.process_time()
|
#t = time.process_time()
|
||||||
|
|
|
@ -145,6 +145,7 @@ def zero_stack(tensor, zero_pos):
|
||||||
|
|
||||||
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
|
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
|
||||||
PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted
|
PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted
|
||||||
|
PARAMETER_MIN = 0.1
|
||||||
def float_parameter(level, maxval):
|
def float_parameter(level, maxval):
|
||||||
"""Helper function to scale `val` between 0 and maxval .
|
"""Helper function to scale `val` between 0 and maxval .
|
||||||
Args:
|
Args:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue