smart_augmentation/higher/smart_aug/dataug.py

1142 lines
46 KiB
Python
Raw Normal View History

2020-01-22 11:15:56 -05:00
""" Data augmentation modules.
Features a custom implementaiton of RandAugment (RandAug), as well as a data augmentation modules allowing gradient propagation.
Typical usage:
aug_model = Augmented_model(Data_AugV5, model)
"""
2019-11-08 11:28:06 -05:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import *
#import kornia
#import random
2019-12-04 10:36:34 -05:00
import numpy as np
2019-11-08 11:28:06 -05:00
import copy
import transformations as TF
### Data augmenter ###
2019-11-14 21:42:00 -05:00
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
2020-01-20 16:10:17 -05:00
"""Data augmentation module with learnable parameters.
Applies transformations (TF) to batch of data.
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
The TF probabilities defines a distribution from which we sample the TF applied.
2020-01-24 11:50:45 -05:00
Be warry, that the order of sequential application of TF is not taken into account. See Data_augV7.
2020-01-20 16:10:17 -05:00
Attributes:
_data_augmentation (bool): Wether TF will be applied during forward pass.
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
_TF (list) : List of TF names.
_nb_tf (int) : Number of TF used.
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
2020-01-31 10:34:44 -05:00
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Beware using shared mag with basic color TF as their lowest magnitude is at PARAMETER_MAX/2.
2020-01-20 16:10:17 -05:00
_fixed_mag (bool): Wether to lock the TF magnitudes.
_fixed_prob (bool): Wether to lock the TF probabilies.
_samples (list): Sampled TF index during last forward pass.
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
_fixed_mix (bool): Wether we lock the mix distribution factor.
_params (nn.ParameterDict): Learnable parameters.
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
_reg_mask (list): Mask selecting the TF considered for the regularisation.
"""
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
2020-01-20 16:10:17 -05:00
"""Init Data_augv5.
Args:
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
"""
2019-11-14 21:17:54 -05:00
super(Data_augV5, self).__init__()
assert len(TF_dict)>0
2020-01-20 16:10:17 -05:00
assert N_TF>=0
2019-11-14 21:17:54 -05:00
self._data_augmentation = True
2020-01-17 11:08:59 -05:00
#TF
2019-11-14 21:42:00 -05:00
self._TF_dict = TF_dict
self._TF= list(self._TF_dict.keys())
2019-11-14 21:17:54 -05:00
self._nb_tf= len(self._TF)
2019-11-14 21:42:00 -05:00
self._N_seqTF = N_TF
2020-01-17 11:08:59 -05:00
#Mag
2019-11-18 14:18:15 -05:00
self._shared_mag = shared_mag
2019-11-18 16:48:51 -05:00
self._fixed_mag = fixed_mag
if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0:
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
self._fixed_mag=True
2019-11-14 21:17:54 -05:00
2020-01-17 11:08:59 -05:00
#Distribution
self._fixed_prob=fixed_prob
self._samples = []
self._mix_dist = False
if mix_dist != 0.0: #Mix dist
self._mix_dist = True
self._fixed_mix=True
if mix_dist is None: #Learn Mix dist
self._fixed_mix = False
mix_dist=0.5
2020-01-17 11:08:59 -05:00
#Params
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
2019-11-14 21:17:54 -05:00
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
2019-11-14 21:17:54 -05:00
})
2020-01-20 11:05:40 -05:00
#for tf in TF.TF_no_grad :
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
2019-11-19 21:46:14 -05:00
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
2019-11-19 15:37:29 -05:00
#Mag regularisation
if not self._fixed_mag:
2019-11-20 16:06:27 -05:00
if self._shared_mag :
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
else:
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
2019-11-19 15:37:29 -05:00
2019-11-14 21:17:54 -05:00
def forward(self, x):
2020-01-20 16:10:17 -05:00
""" Main method of the Data augmentation module.
Args:
x (Tensor): Batch of data.
Returns:
Tensor : Batch of tranformed data.
"""
2019-11-25 16:36:35 +00:00
self._samples = []
2019-11-25 16:43:23 +00:00
if self._data_augmentation:# and TF.random.random() < 0.5:
2019-11-14 21:17:54 -05:00
device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
2019-11-14 21:17:54 -05:00
if not self._mix_dist:
self._distrib = uniforme_dist
else:
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
2019-11-14 21:17:54 -05:00
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
for _ in range(self._N_seqTF):
2019-11-14 21:17:54 -05:00
sample = cat_distrib.sample()
self._samples.append(sample)
## Transformations ##
x = self.apply_TF(x, sample)
return x
2019-11-14 21:42:00 -05:00
2019-11-14 21:17:54 -05:00
def apply_TF(self, x, sampled_TF):
2020-01-20 16:10:17 -05:00
""" Applies the sampled transformations.
Args:
x (Tensor): Batch of data.
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
Returns:
Tensor: Batch of tranformed data.
"""
2019-11-14 21:17:54 -05:00
device = x.device
batch_size, channels, h, w = x.shape
2019-11-14 21:17:54 -05:00
smps_x=[]
2019-11-14 21:17:54 -05:00
for tf_idx in range(self._nb_tf):
mask = sampled_TF==tf_idx #Create selection mask
2020-01-17 11:08:59 -05:00
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
2019-11-14 21:17:54 -05:00
if smp_x.shape[0]!=0: #if there's data to TF
2019-11-18 14:18:15 -05:00
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
2019-11-18 16:48:51 -05:00
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
2019-11-14 21:17:54 -05:00
tf=self._TF[tf_idx]
2019-11-18 16:48:51 -05:00
#In place
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
#Out of place
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
idx= mask.nonzero()
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
x=x.scatter(dim=0, index=idx, src=smp_x)
2019-11-14 21:17:54 -05:00
return x
2019-11-19 15:37:29 -05:00
def adjust_param(self, soft=False): #Detach from gradient ?
2020-01-20 16:10:17 -05:00
""" Enforce limitations to the learned parameters.
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
Args:
2020-01-30 11:21:25 -05:00
soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False)
2020-01-20 16:10:17 -05:00
"""
2019-11-25 16:36:35 +00:00
if not self._fixed_prob:
if soft :
2020-01-30 11:21:25 -05:00
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
2019-11-25 16:36:35 +00:00
else:
2020-01-13 13:26:53 -05:00
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
2019-11-25 16:36:35 +00:00
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
2019-11-14 21:17:54 -05:00
2019-11-25 16:36:35 +00:00
if not self._fixed_mag:
2020-01-13 18:02:36 -05:00
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
2019-11-14 21:42:00 -05:00
if not self._fixed_mix:
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
2019-11-14 21:17:54 -05:00
def loss_weight(self):
2020-01-20 16:10:17 -05:00
""" Weights for the loss.
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction.
Do not take into account the order of application of the TF. See Data_augV7.
2020-01-20 16:10:17 -05:00
Returns:
Tensor : Loss weights.
"""
if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
2019-11-25 16:36:35 +00:00
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
2019-11-14 21:17:54 -05:00
2019-11-20 16:06:27 -05:00
#Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !)
2019-11-14 21:17:54 -05:00
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
for sample in self._samples:
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
2019-11-14 21:42:00 -05:00
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
2019-11-14 21:17:54 -05:00
w_loss += tmp_w
2019-11-25 16:36:35 +00:00
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
2019-11-14 21:17:54 -05:00
w_loss = torch.sum(w_loss,dim=1)
return w_loss
2019-11-19 15:37:29 -05:00
def reg_loss(self, reg_factor=0.005):
2020-01-20 16:10:17 -05:00
""" Regularisation term used to learn the magnitudes.
Use an L2 loss to encourage high magnitudes TF.
Args:
reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005)
Returns:
Tensor containing the regularisation loss value.
"""
2020-01-17 11:08:59 -05:00
if self._fixed_mag:
2019-11-20 16:06:27 -05:00
return torch.tensor(0)
else:
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
return max_mag_reg
2019-11-14 21:17:54 -05:00
def train(self, mode=True):
2020-01-20 16:10:17 -05:00
""" Set the module training mode.
Args:
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
"""
#if mode is None :
# mode=self._data_augmentation
2019-11-14 21:17:54 -05:00
self.augment(mode=mode) #Inutile si mode=None
super(Data_augV5, self).train(mode)
return self
2019-11-14 21:17:54 -05:00
def eval(self):
2020-01-20 16:10:17 -05:00
""" Set the module to evaluation mode.
"""
return self.train(mode=False)
2019-11-14 21:17:54 -05:00
def augment(self, mode=True):
2020-01-20 16:10:17 -05:00
""" Set the augmentation mode.
Args:
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
"""
2019-11-14 21:17:54 -05:00
self._data_augmentation=mode
2020-01-30 11:21:25 -05:00
def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation
2019-11-14 21:17:54 -05:00
def __getitem__(self, key):
2020-01-20 16:10:17 -05:00
"""Access to the learnable parameters
Args:
key (string): Name of the learnable parameter to access.
Returns:
nn.Parameter.
"""
2019-11-14 21:17:54 -05:00
return self._params[key]
def __str__(self):
2020-01-20 16:10:17 -05:00
"""Name of the module
Returns:
String containing the name of the module as well as the higher levels parameters.
"""
2019-11-25 16:36:35 +00:00
dist_param=''
if self._fixed_prob: dist_param+='Fx'
2019-11-18 16:48:51 -05:00
mag_param='Mag'
if self._fixed_mag: mag_param+= 'Fx'
if self._shared_mag: mag_param+= 'Sh'
2019-11-14 21:17:54 -05:00
if not self._mix_dist:
2019-11-25 16:43:23 +00:00
return "Data_augV5(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
elif self._fixed_mix:
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
2019-11-14 21:17:54 -05:00
else:
return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
2019-11-14 21:17:54 -05:00
class Data_augV7(nn.Module): #Proba sequentielles
"""Data augmentation module with learnable parameters.
Applies transformations (TF) to batch of data.
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
The TF probabilities defines a distribution from which we sample the TF applied.
Replace the use of TF by TF sets which are combinaisons of classic TF.
Attributes:
_data_augmentation (bool): Wether TF will be applied during forward pass.
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
_TF (list) : List of TF names.
_nb_tf (int) : Number of TF used.
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
2020-01-31 10:34:44 -05:00
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Beware using shared mag with basic color TF as their lowest magnitude is at PARAMETER_MAX/2.
_fixed_mag (bool): Wether to lock the TF magnitudes.
_fixed_prob (bool): Wether to lock the TF probabilies.
_samples (list): Sampled TF index during last forward pass.
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
_fixed_mix (bool): Wether we lock the mix distribution factor.
_params (nn.ParameterDict): Learnable parameters.
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
_reg_mask (list): Mask selecting the TF considered for the regularisation.
"""
2020-01-24 11:50:45 -05:00
def __init__(self, TF_dict=TF.TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
"""Init Data_augv7.
Args:
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
2020-01-24 11:50:45 -05:00
N_TF (int): Number of TF to be applied sequentially to each inputs. Minimum 2, otherwise prefer using Data_augV5. (default: 2)
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
"""
super(Data_augV7, self).__init__()
assert len(TF_dict)>0
assert N_TF>=0
2020-01-24 11:50:45 -05:00
if N_TF<2:
print("WARNING: Data_augv7 isn't designed to use less than 2 sequentials TF. Please use Data_augv5 instead.")
self._data_augmentation = True
#TF
self._TF_dict = TF_dict
self._TF= list(self._TF_dict.keys())
self._nb_tf= len(self._TF)
self._N_seqTF = N_TF
#Mag
self._shared_mag = shared_mag
self._fixed_mag = fixed_mag
if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0:
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
self._fixed_mag=True
#Distribution
self._fixed_prob=fixed_prob
self._samples = []
self._mix_dist = False
if mix_dist != 0.0: #Mix dist
self._mix_dist = True
self._fixed_mix=True
if mix_dist is None: #Learn Mix dist
self._fixed_mix = False
mix_dist=0.5
#TF sets
2020-01-24 11:50:45 -05:00
#import itertools
#itertools.product(range(self._nb_tf), repeat=self._N_seqTF)
#no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}} #Specific No consecutive ops
no_consecutive={idx for idx, t in enumerate(self._TF) if t not in {'Identity'}} #No consecutive same ops (except Identity)
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
2020-01-24 11:50:45 -05:00
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF (exclude cons_test arrangement)
TF_sets=[]
if set_size>1:
for i in range(n_TF):
if not cons_test(i, idx_prefix):
TF_sets += generate_TF_sets(n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i])
else:
TF_sets+=[[idx_prefix+[i]] for i in range(n_TF) if not cons_test(i, idx_prefix)]
return TF_sets
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
self._nb_TF_sets=len(self._TF_sets)
print("Number of TF sets:",self._nb_TF_sets)
2020-01-24 11:50:45 -05:00
#print(self._TF_sets)
self._prob_mem=torch.zeros(self._nb_TF_sets)
#Params
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_TF_sets)/self._nb_TF_sets), #Distribution prob uniforme
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
})
#for tf in TF.TF_no_grad :
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
#Mag regularisation
if not self._fixed_mag:
if self._shared_mag :
self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max
else:
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in TF.TF_ignore_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
def forward(self, x):
""" Main method of the Data augmentation module.
Args:
x (Tensor): Batch of data.
Returns:
Tensor : Batch of tranformed data.
"""
self._samples = None
if self._data_augmentation:# and TF.random.random() < 0.5:
device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1)
if not self._mix_dist:
self._distrib = uniforme_dist
else:
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_TF_sets), device=device)*self._distrib)
sample = cat_distrib.sample()
self._samples=sample
TF_samples=self._TF_sets[sample,:].to(device) #[Batch_size, TFseq]
for i in range(self._N_seqTF):
## Transformations ##
x = self.apply_TF(x, TF_samples[:,i])
return x
def apply_TF(self, x, sampled_TF):
""" Applies the sampled transformations.
Args:
x (Tensor): Batch of data.
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
Returns:
Tensor: Batch of tranformed data.
"""
device = x.device
batch_size, channels, h, w = x.shape
smps_x=[]
for tf_idx in range(self._nb_tf):
mask = sampled_TF==tf_idx #Create selection mask
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
if smp_x.shape[0]!=0: #if there's data to TF
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
tf=self._TF[tf_idx]
#In place
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
#Out of place
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
idx= mask.nonzero()
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
x=x.scatter(dim=0, index=idx, src=smp_x)
return x
def adjust_param(self, soft=False): #Detach from gradient ?
""" Enforce limitations to the learned parameters.
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
Args:
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
"""
if not self._fixed_prob:
if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
else:
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
if not self._fixed_mag:
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
if not self._fixed_mix:
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
def loss_weight(self):
""" Weights for the loss.
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction.
Returns:
Tensor : Loss weights.
"""
if self._samples is None : return 1 #Pas d'echantillon = pas de ponderation
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
w_loss = torch.zeros((self._samples.shape[0],self._nb_TF_sets), device=self._samples.device)
w_loss.scatter_(1, self._samples.view(-1,1), 1)
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
w_loss = torch.sum(w_loss,dim=1)
return w_loss
def reg_loss(self, reg_factor=0.005):
""" Regularisation term used to learn the magnitudes.
Use an L2 loss to encourage high magnitudes TF.
Args:
reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005)
Returns:
Tensor containing the regularisation loss value.
"""
if self._fixed_mag:
return torch.tensor(0)
else:
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
return max_mag_reg
2020-01-24 11:50:45 -05:00
def TF_prob(self):
""" Gives an estimation of the individual TF probabilities.
2020-01-24 11:50:45 -05:00
Be warry that the probability returned isn't exact. The TF distribution isn't fully represented by those.
Each probability should be taken individualy. They only represent the chance for a specific TF to be picked at least once.
Returms:
Tensor containing the single TF probabilities of applications.
"""
if torch.all(self._params['prob']!=self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
self._prob_mem=self._params['prob'].data.detach_()
self._single_TF_prob=torch.zeros(self._nb_tf)
for idx_tf in range(self._nb_tf):
for i, t_set in enumerate(self._TF_sets):
#uni, count = np.unique(t_set, return_counts=True)
#if idx_tf in uni:
# res[idx_tf]+=self._params['prob'][i]*int(count[np.where(uni==idx_tf)])
if idx_tf in t_set:
self._single_TF_prob[idx_tf]+=self._params['prob'][i]
return self._single_TF_prob
def train(self, mode=True):
""" Set the module training mode.
Args:
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
"""
#if mode is None :
# mode=self._data_augmentation
self.augment(mode=mode) #Inutile si mode=None
super(Data_augV7, self).train(mode)
return self
def eval(self):
""" Set the module to evaluation mode.
"""
return self.train(mode=False)
def augment(self, mode=True):
""" Set the augmentation mode.
Args:
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
"""
self._data_augmentation=mode
2020-01-30 11:21:25 -05:00
def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation
def __getitem__(self, key):
"""Access to the learnable parameters
Args:
key (string): Name of the learnable parameter to access.
Returns:
nn.Parameter.
"""
if key == 'prob': #Override prob access
return self.TF_prob()
return self._params[key]
def __str__(self):
"""Name of the module
Returns:
String containing the name of the module as well as the higher levels parameters.
"""
dist_param=''
if self._fixed_prob: dist_param+='Fx'
mag_param='Mag'
if self._fixed_mag: mag_param+= 'Fx'
if self._shared_mag: mag_param+= 'Sh'
if not self._mix_dist:
return "Data_augV7(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
elif self._fixed_mix:
return "Data_augV7(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
else:
return "Data_augV7(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
2020-01-20 16:10:17 -05:00
"""RandAugment implementation.
Applies transformations (TF) to batch of data.
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple. For the full definiton of the TF, see transformations.py.
The TF probabilities are ignored and, instead selected randomly.
Attributes:
_data_augmentation (bool): Wether TF will be applied during forward pass.
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
_TF (list) : List of TF names.
_nb_tf (int) : Number of TF used.
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Should be True.
_fixed_mag (bool): Wether to lock the TF magnitudes. Should be True.
_params (nn.ParameterDict): Data augmentation parameters.
"""
2019-11-27 12:54:19 -05:00
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
2020-01-20 16:10:17 -05:00
"""Init RandAug.
Args:
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
mag (float): Magnitude of the TF. Should be between [PARAMETER_MIN, PARAMETER_MAX] defined in transformations.py. (default: PARAMETER_MAX)
"""
2019-11-27 12:54:19 -05:00
super(RandAug, self).__init__()
self._data_augmentation = True
self._TF_dict = TF_dict
self._TF= list(self._TF_dict.keys())
self._nb_tf= len(self._TF)
self._N_seqTF = N_TF
self.mag=nn.Parameter(torch.tensor(float(mag)))
self._params = nn.ParameterDict({
2020-01-20 16:10:17 -05:00
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Ignored
2019-11-27 12:54:19 -05:00
"mag" : nn.Parameter(torch.tensor(float(mag))),
})
self._shared_mag = True
self._fixed_mag = True
2020-01-30 11:21:25 -05:00
self._fixed_prob=True
self._fixed_mix=True
2019-11-27 12:54:19 -05:00
2020-01-20 16:10:17 -05:00
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
2019-11-27 12:54:19 -05:00
def forward(self, x):
2020-01-20 16:10:17 -05:00
""" Main method of the Data augmentation module.
Args:
x (Tensor): Batch of data.
Returns:
Tensor : Batch of tranformed data.
"""
2019-11-27 12:54:19 -05:00
if self._data_augmentation:# and TF.random.random() < 0.5:
device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
for _ in range(self._N_seqTF):
## Echantillonage ## == sampled_ops = np.random.choice(transforms, N)
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist)
sample = cat_distrib.sample()
## Transformations ##
x = self.apply_TF(x, sample)
return x
def apply_TF(self, x, sampled_TF):
2020-01-20 16:10:17 -05:00
""" Applies the sampled transformations.
Args:
x (Tensor): Batch of data.
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
Returns:
Tensor: Batch of tranformed data.
"""
2019-11-27 12:54:19 -05:00
smps_x=[]
for tf_idx in range(self._nb_tf):
mask = sampled_TF==tf_idx #Create selection mask
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
if smp_x.shape[0]!=0: #if there's data to TF
magnitude=self._params["mag"].detach()
tf=self._TF[tf_idx]
#print(magnitude)
#In place
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
return x
def adjust_param(self, soft=False):
2020-01-20 16:10:17 -05:00
"""Not used
"""
2019-11-27 12:54:19 -05:00
pass #Pas de parametre a opti
def loss_weight(self):
2020-01-20 16:10:17 -05:00
"""Not used
"""
2019-11-27 12:54:19 -05:00
return 1 #Pas d'echantillon = pas de ponderation
def reg_loss(self, reg_factor=0.005):
2020-01-20 16:10:17 -05:00
"""Not used
"""
2019-11-27 12:54:19 -05:00
return torch.tensor(0) #Pas de regularisation
def train(self, mode=None):
2020-01-20 16:10:17 -05:00
""" Set the module training mode.
Args:
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
"""
2019-11-27 12:54:19 -05:00
if mode is None :
mode=self._data_augmentation
self.augment(mode=mode) #Inutile si mode=None
super(RandAug, self).train(mode)
def eval(self):
2020-01-20 16:10:17 -05:00
""" Set the module to evaluation mode.
"""
2019-11-27 12:54:19 -05:00
self.train(mode=False)
def augment(self, mode=True):
2020-01-20 16:10:17 -05:00
""" Set the augmentation mode.
Args:
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
"""
2019-11-27 12:54:19 -05:00
self._data_augmentation=mode
2020-01-30 11:21:25 -05:00
def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation
2019-11-27 12:54:19 -05:00
def __getitem__(self, key):
2020-01-20 16:10:17 -05:00
"""Access to the learnable parameters
Args:
key (string): Name of the learnable parameter to access.
Returns:
nn.Parameter.
"""
2019-11-27 12:54:19 -05:00
return self._params[key]
def __str__(self):
2020-01-20 16:10:17 -05:00
"""Name of the module
Returns:
String containing the name of the module as well as the higher levels parameters.
"""
2019-11-27 12:54:19 -05:00
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
### Models ###
2020-01-22 11:15:56 -05:00
import higher
class Higher_model(nn.Module):
"""Model wrapper for higher gradient tracking.
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
Keep in memory the orginial model and it's functionnal, higher, version.
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
Might not be needed anymore if Higher implement detach for fmodel.
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
see : https://github.com/facebookresearch/higher
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
TODO: Get rid of the original model if not needed by user.
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
Attributes:
_name (string): Name of the model.
_mods (nn.ModuleDict): Models (Orginial and Higher version).
"""
2020-01-31 16:43:10 -05:00
def __init__(self, model, model_name=None):
2020-01-22 11:15:56 -05:00
"""Init Higher_model.
Args:
model (nn.Module): Network for which higher gradients can be tracked.
2020-01-31 16:43:10 -05:00
model_name (string): Model name. (Default: Class name of model)
2020-01-22 11:15:56 -05:00
"""
super(Higher_model, self).__init__()
2019-11-27 12:54:19 -05:00
2020-01-31 16:43:10 -05:00
self._name = model_name if model_name else model.__class__.__name__ #model.__str__()
2020-01-22 11:15:56 -05:00
self._mods = nn.ModuleDict({
'original': model,
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
})
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
"""Get a differentiable version of an Optimizer.
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
Higher/Differentiable optimizer required to be used for higher gradient tracking.
Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step)
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called.
Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation.
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
Args:
opt (torch.optim): Optimizer to make differentiable.
grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None)
track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True)
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
Returns:
(Higher.DifferentiableOptimizer): Differentiable version of the optimizer.
"""
return higher.optim.get_diff_optim(opt,
self._mods['original'].parameters(),
fmodel=self._mods['functional'],
grad_callback=grad_callback,
track_higher_grads=track_higher_grads)
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
def forward(self, x):
""" Main method of the model.
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
Args:
x (Tensor): Batch of data.
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
Returns:
Tensor : Output of the network. Should be logits.
"""
return self._mods['functional'](x)
2019-11-27 12:54:19 -05:00
2020-01-22 11:15:56 -05:00
def detach_(self):
"""Detach from the graph.
Needed to limit the number of state kept in memory.
"""
tmp = self._mods['functional'].fast_params
self._mods['functional']._fast_params=[]
self._mods['functional'].update_params(tmp)
for p in self._mods['functional'].fast_params:
p.detach_().requires_grad_()
def state_dict(self):
"""Returns a dictionary containing a whole state of the module.
"""
return self._mods['functional'].state_dict()
2019-11-27 12:54:19 -05:00
def __getitem__(self, key):
2020-01-22 11:15:56 -05:00
"""Access to modules
Args:
key (string): Name of the module to access.
Returns:
nn.Module.
"""
return self._mods[key]
2019-11-27 12:54:19 -05:00
def __str__(self):
2020-01-22 11:15:56 -05:00
"""Name of the module
Returns:
String containing the name of the module.
"""
return self._name
2019-11-14 21:17:54 -05:00
2020-01-28 19:42:00 -05:00
from utils import clip_norm
from train_utils import compute_vaLoss
2019-11-08 11:28:06 -05:00
class Augmented_model(nn.Module):
2020-01-20 16:10:17 -05:00
"""Wrapper for a Data Augmentation module and a model.
Attributes:
_mods (nn.ModuleDict): A dictionary containing the modules.
_data_augmentation (bool): Wether data augmentation should be used.
2020-01-20 16:10:17 -05:00
"""
2019-11-08 11:28:06 -05:00
def __init__(self, data_augmenter, model):
2020-01-20 16:10:17 -05:00
"""Init Augmented Model.
By default, data augmentation will be performed.
Args:
data_augmenter (nn.Module): Data augmentation module.
model (nn.Module): Network.
"""
2019-11-08 11:28:06 -05:00
super(Augmented_model, self).__init__()
self._mods = nn.ModuleDict({
'data_aug': data_augmenter,
'model': model
})
self.augment(mode=True)
def forward(self, x):
2020-01-20 16:10:17 -05:00
""" Main method of the Augmented model.
Perform the forward pass of both modules.
Args:
x (Tensor): Batch of data.
Returns:
Tensor : Output of the networks. Should be logits.
"""
2019-11-08 11:28:06 -05:00
return self._mods['model'](self._mods['data_aug'](x))
def augment(self, mode=True):
2020-01-20 16:10:17 -05:00
""" Set the augmentation mode.
Args:
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
"""
2019-11-08 11:28:06 -05:00
self._data_augmentation=mode
self._mods['data_aug'].augment(mode)
2020-01-29 06:36:12 -05:00
#### Encapsulation Meta Opt ####
2020-01-28 19:42:00 -05:00
def start_bilevel_opt(self, inner_it, hp_list, opt_param, dl_val):
2020-01-29 06:36:12 -05:00
""" Set up Augmented Model for bi-level optimisation.
Create and keep in Augmented Model the necessary objects for meta-optimisation.
This allow for an almost transparent use by just hiding the bi-level optimisation (see ''run_dist_dataugV3'') by ::
model.step(loss)
See ''run_simple_smartaug'' for a complete example.
Args:
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step.
hp_list (list): List of hyper-parameters to be learned.
opt_param (dict): Dictionnary containing optimizers parameters.
dl_val (DataLoader): Data loader of validation data.
"""
self._it_count=0
self._in_it=inner_it
self._opt_param=opt_param
#Inner Opt
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
#Validation data
self._dl_val=dl_val
self._dl_val_it=iter(dl_val)
self._val_loss=0.
2020-01-28 19:42:00 -05:00
if inner_it==0 or len(hp_list)==0: #No meta-opt
print("No meta optimization")
2020-01-29 06:36:12 -05:00
#Inner Opt
self._diffopt = self._mods['model'].get_diffopt(
2020-01-28 19:42:00 -05:00
inner_opt,
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
track_higher_grads=False)
2020-01-29 06:36:12 -05:00
self._meta_opt=None
2020-01-28 19:42:00 -05:00
else: #Bi-level opt
print("Bi-Level optimization")
#Inner Opt
self._diffopt = self._mods['model'].get_diffopt(
inner_opt,
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
track_higher_grads=True)
#Meta Opt
self._meta_opt = torch.optim.Adam(hp_list, lr=opt_param['Meta']['lr'])
self._meta_opt.zero_grad()
def step(self, loss):
2020-01-29 06:36:12 -05:00
""" Perform a model update.
2020-01-28 19:42:00 -05:00
2020-01-29 06:36:12 -05:00
''start_bilevel_opt'' method needs to be called once before using this method.
Perform a step of inner optimization and, if needed, a step of meta optimization.
Replace ::
opt.zero_grad()
loss.backward()
opt.step()
val_loss=...
val_loss.backward()
meta_opt.step()
adjust_param()
detach()
meta_opt.zero_grad()
By ::
model.step(loss)
Args:
loss (Tensor): the training loss tensor.
"""
2020-01-28 19:42:00 -05:00
self._it_count+=1
self._diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
if(self._meta_opt and self._it_count>0 and self._it_count%self._in_it==0): #Perform Meta step
#print("meta")
self._val_loss = compute_vaLoss(model=self._mods['model'], dl_it=self._dl_val_it, dl=self._dl_val) + self._mods['data_aug'].reg_loss()
#print_graph(val_loss) #to visualize computational graph
self._val_loss.backward()
torch.nn.utils.clip_grad_norm_(self._mods['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
self._meta_opt.step()
#Adjust Hyper-parameters
self._mods['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
#For optimizer parameters, if needed
#for param_group in self._diffopt.param_groups:
# for param in list(self._opt_param['Inner'].keys())[1:]:
# param_group[param].data = param_group[param].data.clamp(min=1e-4)
#Reset gradients
self._diffopt.detach_()
self._mods['model'].detach_()
self._meta_opt.zero_grad()
self._it_count=0
2020-01-29 06:36:12 -05:00
def val_loss(self):
""" Get the validation loss.
Compute, if needed, the validation loss and returns it.
''start_bilevel_opt'' method needs to be called once before using this method.
Returns:
(Tensor) Validation loss on a single batch of data.
"""
if(self._meta_opt): #Bilevel opti
return self._val_loss
else:
return compute_vaLoss(model=self._mods['model'], dl_it=self._dl_val_it, dl=self._dl_val)
##########################
2020-01-28 19:42:00 -05:00
def train(self, mode=True):
2020-01-20 16:10:17 -05:00
""" Set the module training mode.
Args:
mode (bool): Wether to learn the parameter of the module. (default: None)
2020-01-20 16:10:17 -05:00
"""
#if mode is None :
# mode=self._data_augmentation
2019-11-08 11:28:06 -05:00
super(Augmented_model, self).train(mode)
self._mods['data_aug'].augment(mode=self._data_augmentation) #Restart if needed data augmentation
return self
2019-11-08 11:28:06 -05:00
def eval(self):
2020-01-20 16:10:17 -05:00
""" Set the module to evaluation mode.
"""
#return self.train(mode=False)
super(Augmented_model, self).train(mode=False)
self._mods['data_aug'].augment(mode=False)
return self
2019-11-08 11:28:06 -05:00
def items(self):
"""Return an iterable of the ModuleDict key/value pairs.
"""
return self._mods.items()
def update(self, modules):
2020-01-20 16:10:17 -05:00
"""Update the module dictionnary.
The new dictionnary should keep the same structure.
"""
assert(self._mods.keys()==modules.keys())
2019-11-08 11:28:06 -05:00
self._mods.update(modules)
def is_augmenting(self):
2020-01-20 16:10:17 -05:00
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
2019-11-08 11:28:06 -05:00
return self._data_augmentation
def TF_names(self):
2020-01-20 16:10:17 -05:00
""" Get the transformations names used by the data augmentation module.
Returns:
list : names of the transformations of the data augmentation module.
"""
2019-11-08 11:28:06 -05:00
try:
return self._mods['data_aug']._TF
except:
return None
def __getitem__(self, key):
2020-01-20 16:10:17 -05:00
"""Access to the modules.
Args:
key (string): Name of the module to access.
Returns:
nn.Module.
"""
2019-11-08 11:28:06 -05:00
return self._mods[key]
def __str__(self):
2020-01-20 16:10:17 -05:00
"""Name of the module
Returns:
String containing the name of the module as well as the higher levels parameters.
"""
2020-01-22 11:15:56 -05:00
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"