mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Data_aug main comments
This commit is contained in:
parent
2fe5070b09
commit
2d6d2f7397
1 changed files with 234 additions and 1 deletions
235
higher/dataug.py
235
higher/dataug.py
|
@ -531,9 +531,42 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||||
|
|
||||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
|
"""Data augmentation module with learnable parameters.
|
||||||
|
|
||||||
|
Applies transformations (TF) to batch of data.
|
||||||
|
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
|
||||||
|
The TF probabilities defines a distribution from which we sample the TF applied.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||||
|
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||||
|
_TF (list) : List of TF names.
|
||||||
|
_nb_tf (int) : Number of TF used.
|
||||||
|
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
||||||
|
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF.
|
||||||
|
_fixed_mag (bool): Wether to lock the TF magnitudes.
|
||||||
|
_fixed_prob (bool): Wether to lock the TF probabilies.
|
||||||
|
_samples (list): Sampled TF index during last forward pass.
|
||||||
|
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||||
|
_fixed_mix (bool): Wether we lock the mix distribution factor.
|
||||||
|
_params (nn.ParameterDict): Learnable parameters.
|
||||||
|
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||||
|
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||||
|
"""
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||||
|
"""Init Data_augv5.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||||
|
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
||||||
|
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
||||||
|
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||||
|
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||||
|
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||||
|
"""
|
||||||
super(Data_augV5, self).__init__()
|
super(Data_augV5, self).__init__()
|
||||||
assert len(TF_dict)>0
|
assert len(TF_dict)>0
|
||||||
|
assert N_TF>=0
|
||||||
|
|
||||||
self._data_augmentation = True
|
self._data_augmentation = True
|
||||||
|
|
||||||
|
@ -582,6 +615,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
""" Main method of the Data augmentation module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Batch of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor : Batch of tranformed data.
|
||||||
|
"""
|
||||||
self._samples = []
|
self._samples = []
|
||||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||||
device = x.device
|
device = x.device
|
||||||
|
@ -609,6 +650,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def apply_TF(self, x, sampled_TF):
|
def apply_TF(self, x, sampled_TF):
|
||||||
|
""" Applies the sampled transformations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Batch of data.
|
||||||
|
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Batch of tranformed data.
|
||||||
|
"""
|
||||||
device = x.device
|
device = x.device
|
||||||
batch_size, channels, h, w = x.shape
|
batch_size, channels, h, w = x.shape
|
||||||
smps_x=[]
|
smps_x=[]
|
||||||
|
@ -635,6 +685,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def adjust_param(self, soft=False): #Detach from gradient ?
|
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||||
|
""" Enforce limitations to the learned parameters.
|
||||||
|
|
||||||
|
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
|
||||||
|
"""
|
||||||
if not self._fixed_prob:
|
if not self._fixed_prob:
|
||||||
if soft :
|
if soft :
|
||||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||||
|
@ -649,6 +706,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
||||||
|
|
||||||
def loss_weight(self):
|
def loss_weight(self):
|
||||||
|
""" Weights for the loss.
|
||||||
|
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||||
|
Should be applied to the loss before reduction.
|
||||||
|
|
||||||
|
TODO: Take into account the order of application of the TF.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor : Loss weights.
|
||||||
|
"""
|
||||||
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
||||||
|
|
||||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||||
|
@ -665,6 +731,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
return w_loss
|
return w_loss
|
||||||
|
|
||||||
def reg_loss(self, reg_factor=0.005):
|
def reg_loss(self, reg_factor=0.005):
|
||||||
|
""" Regularisation term used to learn the magnitudes.
|
||||||
|
Use an L2 loss to encourage high magnitudes TF.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005)
|
||||||
|
Returns:
|
||||||
|
Tensor containing the regularisation loss value.
|
||||||
|
"""
|
||||||
if self._fixed_mag:
|
if self._fixed_mag:
|
||||||
return torch.tensor(0)
|
return torch.tensor(0)
|
||||||
else:
|
else:
|
||||||
|
@ -674,21 +748,45 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
return max_mag_reg
|
return max_mag_reg
|
||||||
|
|
||||||
def train(self, mode=None):
|
def train(self, mode=None):
|
||||||
|
""" Set the module training mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
||||||
|
"""
|
||||||
if mode is None :
|
if mode is None :
|
||||||
mode=self._data_augmentation
|
mode=self._data_augmentation
|
||||||
self.augment(mode=mode) #Inutile si mode=None
|
self.augment(mode=mode) #Inutile si mode=None
|
||||||
super(Data_augV5, self).train(mode)
|
super(Data_augV5, self).train(mode)
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
|
""" Set the module to evaluation mode.
|
||||||
|
"""
|
||||||
self.train(mode=False)
|
self.train(mode=False)
|
||||||
|
|
||||||
def augment(self, mode=True):
|
def augment(self, mode=True):
|
||||||
|
""" Set the augmentation mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
||||||
|
"""
|
||||||
self._data_augmentation=mode
|
self._data_augmentation=mode
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
|
"""Access to the learnable parameters
|
||||||
|
Args:
|
||||||
|
key (string): Name of the learnable parameter to access.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Parameter.
|
||||||
|
"""
|
||||||
return self._params[key]
|
return self._params[key]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
"""Name of the module
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String containing the name of the module as well as the higher levels parameters.
|
||||||
|
"""
|
||||||
dist_param=''
|
dist_param=''
|
||||||
if self._fixed_prob: dist_param+='Fx'
|
if self._fixed_prob: dist_param+='Fx'
|
||||||
mag_param='Mag'
|
mag_param='Mag'
|
||||||
|
@ -925,7 +1023,30 @@ class Data_augV6(nn.Module): #Optimisation sequentielle #Mauvais resultats
|
||||||
|
|
||||||
|
|
||||||
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||||
|
"""RandAugment implementation.
|
||||||
|
|
||||||
|
Applies transformations (TF) to batch of data.
|
||||||
|
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple. For the full definiton of the TF, see transformations.py.
|
||||||
|
The TF probabilities are ignored and, instead selected randomly.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||||
|
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||||
|
_TF (list) : List of TF names.
|
||||||
|
_nb_tf (int) : Number of TF used.
|
||||||
|
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
||||||
|
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Should be True.
|
||||||
|
_fixed_mag (bool): Wether to lock the TF magnitudes. Should be True.
|
||||||
|
_params (nn.ParameterDict): Data augmentation parameters.
|
||||||
|
"""
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||||
|
"""Init RandAug.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||||
|
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
||||||
|
mag (float): Magnitude of the TF. Should be between [PARAMETER_MIN, PARAMETER_MAX] defined in transformations.py. (default: PARAMETER_MAX)
|
||||||
|
"""
|
||||||
super(RandAug, self).__init__()
|
super(RandAug, self).__init__()
|
||||||
|
|
||||||
self._data_augmentation = True
|
self._data_augmentation = True
|
||||||
|
@ -937,13 +1058,23 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||||
|
|
||||||
self.mag=nn.Parameter(torch.tensor(float(mag)))
|
self.mag=nn.Parameter(torch.tensor(float(mag)))
|
||||||
self._params = nn.ParameterDict({
|
self._params = nn.ParameterDict({
|
||||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #pas utilise
|
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Ignored
|
||||||
"mag" : nn.Parameter(torch.tensor(float(mag))),
|
"mag" : nn.Parameter(torch.tensor(float(mag))),
|
||||||
})
|
})
|
||||||
self._shared_mag = True
|
self._shared_mag = True
|
||||||
self._fixed_mag = True
|
self._fixed_mag = True
|
||||||
|
|
||||||
|
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
""" Main method of the Data augmentation module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Batch of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor : Batch of tranformed data.
|
||||||
|
"""
|
||||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||||
device = x.device
|
device = x.device
|
||||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||||
|
@ -961,6 +1092,15 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def apply_TF(self, x, sampled_TF):
|
def apply_TF(self, x, sampled_TF):
|
||||||
|
""" Applies the sampled transformations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Batch of data.
|
||||||
|
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Batch of tranformed data.
|
||||||
|
"""
|
||||||
smps_x=[]
|
smps_x=[]
|
||||||
|
|
||||||
for tf_idx in range(self._nb_tf):
|
for tf_idx in range(self._nb_tf):
|
||||||
|
@ -979,30 +1119,60 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def adjust_param(self, soft=False):
|
def adjust_param(self, soft=False):
|
||||||
|
"""Not used
|
||||||
|
"""
|
||||||
pass #Pas de parametre a opti
|
pass #Pas de parametre a opti
|
||||||
|
|
||||||
def loss_weight(self):
|
def loss_weight(self):
|
||||||
|
"""Not used
|
||||||
|
"""
|
||||||
return 1 #Pas d'echantillon = pas de ponderation
|
return 1 #Pas d'echantillon = pas de ponderation
|
||||||
|
|
||||||
def reg_loss(self, reg_factor=0.005):
|
def reg_loss(self, reg_factor=0.005):
|
||||||
|
"""Not used
|
||||||
|
"""
|
||||||
return torch.tensor(0) #Pas de regularisation
|
return torch.tensor(0) #Pas de regularisation
|
||||||
|
|
||||||
def train(self, mode=None):
|
def train(self, mode=None):
|
||||||
|
""" Set the module training mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
||||||
|
"""
|
||||||
if mode is None :
|
if mode is None :
|
||||||
mode=self._data_augmentation
|
mode=self._data_augmentation
|
||||||
self.augment(mode=mode) #Inutile si mode=None
|
self.augment(mode=mode) #Inutile si mode=None
|
||||||
super(RandAug, self).train(mode)
|
super(RandAug, self).train(mode)
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
|
""" Set the module to evaluation mode.
|
||||||
|
"""
|
||||||
self.train(mode=False)
|
self.train(mode=False)
|
||||||
|
|
||||||
def augment(self, mode=True):
|
def augment(self, mode=True):
|
||||||
|
""" Set the augmentation mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
||||||
|
"""
|
||||||
self._data_augmentation=mode
|
self._data_augmentation=mode
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
|
"""Access to the learnable parameters
|
||||||
|
Args:
|
||||||
|
key (string): Name of the learnable parameter to access.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Parameter.
|
||||||
|
"""
|
||||||
return self._params[key]
|
return self._params[key]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
"""Name of the module
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String containing the name of the module as well as the higher levels parameters.
|
||||||
|
"""
|
||||||
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||||
|
|
||||||
class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training)
|
class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training)
|
||||||
|
@ -1092,7 +1262,21 @@ class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training)
|
||||||
return "RandAugUDA(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
return "RandAugUDA(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||||
|
|
||||||
class Augmented_model(nn.Module):
|
class Augmented_model(nn.Module):
|
||||||
|
"""Wrapper for a Data Augmentation module and a model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_mods (nn.ModuleDict): A dictionary containing the modules.
|
||||||
|
_data_augmentation (bool): Wether data augmentation is used.
|
||||||
|
"""
|
||||||
def __init__(self, data_augmenter, model):
|
def __init__(self, data_augmenter, model):
|
||||||
|
"""Init Augmented Model.
|
||||||
|
|
||||||
|
By default, data augmentation will be performed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_augmenter (nn.Module): Data augmentation module.
|
||||||
|
model (nn.Module): Network.
|
||||||
|
"""
|
||||||
super(Augmented_model, self).__init__()
|
super(Augmented_model, self).__init__()
|
||||||
|
|
||||||
self._mods = nn.ModuleDict({
|
self._mods = nn.ModuleDict({
|
||||||
|
@ -1103,13 +1287,33 @@ class Augmented_model(nn.Module):
|
||||||
self.augment(mode=True)
|
self.augment(mode=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
""" Main method of the Augmented model.
|
||||||
|
|
||||||
|
Perform the forward pass of both modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Batch of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor : Output of the networks. Should be logits.
|
||||||
|
"""
|
||||||
return self._mods['model'](self._mods['data_aug'](x))
|
return self._mods['model'](self._mods['data_aug'](x))
|
||||||
|
|
||||||
def augment(self, mode=True):
|
def augment(self, mode=True):
|
||||||
|
""" Set the augmentation mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
||||||
|
"""
|
||||||
self._data_augmentation=mode
|
self._data_augmentation=mode
|
||||||
self._mods['data_aug'].augment(mode)
|
self._mods['data_aug'].augment(mode)
|
||||||
|
|
||||||
def train(self, mode=None):
|
def train(self, mode=None):
|
||||||
|
""" Set the module training mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
||||||
|
"""
|
||||||
if mode is None :
|
if mode is None :
|
||||||
mode=self._data_augmentation
|
mode=self._data_augmentation
|
||||||
self._mods['data_aug'].augment(mode)
|
self._mods['data_aug'].augment(mode)
|
||||||
|
@ -1117,6 +1321,8 @@ class Augmented_model(nn.Module):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
|
""" Set the module to evaluation mode.
|
||||||
|
"""
|
||||||
return self.train(mode=False)
|
return self.train(mode=False)
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
|
@ -1125,21 +1331,48 @@ class Augmented_model(nn.Module):
|
||||||
return self._mods.items()
|
return self._mods.items()
|
||||||
|
|
||||||
def update(self, modules):
|
def update(self, modules):
|
||||||
|
"""Update the module dictionnary.
|
||||||
|
|
||||||
|
The new dictionnary should keep the same structure.
|
||||||
|
"""
|
||||||
|
assert(self._mods.keys()==modules.keys())
|
||||||
self._mods.update(modules)
|
self._mods.update(modules)
|
||||||
|
|
||||||
def is_augmenting(self):
|
def is_augmenting(self):
|
||||||
|
""" Return wether data augmentation is applied.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool : True if data augmentation is applied.
|
||||||
|
"""
|
||||||
return self._data_augmentation
|
return self._data_augmentation
|
||||||
|
|
||||||
def TF_names(self):
|
def TF_names(self):
|
||||||
|
""" Get the transformations names used by the data augmentation module.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list : names of the transformations of the data augmentation module.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return self._mods['data_aug']._TF
|
return self._mods['data_aug']._TF
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
|
"""Access to the modules.
|
||||||
|
Args:
|
||||||
|
key (string): Name of the module to access.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module.
|
||||||
|
"""
|
||||||
return self._mods[key]
|
return self._mods[key]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
"""Name of the module
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String containing the name of the module as well as the higher levels parameters.
|
||||||
|
"""
|
||||||
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue