mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-03 11:40:46 +02:00
Change default data_augV5
This commit is contained in:
parent
7ddb4a41b8
commit
2cbe3d09aa
1 changed files with 4 additions and 4 deletions
|
@ -52,13 +52,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||
"""
|
||||
def __init__(self, TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
def __init__(self, TF_dict, N_TF=1, mix_dist=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
"""Init Data_augv5.
|
||||
|
||||
Args:
|
||||
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
||||
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
||||
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0.5)
|
||||
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||
|
@ -106,8 +106,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
|
||||
})
|
||||
|
||||
#for tf in TF.TF_no_grad :
|
||||
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
for tf in self._TF_ignore_mag :
|
||||
self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Mag regularisation
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue