2020-01-22 11:15:56 -05:00
|
|
|
""" Data augmentation modules.
|
|
|
|
|
|
|
|
Features a custom implementaiton of RandAugment (RandAug), as well as a data augmentation modules allowing gradient propagation.
|
|
|
|
|
|
|
|
Typical usage:
|
|
|
|
|
|
|
|
aug_model = Augmented_model(Data_AugV5, model)
|
|
|
|
"""
|
|
|
|
|
2019-11-08 11:28:06 -05:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from torch.distributions import *
|
|
|
|
|
|
|
|
#import kornia
|
|
|
|
#import random
|
2019-12-04 10:36:34 -05:00
|
|
|
import numpy as np
|
2019-11-08 11:28:06 -05:00
|
|
|
import copy
|
|
|
|
|
|
|
|
import transformations as TF
|
|
|
|
|
2020-01-27 17:29:25 -05:00
|
|
|
### Data augmenter ###
|
2019-11-14 21:42:00 -05:00
|
|
|
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Data augmentation module with learnable parameters.
|
|
|
|
|
|
|
|
Applies transformations (TF) to batch of data.
|
|
|
|
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
|
|
|
|
The TF probabilities defines a distribution from which we sample the TF applied.
|
|
|
|
|
2020-01-24 11:50:45 -05:00
|
|
|
Be warry, that the order of sequential application of TF is not taken into account. See Data_augV7.
|
|
|
|
|
2020-01-20 16:10:17 -05:00
|
|
|
Attributes:
|
|
|
|
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
|
|
|
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
|
|
|
_TF (list) : List of TF names.
|
|
|
|
_nb_tf (int) : Number of TF used.
|
|
|
|
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
|
|
|
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF.
|
|
|
|
_fixed_mag (bool): Wether to lock the TF magnitudes.
|
|
|
|
_fixed_prob (bool): Wether to lock the TF probabilies.
|
|
|
|
_samples (list): Sampled TF index during last forward pass.
|
|
|
|
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
|
|
|
_fixed_mix (bool): Wether we lock the mix distribution factor.
|
|
|
|
_params (nn.ParameterDict): Learnable parameters.
|
|
|
|
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
|
|
|
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
|
|
|
"""
|
2020-01-15 16:55:03 -05:00
|
|
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Init Data_augv5.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
|
|
|
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
|
|
|
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
|
|
|
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
|
|
|
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
|
|
|
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
|
|
|
"""
|
2019-11-14 21:17:54 -05:00
|
|
|
super(Data_augV5, self).__init__()
|
|
|
|
assert len(TF_dict)>0
|
2020-01-20 16:10:17 -05:00
|
|
|
assert N_TF>=0
|
2019-11-14 21:17:54 -05:00
|
|
|
|
|
|
|
self._data_augmentation = True
|
|
|
|
|
2020-01-17 11:08:59 -05:00
|
|
|
#TF
|
2019-11-14 21:42:00 -05:00
|
|
|
self._TF_dict = TF_dict
|
|
|
|
self._TF= list(self._TF_dict.keys())
|
2019-11-14 21:17:54 -05:00
|
|
|
self._nb_tf= len(self._TF)
|
2019-11-14 21:42:00 -05:00
|
|
|
self._N_seqTF = N_TF
|
2020-01-17 11:08:59 -05:00
|
|
|
|
|
|
|
#Mag
|
2019-11-18 14:18:15 -05:00
|
|
|
self._shared_mag = shared_mag
|
2019-11-18 16:48:51 -05:00
|
|
|
self._fixed_mag = fixed_mag
|
2020-01-22 16:53:27 -05:00
|
|
|
if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0:
|
|
|
|
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
|
|
|
|
self._fixed_mag=True
|
2019-11-14 21:17:54 -05:00
|
|
|
|
2020-01-17 11:08:59 -05:00
|
|
|
#Distribution
|
|
|
|
self._fixed_prob=fixed_prob
|
|
|
|
self._samples = []
|
|
|
|
|
|
|
|
self._mix_dist = False
|
|
|
|
if mix_dist != 0.0: #Mix dist
|
|
|
|
self._mix_dist = True
|
|
|
|
|
2020-01-16 16:38:15 -05:00
|
|
|
self._fixed_mix=True
|
|
|
|
if mix_dist is None: #Learn Mix dist
|
|
|
|
self._fixed_mix = False
|
|
|
|
mix_dist=0.5
|
2020-01-17 11:08:59 -05:00
|
|
|
|
|
|
|
#Params
|
2020-01-15 16:55:03 -05:00
|
|
|
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
2019-11-14 21:17:54 -05:00
|
|
|
self._params = nn.ParameterDict({
|
|
|
|
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
2020-01-15 16:55:03 -05:00
|
|
|
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
|
|
|
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
2020-01-16 16:38:15 -05:00
|
|
|
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
|
2019-11-14 21:17:54 -05:00
|
|
|
})
|
|
|
|
|
2020-01-20 11:05:40 -05:00
|
|
|
#for tf in TF.TF_no_grad :
|
|
|
|
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
2019-11-19 21:46:14 -05:00
|
|
|
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
|
|
|
|
2019-11-19 15:37:29 -05:00
|
|
|
#Mag regularisation
|
|
|
|
if not self._fixed_mag:
|
2019-11-20 16:06:27 -05:00
|
|
|
if self._shared_mag :
|
|
|
|
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
|
|
|
|
else:
|
|
|
|
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
|
|
|
|
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
2019-11-19 15:37:29 -05:00
|
|
|
|
2019-11-14 21:17:54 -05:00
|
|
|
def forward(self, x):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Main method of the Data augmentation module.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Tensor): Batch of data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor : Batch of tranformed data.
|
|
|
|
"""
|
2019-11-25 16:36:35 +00:00
|
|
|
self._samples = []
|
2019-11-25 16:43:23 +00:00
|
|
|
if self._data_augmentation:# and TF.random.random() < 0.5:
|
2019-11-14 21:17:54 -05:00
|
|
|
device = x.device
|
|
|
|
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
|
|
|
|
|
|
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
|
|
|
|
2020-01-27 17:29:25 -05:00
|
|
|
## Echantillonage ##
|
|
|
|
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
2019-11-14 21:17:54 -05:00
|
|
|
|
2020-01-27 17:29:25 -05:00
|
|
|
if not self._mix_dist:
|
|
|
|
self._distrib = uniforme_dist
|
|
|
|
else:
|
|
|
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
|
|
|
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
|
|
|
|
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
2019-11-14 21:17:54 -05:00
|
|
|
|
2020-01-27 17:29:25 -05:00
|
|
|
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
|
|
|
|
|
|
|
for _ in range(self._N_seqTF):
|
|
|
|
|
2019-11-14 21:17:54 -05:00
|
|
|
sample = cat_distrib.sample()
|
|
|
|
self._samples.append(sample)
|
|
|
|
|
|
|
|
## Transformations ##
|
|
|
|
x = self.apply_TF(x, sample)
|
|
|
|
return x
|
2019-11-14 21:42:00 -05:00
|
|
|
|
2019-11-14 21:17:54 -05:00
|
|
|
def apply_TF(self, x, sampled_TF):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Applies the sampled transformations.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Tensor): Batch of data.
|
|
|
|
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: Batch of tranformed data.
|
|
|
|
"""
|
2019-11-14 21:17:54 -05:00
|
|
|
device = x.device
|
2019-11-18 12:53:23 -05:00
|
|
|
batch_size, channels, h, w = x.shape
|
2019-11-14 21:17:54 -05:00
|
|
|
smps_x=[]
|
2019-11-18 12:53:23 -05:00
|
|
|
|
2019-11-14 21:17:54 -05:00
|
|
|
for tf_idx in range(self._nb_tf):
|
|
|
|
mask = sampled_TF==tf_idx #Create selection mask
|
2020-01-17 11:08:59 -05:00
|
|
|
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
|
2019-11-14 21:17:54 -05:00
|
|
|
|
|
|
|
if smp_x.shape[0]!=0: #if there's data to TF
|
2019-11-18 14:18:15 -05:00
|
|
|
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
|
2019-11-18 16:48:51 -05:00
|
|
|
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
|
|
|
|
|
2019-11-14 21:17:54 -05:00
|
|
|
tf=self._TF[tf_idx]
|
|
|
|
|
2019-11-18 16:48:51 -05:00
|
|
|
#In place
|
|
|
|
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
|
|
|
|
|
|
|
#Out of place
|
2019-11-18 12:53:23 -05:00
|
|
|
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
|
|
|
|
idx= mask.nonzero()
|
|
|
|
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
|
|
|
x=x.scatter(dim=0, index=idx, src=smp_x)
|
|
|
|
|
2019-11-14 21:17:54 -05:00
|
|
|
return x
|
|
|
|
|
2019-11-19 15:37:29 -05:00
|
|
|
def adjust_param(self, soft=False): #Detach from gradient ?
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Enforce limitations to the learned parameters.
|
|
|
|
|
|
|
|
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
|
|
|
|
"""
|
2019-11-25 16:36:35 +00:00
|
|
|
if not self._fixed_prob:
|
|
|
|
if soft :
|
|
|
|
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
|
|
|
else:
|
2020-01-13 13:26:53 -05:00
|
|
|
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
2019-11-25 16:36:35 +00:00
|
|
|
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
2019-11-14 21:17:54 -05:00
|
|
|
|
2019-11-25 16:36:35 +00:00
|
|
|
if not self._fixed_mag:
|
2020-01-13 18:02:36 -05:00
|
|
|
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
2019-11-14 21:42:00 -05:00
|
|
|
|
2020-01-16 16:38:15 -05:00
|
|
|
if not self._fixed_mix:
|
|
|
|
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
|
|
|
|
2019-11-14 21:17:54 -05:00
|
|
|
def loss_weight(self):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Weights for the loss.
|
|
|
|
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
|
|
|
Should be applied to the loss before reduction.
|
|
|
|
|
2020-01-27 17:29:25 -05:00
|
|
|
Do not take into account the order of application of the TF. See Data_augV7.
|
2020-01-20 16:10:17 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor : Loss weights.
|
|
|
|
"""
|
2020-01-27 17:29:25 -05:00
|
|
|
if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
|
2019-11-25 16:36:35 +00:00
|
|
|
|
|
|
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
2019-11-14 21:17:54 -05:00
|
|
|
|
2019-11-20 16:06:27 -05:00
|
|
|
#Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !)
|
2019-11-14 21:17:54 -05:00
|
|
|
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
|
|
|
for sample in self._samples:
|
|
|
|
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
2019-11-14 21:42:00 -05:00
|
|
|
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
2019-11-14 21:17:54 -05:00
|
|
|
w_loss += tmp_w
|
|
|
|
|
2019-11-25 16:36:35 +00:00
|
|
|
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
2019-11-14 21:17:54 -05:00
|
|
|
w_loss = torch.sum(w_loss,dim=1)
|
|
|
|
return w_loss
|
|
|
|
|
2019-11-19 15:37:29 -05:00
|
|
|
def reg_loss(self, reg_factor=0.005):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Regularisation term used to learn the magnitudes.
|
|
|
|
Use an L2 loss to encourage high magnitudes TF.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005)
|
|
|
|
Returns:
|
|
|
|
Tensor containing the regularisation loss value.
|
|
|
|
"""
|
2020-01-17 11:08:59 -05:00
|
|
|
if self._fixed_mag:
|
2019-11-20 16:06:27 -05:00
|
|
|
return torch.tensor(0)
|
|
|
|
else:
|
|
|
|
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
2020-01-13 14:35:48 -05:00
|
|
|
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
|
|
|
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
|
|
|
return max_mag_reg
|
2019-11-14 21:17:54 -05:00
|
|
|
|
2020-01-20 17:09:31 -05:00
|
|
|
def train(self, mode=True):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Set the module training mode.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
|
|
|
"""
|
2020-01-20 17:09:31 -05:00
|
|
|
#if mode is None :
|
|
|
|
# mode=self._data_augmentation
|
2019-11-14 21:17:54 -05:00
|
|
|
self.augment(mode=mode) #Inutile si mode=None
|
|
|
|
super(Data_augV5, self).train(mode)
|
2020-01-20 17:09:31 -05:00
|
|
|
return self
|
2019-11-14 21:17:54 -05:00
|
|
|
|
|
|
|
def eval(self):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Set the module to evaluation mode.
|
|
|
|
"""
|
2020-01-20 17:09:31 -05:00
|
|
|
return self.train(mode=False)
|
2019-11-14 21:17:54 -05:00
|
|
|
|
|
|
|
def augment(self, mode=True):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Set the augmentation mode.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
|
|
|
"""
|
2019-11-14 21:17:54 -05:00
|
|
|
self._data_augmentation=mode
|
|
|
|
|
|
|
|
def __getitem__(self, key):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Access to the learnable parameters
|
|
|
|
Args:
|
|
|
|
key (string): Name of the learnable parameter to access.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
nn.Parameter.
|
|
|
|
"""
|
2019-11-14 21:17:54 -05:00
|
|
|
return self._params[key]
|
|
|
|
|
|
|
|
def __str__(self):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Name of the module
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
String containing the name of the module as well as the higher levels parameters.
|
|
|
|
"""
|
2019-11-25 16:36:35 +00:00
|
|
|
dist_param=''
|
|
|
|
if self._fixed_prob: dist_param+='Fx'
|
2019-11-18 16:48:51 -05:00
|
|
|
mag_param='Mag'
|
|
|
|
if self._fixed_mag: mag_param+= 'Fx'
|
|
|
|
if self._shared_mag: mag_param+= 'Sh'
|
2019-11-14 21:17:54 -05:00
|
|
|
if not self._mix_dist:
|
2019-11-25 16:43:23 +00:00
|
|
|
return "Data_augV5(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
2020-01-16 16:38:15 -05:00
|
|
|
elif self._fixed_mix:
|
|
|
|
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
2019-11-14 21:17:54 -05:00
|
|
|
else:
|
2020-01-16 16:38:15 -05:00
|
|
|
return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
2019-11-14 21:17:54 -05:00
|
|
|
|
2020-01-22 16:53:27 -05:00
|
|
|
class Data_augV7(nn.Module): #Proba sequentielles
|
|
|
|
"""Data augmentation module with learnable parameters.
|
|
|
|
|
|
|
|
Applies transformations (TF) to batch of data.
|
|
|
|
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
|
|
|
|
The TF probabilities defines a distribution from which we sample the TF applied.
|
|
|
|
|
|
|
|
Replace the use of TF by TF sets which are combinaisons of classic TF.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
|
|
|
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
|
|
|
_TF (list) : List of TF names.
|
|
|
|
_nb_tf (int) : Number of TF used.
|
|
|
|
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
|
|
|
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF.
|
|
|
|
_fixed_mag (bool): Wether to lock the TF magnitudes.
|
|
|
|
_fixed_prob (bool): Wether to lock the TF probabilies.
|
|
|
|
_samples (list): Sampled TF index during last forward pass.
|
|
|
|
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
|
|
|
_fixed_mix (bool): Wether we lock the mix distribution factor.
|
|
|
|
_params (nn.ParameterDict): Learnable parameters.
|
|
|
|
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
|
|
|
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
|
|
|
"""
|
2020-01-24 11:50:45 -05:00
|
|
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
2020-01-22 16:53:27 -05:00
|
|
|
"""Init Data_augv7.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
2020-01-24 11:50:45 -05:00
|
|
|
N_TF (int): Number of TF to be applied sequentially to each inputs. Minimum 2, otherwise prefer using Data_augV5. (default: 2)
|
2020-01-22 16:53:27 -05:00
|
|
|
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
|
|
|
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
|
|
|
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
|
|
|
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
|
|
|
"""
|
|
|
|
super(Data_augV7, self).__init__()
|
|
|
|
assert len(TF_dict)>0
|
|
|
|
assert N_TF>=0
|
2020-01-24 11:50:45 -05:00
|
|
|
|
|
|
|
if N_TF<2:
|
|
|
|
print("WARNING: Data_augv7 isn't designed to use less than 2 sequentials TF. Please use Data_augv5 instead.")
|
2020-01-22 16:53:27 -05:00
|
|
|
|
|
|
|
self._data_augmentation = True
|
|
|
|
|
|
|
|
#TF
|
|
|
|
self._TF_dict = TF_dict
|
|
|
|
self._TF= list(self._TF_dict.keys())
|
|
|
|
self._nb_tf= len(self._TF)
|
|
|
|
self._N_seqTF = N_TF
|
|
|
|
|
|
|
|
#Mag
|
|
|
|
self._shared_mag = shared_mag
|
|
|
|
self._fixed_mag = fixed_mag
|
|
|
|
if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0:
|
|
|
|
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
|
|
|
|
self._fixed_mag=True
|
|
|
|
|
|
|
|
#Distribution
|
|
|
|
self._fixed_prob=fixed_prob
|
|
|
|
self._samples = []
|
|
|
|
|
|
|
|
self._mix_dist = False
|
|
|
|
if mix_dist != 0.0: #Mix dist
|
|
|
|
self._mix_dist = True
|
|
|
|
|
|
|
|
self._fixed_mix=True
|
|
|
|
if mix_dist is None: #Learn Mix dist
|
|
|
|
self._fixed_mix = False
|
|
|
|
mix_dist=0.5
|
|
|
|
|
|
|
|
#TF sets
|
2020-01-24 11:50:45 -05:00
|
|
|
#import itertools
|
|
|
|
#itertools.product(range(self._nb_tf), repeat=self._N_seqTF)
|
|
|
|
|
|
|
|
#no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}} #Specific No consecutive ops
|
|
|
|
no_consecutive={idx for idx, t in enumerate(self._TF) if t not in {'Identity'}} #No consecutive same ops (except Identity)
|
2020-01-22 16:53:27 -05:00
|
|
|
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
|
2020-01-24 11:50:45 -05:00
|
|
|
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF (exclude cons_test arrangement)
|
2020-01-22 16:53:27 -05:00
|
|
|
TF_sets=[]
|
|
|
|
if set_size>1:
|
|
|
|
for i in range(n_TF):
|
|
|
|
if not cons_test(i, idx_prefix):
|
|
|
|
TF_sets += generate_TF_sets(n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i])
|
|
|
|
else:
|
|
|
|
TF_sets+=[[idx_prefix+[i]] for i in range(n_TF) if not cons_test(i, idx_prefix)]
|
|
|
|
return TF_sets
|
|
|
|
|
|
|
|
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
|
|
|
|
self._nb_TF_sets=len(self._TF_sets)
|
|
|
|
print("Number of TF sets:",self._nb_TF_sets)
|
2020-01-24 11:50:45 -05:00
|
|
|
#print(self._TF_sets)
|
|
|
|
self._prob_mem=torch.zeros(self._nb_TF_sets)
|
2020-01-22 16:53:27 -05:00
|
|
|
|
|
|
|
#Params
|
|
|
|
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
|
|
|
self._params = nn.ParameterDict({
|
|
|
|
"prob": nn.Parameter(torch.ones(self._nb_TF_sets)/self._nb_TF_sets), #Distribution prob uniforme
|
|
|
|
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
|
|
|
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
|
|
|
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
|
|
|
|
})
|
|
|
|
|
|
|
|
#for tf in TF.TF_no_grad :
|
|
|
|
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
|
|
|
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
|
|
|
|
|
|
|
#Mag regularisation
|
|
|
|
if not self._fixed_mag:
|
|
|
|
if self._shared_mag :
|
|
|
|
self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max
|
|
|
|
else:
|
|
|
|
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in TF.TF_ignore_mag]
|
|
|
|
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
""" Main method of the Data augmentation module.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Tensor): Batch of data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor : Batch of tranformed data.
|
|
|
|
"""
|
|
|
|
self._samples = None
|
|
|
|
if self._data_augmentation:# and TF.random.random() < 0.5:
|
|
|
|
device = x.device
|
|
|
|
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
|
|
|
|
|
|
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
|
|
|
|
|
|
|
|
|
|
|
## Echantillonage ##
|
|
|
|
uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1)
|
|
|
|
|
|
|
|
if not self._mix_dist:
|
|
|
|
self._distrib = uniforme_dist
|
|
|
|
else:
|
|
|
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
|
|
|
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
|
|
|
|
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
|
|
|
|
|
|
|
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_TF_sets), device=device)*self._distrib)
|
|
|
|
sample = cat_distrib.sample()
|
|
|
|
|
|
|
|
self._samples=sample
|
|
|
|
TF_samples=self._TF_sets[sample,:].to(device) #[Batch_size, TFseq]
|
|
|
|
|
|
|
|
for i in range(self._N_seqTF):
|
|
|
|
## Transformations ##
|
|
|
|
x = self.apply_TF(x, TF_samples[:,i])
|
|
|
|
return x
|
|
|
|
|
|
|
|
def apply_TF(self, x, sampled_TF):
|
|
|
|
""" Applies the sampled transformations.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Tensor): Batch of data.
|
|
|
|
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: Batch of tranformed data.
|
|
|
|
"""
|
|
|
|
device = x.device
|
|
|
|
batch_size, channels, h, w = x.shape
|
|
|
|
smps_x=[]
|
|
|
|
|
|
|
|
for tf_idx in range(self._nb_tf):
|
|
|
|
mask = sampled_TF==tf_idx #Create selection mask
|
|
|
|
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
|
|
|
|
|
|
|
|
if smp_x.shape[0]!=0: #if there's data to TF
|
|
|
|
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
|
|
|
|
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
|
|
|
|
|
|
|
|
tf=self._TF[tf_idx]
|
|
|
|
|
|
|
|
#In place
|
|
|
|
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
|
|
|
|
|
|
|
#Out of place
|
|
|
|
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
|
|
|
|
idx= mask.nonzero()
|
|
|
|
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
|
|
|
x=x.scatter(dim=0, index=idx, src=smp_x)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
def adjust_param(self, soft=False): #Detach from gradient ?
|
|
|
|
""" Enforce limitations to the learned parameters.
|
|
|
|
|
|
|
|
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
|
|
|
|
"""
|
|
|
|
if not self._fixed_prob:
|
|
|
|
if soft :
|
|
|
|
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
|
|
|
else:
|
|
|
|
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
|
|
|
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
|
|
|
|
|
|
|
if not self._fixed_mag:
|
|
|
|
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
|
|
|
|
|
|
|
if not self._fixed_mix:
|
|
|
|
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
|
|
|
|
|
|
|
def loss_weight(self):
|
|
|
|
""" Weights for the loss.
|
|
|
|
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
|
|
|
Should be applied to the loss before reduction.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor : Loss weights.
|
|
|
|
"""
|
|
|
|
if self._samples is None : return 1 #Pas d'echantillon = pas de ponderation
|
|
|
|
|
|
|
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
|
|
|
|
|
|
|
w_loss = torch.zeros((self._samples.shape[0],self._nb_TF_sets), device=self._samples.device)
|
|
|
|
w_loss.scatter_(1, self._samples.view(-1,1), 1)
|
|
|
|
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
|
|
|
w_loss = torch.sum(w_loss,dim=1)
|
|
|
|
return w_loss
|
|
|
|
|
|
|
|
def reg_loss(self, reg_factor=0.005):
|
|
|
|
""" Regularisation term used to learn the magnitudes.
|
|
|
|
Use an L2 loss to encourage high magnitudes TF.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005)
|
|
|
|
Returns:
|
|
|
|
Tensor containing the regularisation loss value.
|
|
|
|
"""
|
|
|
|
if self._fixed_mag:
|
|
|
|
return torch.tensor(0)
|
|
|
|
else:
|
|
|
|
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
|
|
|
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
|
|
|
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
|
|
|
return max_mag_reg
|
|
|
|
|
2020-01-24 11:50:45 -05:00
|
|
|
def TF_prob(self):
|
|
|
|
""" Gives an estimation of the individual TF probabilities.
|
2020-01-22 16:53:27 -05:00
|
|
|
|
2020-01-24 11:50:45 -05:00
|
|
|
Be warry that the probability returned isn't exact. The TF distribution isn't fully represented by those.
|
|
|
|
Each probability should be taken individualy. They only represent the chance for a specific TF to be picked at least once.
|
|
|
|
|
|
|
|
Returms:
|
|
|
|
Tensor containing the single TF probabilities of applications.
|
|
|
|
"""
|
|
|
|
if torch.all(self._params['prob']!=self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
|
|
|
|
self._prob_mem=self._params['prob'].data.detach_()
|
|
|
|
self._single_TF_prob=torch.zeros(self._nb_tf)
|
|
|
|
for idx_tf in range(self._nb_tf):
|
|
|
|
for i, t_set in enumerate(self._TF_sets):
|
|
|
|
#uni, count = np.unique(t_set, return_counts=True)
|
|
|
|
#if idx_tf in uni:
|
|
|
|
# res[idx_tf]+=self._params['prob'][i]*int(count[np.where(uni==idx_tf)])
|
|
|
|
if idx_tf in t_set:
|
|
|
|
self._single_TF_prob[idx_tf]+=self._params['prob'][i]
|
|
|
|
|
|
|
|
return self._single_TF_prob
|
2020-01-22 16:53:27 -05:00
|
|
|
|
|
|
|
def train(self, mode=True):
|
|
|
|
""" Set the module training mode.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
|
|
|
"""
|
|
|
|
#if mode is None :
|
|
|
|
# mode=self._data_augmentation
|
|
|
|
self.augment(mode=mode) #Inutile si mode=None
|
|
|
|
super(Data_augV7, self).train(mode)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def eval(self):
|
|
|
|
""" Set the module to evaluation mode.
|
|
|
|
"""
|
|
|
|
return self.train(mode=False)
|
|
|
|
|
|
|
|
def augment(self, mode=True):
|
|
|
|
""" Set the augmentation mode.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
|
|
|
"""
|
|
|
|
self._data_augmentation=mode
|
|
|
|
|
|
|
|
def __getitem__(self, key):
|
|
|
|
"""Access to the learnable parameters
|
|
|
|
Args:
|
|
|
|
key (string): Name of the learnable parameter to access.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
nn.Parameter.
|
|
|
|
"""
|
|
|
|
if key == 'prob': #Override prob access
|
|
|
|
return self.TF_prob()
|
|
|
|
return self._params[key]
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
"""Name of the module
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
String containing the name of the module as well as the higher levels parameters.
|
|
|
|
"""
|
|
|
|
dist_param=''
|
|
|
|
if self._fixed_prob: dist_param+='Fx'
|
|
|
|
mag_param='Mag'
|
|
|
|
if self._fixed_mag: mag_param+= 'Fx'
|
|
|
|
if self._shared_mag: mag_param+= 'Sh'
|
|
|
|
if not self._mix_dist:
|
|
|
|
return "Data_augV7(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
|
|
|
elif self._fixed_mix:
|
|
|
|
return "Data_augV7(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
|
|
|
else:
|
|
|
|
return "Data_augV7(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
|
|
|
|
2019-11-27 17:19:51 -05:00
|
|
|
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
2020-01-20 16:10:17 -05:00
|
|
|
"""RandAugment implementation.
|
|
|
|
|
|
|
|
Applies transformations (TF) to batch of data.
|
|
|
|
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple. For the full definiton of the TF, see transformations.py.
|
|
|
|
The TF probabilities are ignored and, instead selected randomly.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
|
|
|
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
|
|
|
_TF (list) : List of TF names.
|
|
|
|
_nb_tf (int) : Number of TF used.
|
|
|
|
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
|
|
|
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Should be True.
|
|
|
|
_fixed_mag (bool): Wether to lock the TF magnitudes. Should be True.
|
|
|
|
_params (nn.ParameterDict): Data augmentation parameters.
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Init RandAug.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
|
|
|
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
|
|
|
mag (float): Magnitude of the TF. Should be between [PARAMETER_MIN, PARAMETER_MAX] defined in transformations.py. (default: PARAMETER_MAX)
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
super(RandAug, self).__init__()
|
|
|
|
|
|
|
|
self._data_augmentation = True
|
|
|
|
|
|
|
|
self._TF_dict = TF_dict
|
|
|
|
self._TF= list(self._TF_dict.keys())
|
|
|
|
self._nb_tf= len(self._TF)
|
|
|
|
self._N_seqTF = N_TF
|
|
|
|
|
|
|
|
self.mag=nn.Parameter(torch.tensor(float(mag)))
|
|
|
|
self._params = nn.ParameterDict({
|
2020-01-20 16:10:17 -05:00
|
|
|
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Ignored
|
2019-11-27 12:54:19 -05:00
|
|
|
"mag" : nn.Parameter(torch.tensor(float(mag))),
|
|
|
|
})
|
|
|
|
self._shared_mag = True
|
|
|
|
self._fixed_mag = True
|
|
|
|
|
2020-01-20 16:10:17 -05:00
|
|
|
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
|
|
|
|
2019-11-27 12:54:19 -05:00
|
|
|
def forward(self, x):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Main method of the Data augmentation module.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Tensor): Batch of data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor : Batch of tranformed data.
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
if self._data_augmentation:# and TF.random.random() < 0.5:
|
|
|
|
device = x.device
|
|
|
|
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
|
|
|
|
|
|
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
|
|
|
|
|
|
|
for _ in range(self._N_seqTF):
|
|
|
|
## Echantillonage ## == sampled_ops = np.random.choice(transforms, N)
|
|
|
|
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
|
|
|
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist)
|
|
|
|
sample = cat_distrib.sample()
|
|
|
|
|
|
|
|
## Transformations ##
|
|
|
|
x = self.apply_TF(x, sample)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def apply_TF(self, x, sampled_TF):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Applies the sampled transformations.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Tensor): Batch of data.
|
|
|
|
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: Batch of tranformed data.
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
smps_x=[]
|
|
|
|
|
|
|
|
for tf_idx in range(self._nb_tf):
|
|
|
|
mask = sampled_TF==tf_idx #Create selection mask
|
|
|
|
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
|
|
|
|
|
|
|
|
if smp_x.shape[0]!=0: #if there's data to TF
|
|
|
|
magnitude=self._params["mag"].detach()
|
|
|
|
|
|
|
|
tf=self._TF[tf_idx]
|
|
|
|
#print(magnitude)
|
|
|
|
|
|
|
|
#In place
|
|
|
|
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
def adjust_param(self, soft=False):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Not used
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
pass #Pas de parametre a opti
|
|
|
|
|
|
|
|
def loss_weight(self):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Not used
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
return 1 #Pas d'echantillon = pas de ponderation
|
|
|
|
|
|
|
|
def reg_loss(self, reg_factor=0.005):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Not used
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
return torch.tensor(0) #Pas de regularisation
|
|
|
|
|
|
|
|
def train(self, mode=None):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Set the module training mode.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
if mode is None :
|
|
|
|
mode=self._data_augmentation
|
|
|
|
self.augment(mode=mode) #Inutile si mode=None
|
|
|
|
super(RandAug, self).train(mode)
|
|
|
|
|
|
|
|
def eval(self):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Set the module to evaluation mode.
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
self.train(mode=False)
|
|
|
|
|
|
|
|
def augment(self, mode=True):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Set the augmentation mode.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
self._data_augmentation=mode
|
|
|
|
|
|
|
|
def __getitem__(self, key):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Access to the learnable parameters
|
|
|
|
Args:
|
|
|
|
key (string): Name of the learnable parameter to access.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
nn.Parameter.
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
return self._params[key]
|
|
|
|
|
|
|
|
def __str__(self):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Name of the module
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
String containing the name of the module as well as the higher levels parameters.
|
|
|
|
"""
|
2019-11-27 12:54:19 -05:00
|
|
|
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
|
|
|
|
2020-01-27 17:29:25 -05:00
|
|
|
### Models ###
|
2020-01-22 11:15:56 -05:00
|
|
|
import higher
|
|
|
|
class Higher_model(nn.Module):
|
|
|
|
"""Model wrapper for higher gradient tracking.
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
Keep in memory the orginial model and it's functionnal, higher, version.
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
Might not be needed anymore if Higher implement detach for fmodel.
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
see : https://github.com/facebookresearch/higher
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
TODO: Get rid of the original model if not needed by user.
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
Attributes:
|
|
|
|
_name (string): Name of the model.
|
|
|
|
_mods (nn.ModuleDict): Models (Orginial and Higher version).
|
|
|
|
"""
|
|
|
|
def __init__(self, model):
|
|
|
|
"""Init Higher_model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): Network for which higher gradients can be tracked.
|
|
|
|
"""
|
|
|
|
super(Higher_model, self).__init__()
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
self._name = model.__str__()
|
|
|
|
self._mods = nn.ModuleDict({
|
|
|
|
'original': model,
|
|
|
|
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
|
|
|
})
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
|
|
|
|
"""Get a differentiable version of an Optimizer.
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
Higher/Differentiable optimizer required to be used for higher gradient tracking.
|
|
|
|
Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step)
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called.
|
|
|
|
Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation.
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
Args:
|
|
|
|
opt (torch.optim): Optimizer to make differentiable.
|
|
|
|
grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None)
|
|
|
|
track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True)
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
Returns:
|
|
|
|
(Higher.DifferentiableOptimizer): Differentiable version of the optimizer.
|
|
|
|
"""
|
|
|
|
return higher.optim.get_diff_optim(opt,
|
|
|
|
self._mods['original'].parameters(),
|
|
|
|
fmodel=self._mods['functional'],
|
|
|
|
grad_callback=grad_callback,
|
|
|
|
track_higher_grads=track_higher_grads)
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
def forward(self, x):
|
|
|
|
""" Main method of the model.
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
Args:
|
|
|
|
x (Tensor): Batch of data.
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
Returns:
|
|
|
|
Tensor : Output of the network. Should be logits.
|
|
|
|
"""
|
|
|
|
return self._mods['functional'](x)
|
2019-11-27 12:54:19 -05:00
|
|
|
|
2020-01-22 11:15:56 -05:00
|
|
|
def detach_(self):
|
|
|
|
"""Detach from the graph.
|
|
|
|
|
|
|
|
Needed to limit the number of state kept in memory.
|
|
|
|
"""
|
|
|
|
tmp = self._mods['functional'].fast_params
|
|
|
|
self._mods['functional']._fast_params=[]
|
|
|
|
self._mods['functional'].update_params(tmp)
|
|
|
|
for p in self._mods['functional'].fast_params:
|
|
|
|
p.detach_().requires_grad_()
|
|
|
|
|
|
|
|
def state_dict(self):
|
|
|
|
"""Returns a dictionary containing a whole state of the module.
|
|
|
|
"""
|
|
|
|
return self._mods['functional'].state_dict()
|
2019-11-27 12:54:19 -05:00
|
|
|
|
|
|
|
def __getitem__(self, key):
|
2020-01-22 11:15:56 -05:00
|
|
|
"""Access to modules
|
|
|
|
Args:
|
|
|
|
key (string): Name of the module to access.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
nn.Module.
|
|
|
|
"""
|
|
|
|
return self._mods[key]
|
2019-11-27 12:54:19 -05:00
|
|
|
|
|
|
|
def __str__(self):
|
2020-01-22 11:15:56 -05:00
|
|
|
"""Name of the module
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
String containing the name of the module.
|
|
|
|
"""
|
|
|
|
return self._name
|
2019-11-14 21:17:54 -05:00
|
|
|
|
2020-01-28 19:42:00 -05:00
|
|
|
from utils import clip_norm
|
|
|
|
from train_utils import compute_vaLoss
|
2019-11-08 11:28:06 -05:00
|
|
|
class Augmented_model(nn.Module):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Wrapper for a Data Augmentation module and a model.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
_mods (nn.ModuleDict): A dictionary containing the modules.
|
2020-01-20 17:09:31 -05:00
|
|
|
_data_augmentation (bool): Wether data augmentation should be used.
|
2020-01-20 16:10:17 -05:00
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
def __init__(self, data_augmenter, model):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Init Augmented Model.
|
|
|
|
|
|
|
|
By default, data augmentation will be performed.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
data_augmenter (nn.Module): Data augmentation module.
|
|
|
|
model (nn.Module): Network.
|
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
super(Augmented_model, self).__init__()
|
|
|
|
|
|
|
|
self._mods = nn.ModuleDict({
|
|
|
|
'data_aug': data_augmenter,
|
|
|
|
'model': model
|
|
|
|
})
|
|
|
|
|
|
|
|
self.augment(mode=True)
|
|
|
|
|
|
|
|
def forward(self, x):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Main method of the Augmented model.
|
|
|
|
|
|
|
|
Perform the forward pass of both modules.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Tensor): Batch of data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor : Output of the networks. Should be logits.
|
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
return self._mods['model'](self._mods['data_aug'](x))
|
|
|
|
|
|
|
|
def augment(self, mode=True):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Set the augmentation mode.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
self._data_augmentation=mode
|
|
|
|
self._mods['data_aug'].augment(mode)
|
|
|
|
|
2020-01-29 06:36:12 -05:00
|
|
|
#### Encapsulation Meta Opt ####
|
2020-01-28 19:42:00 -05:00
|
|
|
def start_bilevel_opt(self, inner_it, hp_list, opt_param, dl_val):
|
2020-01-29 06:36:12 -05:00
|
|
|
""" Set up Augmented Model for bi-level optimisation.
|
|
|
|
|
|
|
|
Create and keep in Augmented Model the necessary objects for meta-optimisation.
|
|
|
|
This allow for an almost transparent use by just hiding the bi-level optimisation (see ''run_dist_dataugV3'') by ::
|
|
|
|
|
|
|
|
model.step(loss)
|
|
|
|
|
|
|
|
See ''run_simple_smartaug'' for a complete example.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step.
|
|
|
|
hp_list (list): List of hyper-parameters to be learned.
|
|
|
|
opt_param (dict): Dictionnary containing optimizers parameters.
|
|
|
|
dl_val (DataLoader): Data loader of validation data.
|
|
|
|
"""
|
|
|
|
|
|
|
|
self._it_count=0
|
|
|
|
self._in_it=inner_it
|
|
|
|
|
|
|
|
self._opt_param=opt_param
|
|
|
|
#Inner Opt
|
|
|
|
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
|
|
|
|
|
|
|
#Validation data
|
|
|
|
self._dl_val=dl_val
|
|
|
|
self._dl_val_it=iter(dl_val)
|
|
|
|
self._val_loss=0.
|
2020-01-28 19:42:00 -05:00
|
|
|
|
|
|
|
if inner_it==0 or len(hp_list)==0: #No meta-opt
|
|
|
|
print("No meta optimization")
|
|
|
|
|
2020-01-29 06:36:12 -05:00
|
|
|
#Inner Opt
|
|
|
|
self._diffopt = self._mods['model'].get_diffopt(
|
2020-01-28 19:42:00 -05:00
|
|
|
inner_opt,
|
|
|
|
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
|
|
|
track_higher_grads=False)
|
|
|
|
|
2020-01-29 06:36:12 -05:00
|
|
|
self._meta_opt=None
|
|
|
|
|
2020-01-28 19:42:00 -05:00
|
|
|
else: #Bi-level opt
|
|
|
|
print("Bi-Level optimization")
|
|
|
|
|
|
|
|
#Inner Opt
|
|
|
|
self._diffopt = self._mods['model'].get_diffopt(
|
|
|
|
inner_opt,
|
|
|
|
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
|
|
|
track_higher_grads=True)
|
|
|
|
|
|
|
|
#Meta Opt
|
|
|
|
self._meta_opt = torch.optim.Adam(hp_list, lr=opt_param['Meta']['lr'])
|
|
|
|
self._meta_opt.zero_grad()
|
|
|
|
|
|
|
|
def step(self, loss):
|
2020-01-29 06:36:12 -05:00
|
|
|
""" Perform a model update.
|
2020-01-28 19:42:00 -05:00
|
|
|
|
2020-01-29 06:36:12 -05:00
|
|
|
''start_bilevel_opt'' method needs to be called once before using this method.
|
|
|
|
|
|
|
|
Perform a step of inner optimization and, if needed, a step of meta optimization.
|
|
|
|
Replace ::
|
|
|
|
|
|
|
|
opt.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
opt.step()
|
|
|
|
|
|
|
|
val_loss=...
|
|
|
|
val_loss.backward()
|
|
|
|
meta_opt.step()
|
|
|
|
adjust_param()
|
|
|
|
detach()
|
|
|
|
meta_opt.zero_grad()
|
|
|
|
|
|
|
|
By ::
|
|
|
|
|
|
|
|
model.step(loss)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
loss (Tensor): the training loss tensor.
|
|
|
|
"""
|
2020-01-28 19:42:00 -05:00
|
|
|
self._it_count+=1
|
|
|
|
self._diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
|
|
|
|
|
|
|
if(self._meta_opt and self._it_count>0 and self._it_count%self._in_it==0): #Perform Meta step
|
|
|
|
#print("meta")
|
|
|
|
self._val_loss = compute_vaLoss(model=self._mods['model'], dl_it=self._dl_val_it, dl=self._dl_val) + self._mods['data_aug'].reg_loss()
|
|
|
|
#print_graph(val_loss) #to visualize computational graph
|
|
|
|
self._val_loss.backward()
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self._mods['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
|
|
|
|
|
|
|
self._meta_opt.step()
|
|
|
|
|
|
|
|
#Adjust Hyper-parameters
|
|
|
|
self._mods['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
|
|
|
|
|
|
|
#For optimizer parameters, if needed
|
|
|
|
#for param_group in self._diffopt.param_groups:
|
|
|
|
# for param in list(self._opt_param['Inner'].keys())[1:]:
|
|
|
|
# param_group[param].data = param_group[param].data.clamp(min=1e-4)
|
|
|
|
|
|
|
|
#Reset gradients
|
|
|
|
self._diffopt.detach_()
|
|
|
|
self._mods['model'].detach_()
|
|
|
|
self._meta_opt.zero_grad()
|
|
|
|
|
|
|
|
self._it_count=0
|
|
|
|
|
2020-01-29 06:36:12 -05:00
|
|
|
def val_loss(self):
|
|
|
|
""" Get the validation loss.
|
|
|
|
|
|
|
|
Compute, if needed, the validation loss and returns it.
|
|
|
|
|
|
|
|
''start_bilevel_opt'' method needs to be called once before using this method.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(Tensor) Validation loss on a single batch of data.
|
|
|
|
"""
|
|
|
|
if(self._meta_opt): #Bilevel opti
|
|
|
|
return self._val_loss
|
|
|
|
else:
|
|
|
|
return compute_vaLoss(model=self._mods['model'], dl_it=self._dl_val_it, dl=self._dl_val)
|
|
|
|
|
|
|
|
##########################
|
2020-01-28 19:42:00 -05:00
|
|
|
|
2020-01-20 17:09:31 -05:00
|
|
|
def train(self, mode=True):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Set the module training mode.
|
|
|
|
|
|
|
|
Args:
|
2020-01-20 17:09:31 -05:00
|
|
|
mode (bool): Wether to learn the parameter of the module. (default: None)
|
2020-01-20 16:10:17 -05:00
|
|
|
"""
|
2020-01-20 17:09:31 -05:00
|
|
|
#if mode is None :
|
|
|
|
# mode=self._data_augmentation
|
2019-11-08 11:28:06 -05:00
|
|
|
super(Augmented_model, self).train(mode)
|
2020-01-20 17:09:31 -05:00
|
|
|
self._mods['data_aug'].augment(mode=self._data_augmentation) #Restart if needed data augmentation
|
2019-11-13 13:38:00 -05:00
|
|
|
return self
|
2019-11-08 11:28:06 -05:00
|
|
|
|
|
|
|
def eval(self):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Set the module to evaluation mode.
|
|
|
|
"""
|
2020-01-20 17:09:31 -05:00
|
|
|
#return self.train(mode=False)
|
|
|
|
super(Augmented_model, self).train(mode=False)
|
|
|
|
self._mods['data_aug'].augment(mode=False)
|
|
|
|
return self
|
2019-11-08 11:28:06 -05:00
|
|
|
|
|
|
|
def items(self):
|
|
|
|
"""Return an iterable of the ModuleDict key/value pairs.
|
|
|
|
"""
|
|
|
|
return self._mods.items()
|
|
|
|
|
|
|
|
def update(self, modules):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Update the module dictionnary.
|
|
|
|
|
|
|
|
The new dictionnary should keep the same structure.
|
|
|
|
"""
|
|
|
|
assert(self._mods.keys()==modules.keys())
|
2019-11-08 11:28:06 -05:00
|
|
|
self._mods.update(modules)
|
|
|
|
|
|
|
|
def is_augmenting(self):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Return wether data augmentation is applied.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool : True if data augmentation is applied.
|
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
return self._data_augmentation
|
|
|
|
|
|
|
|
def TF_names(self):
|
2020-01-20 16:10:17 -05:00
|
|
|
""" Get the transformations names used by the data augmentation module.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list : names of the transformations of the data augmentation module.
|
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
try:
|
|
|
|
return self._mods['data_aug']._TF
|
|
|
|
except:
|
|
|
|
return None
|
|
|
|
|
|
|
|
def __getitem__(self, key):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Access to the modules.
|
|
|
|
Args:
|
|
|
|
key (string): Name of the module to access.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
nn.Module.
|
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
return self._mods[key]
|
|
|
|
|
|
|
|
def __str__(self):
|
2020-01-20 16:10:17 -05:00
|
|
|
"""Name of the module
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
String containing the name of the module as well as the higher levels parameters.
|
|
|
|
"""
|
2020-01-22 11:15:56 -05:00
|
|
|
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|