Data_aug main comments

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-20 16:10:17 -05:00
parent 2fe5070b09
commit 2d6d2f7397

View file

@ -531,9 +531,42 @@ class Data_augV4(nn.Module): #Transformations avec mask
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF) return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
"""Data augmentation module with learnable parameters.
Applies transformations (TF) to batch of data.
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
The TF probabilities defines a distribution from which we sample the TF applied.
Attributes:
_data_augmentation (bool): Wether TF will be applied during forward pass.
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
_TF (list) : List of TF names.
_nb_tf (int) : Number of TF used.
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF.
_fixed_mag (bool): Wether to lock the TF magnitudes.
_fixed_prob (bool): Wether to lock the TF probabilies.
_samples (list): Sampled TF index during last forward pass.
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
_fixed_mix (bool): Wether we lock the mix distribution factor.
_params (nn.ParameterDict): Learnable parameters.
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
_reg_mask (list): Mask selecting the TF considered for the regularisation.
"""
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True): def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
"""Init Data_augv5.
Args:
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
"""
super(Data_augV5, self).__init__() super(Data_augV5, self).__init__()
assert len(TF_dict)>0 assert len(TF_dict)>0
assert N_TF>=0
self._data_augmentation = True self._data_augmentation = True
@ -582,6 +615,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
def forward(self, x): def forward(self, x):
""" Main method of the Data augmentation module.
Args:
x (Tensor): Batch of data.
Returns:
Tensor : Batch of tranformed data.
"""
self._samples = [] self._samples = []
if self._data_augmentation:# and TF.random.random() < 0.5: if self._data_augmentation:# and TF.random.random() < 0.5:
device = x.device device = x.device
@ -609,6 +650,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
return x return x
def apply_TF(self, x, sampled_TF): def apply_TF(self, x, sampled_TF):
""" Applies the sampled transformations.
Args:
x (Tensor): Batch of data.
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
Returns:
Tensor: Batch of tranformed data.
"""
device = x.device device = x.device
batch_size, channels, h, w = x.shape batch_size, channels, h, w = x.shape
smps_x=[] smps_x=[]
@ -635,6 +685,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
return x return x
def adjust_param(self, soft=False): #Detach from gradient ? def adjust_param(self, soft=False): #Detach from gradient ?
""" Enforce limitations to the learned parameters.
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
Args:
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
"""
if not self._fixed_prob: if not self._fixed_prob:
if soft : if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
@ -649,6 +706,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999) self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
def loss_weight(self): def loss_weight(self):
""" Weights for the loss.
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction.
TODO: Take into account the order of application of the TF.
Returns:
Tensor : Loss weights.
"""
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
@ -665,6 +731,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
return w_loss return w_loss
def reg_loss(self, reg_factor=0.005): def reg_loss(self, reg_factor=0.005):
""" Regularisation term used to learn the magnitudes.
Use an L2 loss to encourage high magnitudes TF.
Args:
reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005)
Returns:
Tensor containing the regularisation loss value.
"""
if self._fixed_mag: if self._fixed_mag:
return torch.tensor(0) return torch.tensor(0)
else: else:
@ -674,21 +748,45 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
return max_mag_reg return max_mag_reg
def train(self, mode=None): def train(self, mode=None):
""" Set the module training mode.
Args:
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
"""
if mode is None : if mode is None :
mode=self._data_augmentation mode=self._data_augmentation
self.augment(mode=mode) #Inutile si mode=None self.augment(mode=mode) #Inutile si mode=None
super(Data_augV5, self).train(mode) super(Data_augV5, self).train(mode)
def eval(self): def eval(self):
""" Set the module to evaluation mode.
"""
self.train(mode=False) self.train(mode=False)
def augment(self, mode=True): def augment(self, mode=True):
""" Set the augmentation mode.
Args:
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
"""
self._data_augmentation=mode self._data_augmentation=mode
def __getitem__(self, key): def __getitem__(self, key):
"""Access to the learnable parameters
Args:
key (string): Name of the learnable parameter to access.
Returns:
nn.Parameter.
"""
return self._params[key] return self._params[key]
def __str__(self): def __str__(self):
"""Name of the module
Returns:
String containing the name of the module as well as the higher levels parameters.
"""
dist_param='' dist_param=''
if self._fixed_prob: dist_param+='Fx' if self._fixed_prob: dist_param+='Fx'
mag_param='Mag' mag_param='Mag'
@ -925,7 +1023,30 @@ class Data_augV6(nn.Module): #Optimisation sequentielle #Mauvais resultats
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
"""RandAugment implementation.
Applies transformations (TF) to batch of data.
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple. For the full definiton of the TF, see transformations.py.
The TF probabilities are ignored and, instead selected randomly.
Attributes:
_data_augmentation (bool): Wether TF will be applied during forward pass.
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
_TF (list) : List of TF names.
_nb_tf (int) : Number of TF used.
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Should be True.
_fixed_mag (bool): Wether to lock the TF magnitudes. Should be True.
_params (nn.ParameterDict): Data augmentation parameters.
"""
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX): def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
"""Init RandAug.
Args:
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
mag (float): Magnitude of the TF. Should be between [PARAMETER_MIN, PARAMETER_MAX] defined in transformations.py. (default: PARAMETER_MAX)
"""
super(RandAug, self).__init__() super(RandAug, self).__init__()
self._data_augmentation = True self._data_augmentation = True
@ -937,13 +1058,23 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
self.mag=nn.Parameter(torch.tensor(float(mag))) self.mag=nn.Parameter(torch.tensor(float(mag)))
self._params = nn.ParameterDict({ self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #pas utilise "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Ignored
"mag" : nn.Parameter(torch.tensor(float(mag))), "mag" : nn.Parameter(torch.tensor(float(mag))),
}) })
self._shared_mag = True self._shared_mag = True
self._fixed_mag = True self._fixed_mag = True
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
def forward(self, x): def forward(self, x):
""" Main method of the Data augmentation module.
Args:
x (Tensor): Batch of data.
Returns:
Tensor : Batch of tranformed data.
"""
if self._data_augmentation:# and TF.random.random() < 0.5: if self._data_augmentation:# and TF.random.random() < 0.5:
device = x.device device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
@ -961,6 +1092,15 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
return x return x
def apply_TF(self, x, sampled_TF): def apply_TF(self, x, sampled_TF):
""" Applies the sampled transformations.
Args:
x (Tensor): Batch of data.
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
Returns:
Tensor: Batch of tranformed data.
"""
smps_x=[] smps_x=[]
for tf_idx in range(self._nb_tf): for tf_idx in range(self._nb_tf):
@ -979,30 +1119,60 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
return x return x
def adjust_param(self, soft=False): def adjust_param(self, soft=False):
"""Not used
"""
pass #Pas de parametre a opti pass #Pas de parametre a opti
def loss_weight(self): def loss_weight(self):
"""Not used
"""
return 1 #Pas d'echantillon = pas de ponderation return 1 #Pas d'echantillon = pas de ponderation
def reg_loss(self, reg_factor=0.005): def reg_loss(self, reg_factor=0.005):
"""Not used
"""
return torch.tensor(0) #Pas de regularisation return torch.tensor(0) #Pas de regularisation
def train(self, mode=None): def train(self, mode=None):
""" Set the module training mode.
Args:
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
"""
if mode is None : if mode is None :
mode=self._data_augmentation mode=self._data_augmentation
self.augment(mode=mode) #Inutile si mode=None self.augment(mode=mode) #Inutile si mode=None
super(RandAug, self).train(mode) super(RandAug, self).train(mode)
def eval(self): def eval(self):
""" Set the module to evaluation mode.
"""
self.train(mode=False) self.train(mode=False)
def augment(self, mode=True): def augment(self, mode=True):
""" Set the augmentation mode.
Args:
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
"""
self._data_augmentation=mode self._data_augmentation=mode
def __getitem__(self, key): def __getitem__(self, key):
"""Access to the learnable parameters
Args:
key (string): Name of the learnable parameter to access.
Returns:
nn.Parameter.
"""
return self._params[key] return self._params[key]
def __str__(self): def __str__(self):
"""Name of the module
Returns:
String containing the name of the module as well as the higher levels parameters.
"""
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training) class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training)
@ -1092,7 +1262,21 @@ class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training)
return "RandAugUDA(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) return "RandAugUDA(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
class Augmented_model(nn.Module): class Augmented_model(nn.Module):
"""Wrapper for a Data Augmentation module and a model.
Attributes:
_mods (nn.ModuleDict): A dictionary containing the modules.
_data_augmentation (bool): Wether data augmentation is used.
"""
def __init__(self, data_augmenter, model): def __init__(self, data_augmenter, model):
"""Init Augmented Model.
By default, data augmentation will be performed.
Args:
data_augmenter (nn.Module): Data augmentation module.
model (nn.Module): Network.
"""
super(Augmented_model, self).__init__() super(Augmented_model, self).__init__()
self._mods = nn.ModuleDict({ self._mods = nn.ModuleDict({
@ -1103,13 +1287,33 @@ class Augmented_model(nn.Module):
self.augment(mode=True) self.augment(mode=True)
def forward(self, x): def forward(self, x):
""" Main method of the Augmented model.
Perform the forward pass of both modules.
Args:
x (Tensor): Batch of data.
Returns:
Tensor : Output of the networks. Should be logits.
"""
return self._mods['model'](self._mods['data_aug'](x)) return self._mods['model'](self._mods['data_aug'](x))
def augment(self, mode=True): def augment(self, mode=True):
""" Set the augmentation mode.
Args:
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
"""
self._data_augmentation=mode self._data_augmentation=mode
self._mods['data_aug'].augment(mode) self._mods['data_aug'].augment(mode)
def train(self, mode=None): def train(self, mode=None):
""" Set the module training mode.
Args:
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
"""
if mode is None : if mode is None :
mode=self._data_augmentation mode=self._data_augmentation
self._mods['data_aug'].augment(mode) self._mods['data_aug'].augment(mode)
@ -1117,6 +1321,8 @@ class Augmented_model(nn.Module):
return self return self
def eval(self): def eval(self):
""" Set the module to evaluation mode.
"""
return self.train(mode=False) return self.train(mode=False)
def items(self): def items(self):
@ -1125,21 +1331,48 @@ class Augmented_model(nn.Module):
return self._mods.items() return self._mods.items()
def update(self, modules): def update(self, modules):
"""Update the module dictionnary.
The new dictionnary should keep the same structure.
"""
assert(self._mods.keys()==modules.keys())
self._mods.update(modules) self._mods.update(modules)
def is_augmenting(self): def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation return self._data_augmentation
def TF_names(self): def TF_names(self):
""" Get the transformations names used by the data augmentation module.
Returns:
list : names of the transformations of the data augmentation module.
"""
try: try:
return self._mods['data_aug']._TF return self._mods['data_aug']._TF
except: except:
return None return None
def __getitem__(self, key): def __getitem__(self, key):
"""Access to the modules.
Args:
key (string): Name of the module to access.
Returns:
nn.Module.
"""
return self._mods[key] return self._mods[key]
def __str__(self): def __str__(self):
"""Name of the module
Returns:
String containing the name of the module as well as the higher levels parameters.
"""
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")" return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
''' '''