From 2d6d2f7397cd66ecaa63b500f45555d4f20dd19d Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Mon, 20 Jan 2020 16:10:17 -0500 Subject: [PATCH] Data_aug main comments --- higher/dataug.py | 235 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 234 insertions(+), 1 deletion(-) diff --git a/higher/dataug.py b/higher/dataug.py index 2c6b623..3255ae8 100755 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -531,9 +531,42 @@ class Data_augV4(nn.Module): #Transformations avec mask return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF) class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) + """Data augmentation module with learnable parameters. + + Applies transformations (TF) to batch of data. + Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py. + The TF probabilities defines a distribution from which we sample the TF applied. + + Attributes: + _data_augmentation (bool): Wether TF will be applied during forward pass. + _TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied. + _TF (list) : List of TF names. + _nb_tf (int) : Number of TF used. + _N_seqTF (int) : Number of TF to be applied sequentially to each inputs + _shared_mag (bool) : Wether to share a single magnitude parameters for all TF. + _fixed_mag (bool): Wether to lock the TF magnitudes. + _fixed_prob (bool): Wether to lock the TF probabilies. + _samples (list): Sampled TF index during last forward pass. + _mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used. + _fixed_mix (bool): Wether we lock the mix distribution factor. + _params (nn.ParameterDict): Learnable parameters. + _reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes). + _reg_mask (list): Mask selecting the TF considered for the regularisation. + """ def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True): + """Init Data_augv5. + + Args: + TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py) + N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1) + mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0) + fixed_prob (bool): Wether to lock the TF probabilies. (default: False) + fixed_mag (bool): Wether to lock the TF magnitudes. (default: True) + shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True) + """ super(Data_augV5, self).__init__() assert len(TF_dict)>0 + assert N_TF>=0 self._data_augmentation = True @@ -582,6 +615,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max def forward(self, x): + """ Main method of the Data augmentation module. + + Args: + x (Tensor): Batch of data. + + Returns: + Tensor : Batch of tranformed data. + """ self._samples = [] if self._data_augmentation:# and TF.random.random() < 0.5: device = x.device @@ -609,6 +650,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) return x def apply_TF(self, x, sampled_TF): + """ Applies the sampled transformations. + + Args: + x (Tensor): Batch of data. + sampled_TF (Tensor): Indexes of the TF to be applied to each element of data. + + Returns: + Tensor: Batch of tranformed data. + """ device = x.device batch_size, channels, h, w = x.shape smps_x=[] @@ -635,6 +685,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) return x def adjust_param(self, soft=False): #Detach from gradient ? + """ Enforce limitations to the learned parameters. + + Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters. + + Args: + soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False) + """ if not self._fixed_prob: if soft : self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible @@ -649,6 +706,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999) def loss_weight(self): + """ Weights for the loss. + Compute the weights for the loss of each inputs depending on wich TF was applied to them. + Should be applied to the loss before reduction. + + TODO: Take into account the order of application of the TF. + + Returns: + Tensor : Loss weights. + """ if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] @@ -665,6 +731,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) return w_loss def reg_loss(self, reg_factor=0.005): + """ Regularisation term used to learn the magnitudes. + Use an L2 loss to encourage high magnitudes TF. + + Args: + reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005) + Returns: + Tensor containing the regularisation loss value. + """ if self._fixed_mag: return torch.tensor(0) else: @@ -674,21 +748,45 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) return max_mag_reg def train(self, mode=None): + """ Set the module training mode. + + Args: + mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None) + """ if mode is None : mode=self._data_augmentation self.augment(mode=mode) #Inutile si mode=None super(Data_augV5, self).train(mode) def eval(self): + """ Set the module to evaluation mode. + """ self.train(mode=False) def augment(self, mode=True): + """ Set the augmentation mode. + + Args: + mode (bool): Wether to perform data augmentation on the forward pass. (default: True) + """ self._data_augmentation=mode def __getitem__(self, key): + """Access to the learnable parameters + Args: + key (string): Name of the learnable parameter to access. + + Returns: + nn.Parameter. + """ return self._params[key] def __str__(self): + """Name of the module + + Returns: + String containing the name of the module as well as the higher levels parameters. + """ dist_param='' if self._fixed_prob: dist_param+='Fx' mag_param='Mag' @@ -925,7 +1023,30 @@ class Data_augV6(nn.Module): #Optimisation sequentielle #Mauvais resultats class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide + """RandAugment implementation. + + Applies transformations (TF) to batch of data. + Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple. For the full definiton of the TF, see transformations.py. + The TF probabilities are ignored and, instead selected randomly. + + Attributes: + _data_augmentation (bool): Wether TF will be applied during forward pass. + _TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied. + _TF (list) : List of TF names. + _nb_tf (int) : Number of TF used. + _N_seqTF (int) : Number of TF to be applied sequentially to each inputs + _shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Should be True. + _fixed_mag (bool): Wether to lock the TF magnitudes. Should be True. + _params (nn.ParameterDict): Data augmentation parameters. + """ def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX): + """Init RandAug. + + Args: + TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py) + N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1) + mag (float): Magnitude of the TF. Should be between [PARAMETER_MIN, PARAMETER_MAX] defined in transformations.py. (default: PARAMETER_MAX) + """ super(RandAug, self).__init__() self._data_augmentation = True @@ -937,13 +1058,23 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide self.mag=nn.Parameter(torch.tensor(float(mag))) self._params = nn.ParameterDict({ - "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #pas utilise + "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Ignored "mag" : nn.Parameter(torch.tensor(float(mag))), }) self._shared_mag = True self._fixed_mag = True + self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) + def forward(self, x): + """ Main method of the Data augmentation module. + + Args: + x (Tensor): Batch of data. + + Returns: + Tensor : Batch of tranformed data. + """ if self._data_augmentation:# and TF.random.random() < 0.5: device = x.device batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] @@ -961,6 +1092,15 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide return x def apply_TF(self, x, sampled_TF): + """ Applies the sampled transformations. + + Args: + x (Tensor): Batch of data. + sampled_TF (Tensor): Indexes of the TF to be applied to each element of data. + + Returns: + Tensor: Batch of tranformed data. + """ smps_x=[] for tf_idx in range(self._nb_tf): @@ -979,30 +1119,60 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide return x def adjust_param(self, soft=False): + """Not used + """ pass #Pas de parametre a opti def loss_weight(self): + """Not used + """ return 1 #Pas d'echantillon = pas de ponderation def reg_loss(self, reg_factor=0.005): + """Not used + """ return torch.tensor(0) #Pas de regularisation def train(self, mode=None): + """ Set the module training mode. + + Args: + mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None) + """ if mode is None : mode=self._data_augmentation self.augment(mode=mode) #Inutile si mode=None super(RandAug, self).train(mode) def eval(self): + """ Set the module to evaluation mode. + """ self.train(mode=False) def augment(self, mode=True): + """ Set the augmentation mode. + + Args: + mode (bool): Wether to perform data augmentation on the forward pass. (default: True) + """ self._data_augmentation=mode def __getitem__(self, key): + """Access to the learnable parameters + Args: + key (string): Name of the learnable parameter to access. + + Returns: + nn.Parameter. + """ return self._params[key] def __str__(self): + """Name of the module + + Returns: + String containing the name of the module as well as the higher levels parameters. + """ return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training) @@ -1092,7 +1262,21 @@ class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training) return "RandAugUDA(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) class Augmented_model(nn.Module): + """Wrapper for a Data Augmentation module and a model. + + Attributes: + _mods (nn.ModuleDict): A dictionary containing the modules. + _data_augmentation (bool): Wether data augmentation is used. + """ def __init__(self, data_augmenter, model): + """Init Augmented Model. + + By default, data augmentation will be performed. + + Args: + data_augmenter (nn.Module): Data augmentation module. + model (nn.Module): Network. + """ super(Augmented_model, self).__init__() self._mods = nn.ModuleDict({ @@ -1103,13 +1287,33 @@ class Augmented_model(nn.Module): self.augment(mode=True) def forward(self, x): + """ Main method of the Augmented model. + + Perform the forward pass of both modules. + + Args: + x (Tensor): Batch of data. + + Returns: + Tensor : Output of the networks. Should be logits. + """ return self._mods['model'](self._mods['data_aug'](x)) def augment(self, mode=True): + """ Set the augmentation mode. + + Args: + mode (bool): Wether to perform data augmentation on the forward pass. (default: True) + """ self._data_augmentation=mode self._mods['data_aug'].augment(mode) def train(self, mode=None): + """ Set the module training mode. + + Args: + mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None) + """ if mode is None : mode=self._data_augmentation self._mods['data_aug'].augment(mode) @@ -1117,6 +1321,8 @@ class Augmented_model(nn.Module): return self def eval(self): + """ Set the module to evaluation mode. + """ return self.train(mode=False) def items(self): @@ -1125,21 +1331,48 @@ class Augmented_model(nn.Module): return self._mods.items() def update(self, modules): + """Update the module dictionnary. + + The new dictionnary should keep the same structure. + """ + assert(self._mods.keys()==modules.keys()) self._mods.update(modules) def is_augmenting(self): + """ Return wether data augmentation is applied. + + Returns: + bool : True if data augmentation is applied. + """ return self._data_augmentation def TF_names(self): + """ Get the transformations names used by the data augmentation module. + + Returns: + list : names of the transformations of the data augmentation module. + """ try: return self._mods['data_aug']._TF except: return None def __getitem__(self, key): + """Access to the modules. + Args: + key (string): Name of the module to access. + + Returns: + nn.Module. + """ return self._mods[key] def __str__(self): + """Name of the module + + Returns: + String containing the name of the module as well as the higher levels parameters. + """ return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")" '''