mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Fix Translate + TF loader
This commit is contained in:
parent
79de0191a8
commit
b170af076f
9 changed files with 674 additions and 40 deletions
|
@ -39,6 +39,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||
_TF (list) : List of TF names.
|
||||
_TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
|
||||
_nb_tf (int) : Number of TF used.
|
||||
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
||||
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Beware using shared mag with basic color TF as their lowest magnitude is at PARAMETER_MAX/2.
|
||||
|
@ -51,7 +52,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||
"""
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||
def __init__(self, TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
"""Init Data_augv5.
|
||||
|
||||
Args:
|
||||
|
@ -61,6 +62,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||
TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
|
||||
"""
|
||||
super(Data_augV5, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
|
@ -71,13 +73,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
#TF
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._TF_ignore_mag=TF_ignore_mag
|
||||
self._nb_tf= len(self._TF)
|
||||
self._N_seqTF = N_TF
|
||||
|
||||
#Mag
|
||||
self._shared_mag = shared_mag
|
||||
self._fixed_mag = fixed_mag
|
||||
if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0:
|
||||
if not self._fixed_mag and len([tf for tf in self._TF if tf not in self._TF_ignore_mag])==0:
|
||||
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
|
||||
self._fixed_mag=True
|
||||
|
||||
|
@ -112,7 +115,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
if self._shared_mag :
|
||||
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
|
||||
else:
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in self._TF_ignore_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -324,6 +327,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||
_TF (list) : List of TF names.
|
||||
_TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
|
||||
_nb_tf (int) : Number of TF used.
|
||||
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
||||
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Beware using shared mag with basic color TF as their lowest magnitude is at PARAMETER_MAX/2.
|
||||
|
@ -336,7 +340,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||
"""
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||
def __init__(self, TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
"""Init Data_augv7.
|
||||
|
||||
Args:
|
||||
|
@ -346,6 +350,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||
TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
|
||||
"""
|
||||
super(Data_augV7, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
|
@ -359,13 +364,14 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
#TF
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._TF_ignore_mag= TF_ignore_mag
|
||||
self._nb_tf= len(self._TF)
|
||||
self._N_seqTF = N_TF
|
||||
|
||||
#Mag
|
||||
self._shared_mag = shared_mag
|
||||
self._fixed_mag = fixed_mag
|
||||
if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0:
|
||||
if not self._fixed_mag and len([tf for tf in self._TF if tf not in self._TF_ignore_mag])==0:
|
||||
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
|
||||
self._fixed_mag=True
|
||||
|
||||
|
@ -423,7 +429,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
if self._shared_mag :
|
||||
self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
else:
|
||||
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in TF.TF_ignore_mag]
|
||||
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in self._TF_ignore_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -657,7 +663,7 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
_fixed_mag (bool): Wether to lock the TF magnitudes. Should be True.
|
||||
_params (nn.ParameterDict): Data augmentation parameters.
|
||||
"""
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||
def __init__(self, TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||
"""Init RandAug.
|
||||
|
||||
Args:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue