Fix Translate + TF loader

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-14 13:57:17 -05:00
parent 79de0191a8
commit b170af076f
9 changed files with 674 additions and 40 deletions

View file

@ -39,6 +39,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
_data_augmentation (bool): Wether TF will be applied during forward pass.
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
_TF (list) : List of TF names.
_TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
_nb_tf (int) : Number of TF used.
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Beware using shared mag with basic color TF as their lowest magnitude is at PARAMETER_MAX/2.
@ -51,7 +52,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
_reg_mask (list): Mask selecting the TF considered for the regularisation.
"""
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
def __init__(self, TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
"""Init Data_augv5.
Args:
@ -61,6 +62,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
"""
super(Data_augV5, self).__init__()
assert len(TF_dict)>0
@ -71,13 +73,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
#TF
self._TF_dict = TF_dict
self._TF= list(self._TF_dict.keys())
self._TF_ignore_mag=TF_ignore_mag
self._nb_tf= len(self._TF)
self._N_seqTF = N_TF
#Mag
self._shared_mag = shared_mag
self._fixed_mag = fixed_mag
if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0:
if not self._fixed_mag and len([tf for tf in self._TF if tf not in self._TF_ignore_mag])==0:
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
self._fixed_mag=True
@ -112,7 +115,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
if self._shared_mag :
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
else:
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in self._TF_ignore_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
def forward(self, x):
@ -324,6 +327,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
_data_augmentation (bool): Wether TF will be applied during forward pass.
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
_TF (list) : List of TF names.
_TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
_nb_tf (int) : Number of TF used.
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Beware using shared mag with basic color TF as their lowest magnitude is at PARAMETER_MAX/2.
@ -336,7 +340,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
_reg_mask (list): Mask selecting the TF considered for the regularisation.
"""
def __init__(self, TF_dict=TF.TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
def __init__(self, TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
"""Init Data_augv7.
Args:
@ -346,6 +350,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
"""
super(Data_augV7, self).__init__()
assert len(TF_dict)>0
@ -359,13 +364,14 @@ class Data_augV7(nn.Module): #Proba sequentielles
#TF
self._TF_dict = TF_dict
self._TF= list(self._TF_dict.keys())
self._TF_ignore_mag= TF_ignore_mag
self._nb_tf= len(self._TF)
self._N_seqTF = N_TF
#Mag
self._shared_mag = shared_mag
self._fixed_mag = fixed_mag
if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0:
if not self._fixed_mag and len([tf for tf in self._TF if tf not in self._TF_ignore_mag])==0:
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
self._fixed_mag=True
@ -423,7 +429,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
if self._shared_mag :
self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max
else:
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in TF.TF_ignore_mag]
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in self._TF_ignore_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
def forward(self, x):
@ -657,7 +663,7 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
_fixed_mag (bool): Wether to lock the TF magnitudes. Should be True.
_params (nn.ParameterDict): Data augmentation parameters.
"""
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
def __init__(self, TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
"""Init RandAug.
Args: