mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Data_aug main comments
This commit is contained in:
parent
2fe5070b09
commit
2d6d2f7397
1 changed files with 234 additions and 1 deletions
235
higher/dataug.py
235
higher/dataug.py
|
@ -531,9 +531,42 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||
|
||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||
"""Data augmentation module with learnable parameters.
|
||||
|
||||
Applies transformations (TF) to batch of data.
|
||||
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
|
||||
The TF probabilities defines a distribution from which we sample the TF applied.
|
||||
|
||||
Attributes:
|
||||
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||
_TF (list) : List of TF names.
|
||||
_nb_tf (int) : Number of TF used.
|
||||
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
||||
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF.
|
||||
_fixed_mag (bool): Wether to lock the TF magnitudes.
|
||||
_fixed_prob (bool): Wether to lock the TF probabilies.
|
||||
_samples (list): Sampled TF index during last forward pass.
|
||||
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||
_fixed_mix (bool): Wether we lock the mix distribution factor.
|
||||
_params (nn.ParameterDict): Learnable parameters.
|
||||
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||
"""
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||
"""Init Data_augv5.
|
||||
|
||||
Args:
|
||||
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
||||
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
||||
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||
"""
|
||||
super(Data_augV5, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
assert N_TF>=0
|
||||
|
||||
self._data_augmentation = True
|
||||
|
||||
|
@ -582,6 +615,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
def forward(self, x):
|
||||
""" Main method of the Data augmentation module.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
|
||||
Returns:
|
||||
Tensor : Batch of tranformed data.
|
||||
"""
|
||||
self._samples = []
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
device = x.device
|
||||
|
@ -609,6 +650,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
""" Applies the sampled transformations.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of tranformed data.
|
||||
"""
|
||||
device = x.device
|
||||
batch_size, channels, h, w = x.shape
|
||||
smps_x=[]
|
||||
|
@ -635,6 +685,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
return x
|
||||
|
||||
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||
""" Enforce limitations to the learned parameters.
|
||||
|
||||
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
|
||||
|
||||
Args:
|
||||
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
|
||||
"""
|
||||
if not self._fixed_prob:
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
|
@ -649,6 +706,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
||||
|
||||
def loss_weight(self):
|
||||
""" Weights for the loss.
|
||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||
Should be applied to the loss before reduction.
|
||||
|
||||
TODO: Take into account the order of application of the TF.
|
||||
|
||||
Returns:
|
||||
Tensor : Loss weights.
|
||||
"""
|
||||
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
||||
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
|
@ -665,6 +731,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
""" Regularisation term used to learn the magnitudes.
|
||||
Use an L2 loss to encourage high magnitudes TF.
|
||||
|
||||
Args:
|
||||
reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005)
|
||||
Returns:
|
||||
Tensor containing the regularisation loss value.
|
||||
"""
|
||||
if self._fixed_mag:
|
||||
return torch.tensor(0)
|
||||
else:
|
||||
|
@ -674,21 +748,45 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
return max_mag_reg
|
||||
|
||||
def train(self, mode=None):
|
||||
""" Set the module training mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
||||
"""
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(Data_augV5, self).train(mode)
|
||||
|
||||
def eval(self):
|
||||
""" Set the module to evaluation mode.
|
||||
"""
|
||||
self.train(mode=False)
|
||||
|
||||
def augment(self, mode=True):
|
||||
""" Set the augmentation mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
||||
"""
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Access to the learnable parameters
|
||||
Args:
|
||||
key (string): Name of the learnable parameter to access.
|
||||
|
||||
Returns:
|
||||
nn.Parameter.
|
||||
"""
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
"""Name of the module
|
||||
|
||||
Returns:
|
||||
String containing the name of the module as well as the higher levels parameters.
|
||||
"""
|
||||
dist_param=''
|
||||
if self._fixed_prob: dist_param+='Fx'
|
||||
mag_param='Mag'
|
||||
|
@ -925,7 +1023,30 @@ class Data_augV6(nn.Module): #Optimisation sequentielle #Mauvais resultats
|
|||
|
||||
|
||||
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||
"""RandAugment implementation.
|
||||
|
||||
Applies transformations (TF) to batch of data.
|
||||
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple. For the full definiton of the TF, see transformations.py.
|
||||
The TF probabilities are ignored and, instead selected randomly.
|
||||
|
||||
Attributes:
|
||||
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||
_TF (list) : List of TF names.
|
||||
_nb_tf (int) : Number of TF used.
|
||||
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
||||
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Should be True.
|
||||
_fixed_mag (bool): Wether to lock the TF magnitudes. Should be True.
|
||||
_params (nn.ParameterDict): Data augmentation parameters.
|
||||
"""
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||
"""Init RandAug.
|
||||
|
||||
Args:
|
||||
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
||||
mag (float): Magnitude of the TF. Should be between [PARAMETER_MIN, PARAMETER_MAX] defined in transformations.py. (default: PARAMETER_MAX)
|
||||
"""
|
||||
super(RandAug, self).__init__()
|
||||
|
||||
self._data_augmentation = True
|
||||
|
@ -937,13 +1058,23 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
|
||||
self.mag=nn.Parameter(torch.tensor(float(mag)))
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #pas utilise
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Ignored
|
||||
"mag" : nn.Parameter(torch.tensor(float(mag))),
|
||||
})
|
||||
self._shared_mag = True
|
||||
self._fixed_mag = True
|
||||
|
||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||
|
||||
def forward(self, x):
|
||||
""" Main method of the Data augmentation module.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
|
||||
Returns:
|
||||
Tensor : Batch of tranformed data.
|
||||
"""
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
@ -961,6 +1092,15 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
""" Applies the sampled transformations.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of tranformed data.
|
||||
"""
|
||||
smps_x=[]
|
||||
|
||||
for tf_idx in range(self._nb_tf):
|
||||
|
@ -979,30 +1119,60 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
return x
|
||||
|
||||
def adjust_param(self, soft=False):
|
||||
"""Not used
|
||||
"""
|
||||
pass #Pas de parametre a opti
|
||||
|
||||
def loss_weight(self):
|
||||
"""Not used
|
||||
"""
|
||||
return 1 #Pas d'echantillon = pas de ponderation
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
"""Not used
|
||||
"""
|
||||
return torch.tensor(0) #Pas de regularisation
|
||||
|
||||
def train(self, mode=None):
|
||||
""" Set the module training mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
||||
"""
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(RandAug, self).train(mode)
|
||||
|
||||
def eval(self):
|
||||
""" Set the module to evaluation mode.
|
||||
"""
|
||||
self.train(mode=False)
|
||||
|
||||
def augment(self, mode=True):
|
||||
""" Set the augmentation mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
||||
"""
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Access to the learnable parameters
|
||||
Args:
|
||||
key (string): Name of the learnable parameter to access.
|
||||
|
||||
Returns:
|
||||
nn.Parameter.
|
||||
"""
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
"""Name of the module
|
||||
|
||||
Returns:
|
||||
String containing the name of the module as well as the higher levels parameters.
|
||||
"""
|
||||
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||
|
||||
class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training)
|
||||
|
@ -1092,7 +1262,21 @@ class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training)
|
|||
return "RandAugUDA(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||
|
||||
class Augmented_model(nn.Module):
|
||||
"""Wrapper for a Data Augmentation module and a model.
|
||||
|
||||
Attributes:
|
||||
_mods (nn.ModuleDict): A dictionary containing the modules.
|
||||
_data_augmentation (bool): Wether data augmentation is used.
|
||||
"""
|
||||
def __init__(self, data_augmenter, model):
|
||||
"""Init Augmented Model.
|
||||
|
||||
By default, data augmentation will be performed.
|
||||
|
||||
Args:
|
||||
data_augmenter (nn.Module): Data augmentation module.
|
||||
model (nn.Module): Network.
|
||||
"""
|
||||
super(Augmented_model, self).__init__()
|
||||
|
||||
self._mods = nn.ModuleDict({
|
||||
|
@ -1103,13 +1287,33 @@ class Augmented_model(nn.Module):
|
|||
self.augment(mode=True)
|
||||
|
||||
def forward(self, x):
|
||||
""" Main method of the Augmented model.
|
||||
|
||||
Perform the forward pass of both modules.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
|
||||
Returns:
|
||||
Tensor : Output of the networks. Should be logits.
|
||||
"""
|
||||
return self._mods['model'](self._mods['data_aug'](x))
|
||||
|
||||
def augment(self, mode=True):
|
||||
""" Set the augmentation mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
||||
"""
|
||||
self._data_augmentation=mode
|
||||
self._mods['data_aug'].augment(mode)
|
||||
|
||||
def train(self, mode=None):
|
||||
""" Set the module training mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
||||
"""
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self._mods['data_aug'].augment(mode)
|
||||
|
@ -1117,6 +1321,8 @@ class Augmented_model(nn.Module):
|
|||
return self
|
||||
|
||||
def eval(self):
|
||||
""" Set the module to evaluation mode.
|
||||
"""
|
||||
return self.train(mode=False)
|
||||
|
||||
def items(self):
|
||||
|
@ -1125,21 +1331,48 @@ class Augmented_model(nn.Module):
|
|||
return self._mods.items()
|
||||
|
||||
def update(self, modules):
|
||||
"""Update the module dictionnary.
|
||||
|
||||
The new dictionnary should keep the same structure.
|
||||
"""
|
||||
assert(self._mods.keys()==modules.keys())
|
||||
self._mods.update(modules)
|
||||
|
||||
def is_augmenting(self):
|
||||
""" Return wether data augmentation is applied.
|
||||
|
||||
Returns:
|
||||
bool : True if data augmentation is applied.
|
||||
"""
|
||||
return self._data_augmentation
|
||||
|
||||
def TF_names(self):
|
||||
""" Get the transformations names used by the data augmentation module.
|
||||
|
||||
Returns:
|
||||
list : names of the transformations of the data augmentation module.
|
||||
"""
|
||||
try:
|
||||
return self._mods['data_aug']._TF
|
||||
except:
|
||||
return None
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Access to the modules.
|
||||
Args:
|
||||
key (string): Name of the module to access.
|
||||
|
||||
Returns:
|
||||
nn.Module.
|
||||
"""
|
||||
return self._mods[key]
|
||||
|
||||
def __str__(self):
|
||||
"""Name of the module
|
||||
|
||||
Returns:
|
||||
String containing the name of the module as well as the higher levels parameters.
|
||||
"""
|
||||
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
||||
|
||||
'''
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue