mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Data_augV7 : Proba sequentielles + Minor improvements
This commit is contained in:
parent
f2019aae4a
commit
dc18397660
4 changed files with 317 additions and 11 deletions
306
higher/dataug.py
306
higher/dataug.py
|
@ -19,6 +19,7 @@ import copy
|
||||||
|
|
||||||
import transformations as TF
|
import transformations as TF
|
||||||
|
|
||||||
|
|
||||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
"""Data augmentation module with learnable parameters.
|
"""Data augmentation module with learnable parameters.
|
||||||
|
|
||||||
|
@ -68,6 +69,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
#Mag
|
#Mag
|
||||||
self._shared_mag = shared_mag
|
self._shared_mag = shared_mag
|
||||||
self._fixed_mag = fixed_mag
|
self._fixed_mag = fixed_mag
|
||||||
|
if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0:
|
||||||
|
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
|
||||||
|
self._fixed_mag=True
|
||||||
|
|
||||||
#Distribution
|
#Distribution
|
||||||
self._fixed_prob=fixed_prob
|
self._fixed_prob=fixed_prob
|
||||||
|
@ -289,6 +293,308 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
else:
|
else:
|
||||||
return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||||
|
|
||||||
|
class Data_augV7(nn.Module): #Proba sequentielles
|
||||||
|
"""Data augmentation module with learnable parameters.
|
||||||
|
|
||||||
|
Applies transformations (TF) to batch of data.
|
||||||
|
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
|
||||||
|
The TF probabilities defines a distribution from which we sample the TF applied.
|
||||||
|
|
||||||
|
Replace the use of TF by TF sets which are combinaisons of classic TF.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||||
|
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||||
|
_TF (list) : List of TF names.
|
||||||
|
_nb_tf (int) : Number of TF used.
|
||||||
|
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
||||||
|
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF.
|
||||||
|
_fixed_mag (bool): Wether to lock the TF magnitudes.
|
||||||
|
_fixed_prob (bool): Wether to lock the TF probabilies.
|
||||||
|
_samples (list): Sampled TF index during last forward pass.
|
||||||
|
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||||
|
_fixed_mix (bool): Wether we lock the mix distribution factor.
|
||||||
|
_params (nn.ParameterDict): Learnable parameters.
|
||||||
|
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||||
|
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||||
|
"""
|
||||||
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||||
|
"""Init Data_augv7.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||||
|
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
||||||
|
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
||||||
|
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||||
|
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||||
|
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||||
|
"""
|
||||||
|
super(Data_augV7, self).__init__()
|
||||||
|
assert len(TF_dict)>0
|
||||||
|
assert N_TF>=0
|
||||||
|
|
||||||
|
self._data_augmentation = True
|
||||||
|
|
||||||
|
#TF
|
||||||
|
self._TF_dict = TF_dict
|
||||||
|
self._TF= list(self._TF_dict.keys())
|
||||||
|
self._nb_tf= len(self._TF)
|
||||||
|
self._N_seqTF = N_TF
|
||||||
|
|
||||||
|
#Mag
|
||||||
|
self._shared_mag = shared_mag
|
||||||
|
self._fixed_mag = fixed_mag
|
||||||
|
if not self._fixed_mag and len([tf for tf in self._TF if tf not in TF.TF_ignore_mag])==0:
|
||||||
|
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
|
||||||
|
self._fixed_mag=True
|
||||||
|
|
||||||
|
#Distribution
|
||||||
|
self._fixed_prob=fixed_prob
|
||||||
|
self._samples = []
|
||||||
|
|
||||||
|
self._mix_dist = False
|
||||||
|
if mix_dist != 0.0: #Mix dist
|
||||||
|
self._mix_dist = True
|
||||||
|
|
||||||
|
self._fixed_mix=True
|
||||||
|
if mix_dist is None: #Learn Mix dist
|
||||||
|
self._fixed_mix = False
|
||||||
|
mix_dist=0.5
|
||||||
|
|
||||||
|
#TF sets
|
||||||
|
no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}}
|
||||||
|
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
|
||||||
|
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF
|
||||||
|
TF_sets=[]
|
||||||
|
if set_size>1:
|
||||||
|
for i in range(n_TF):
|
||||||
|
if not cons_test(i, idx_prefix):
|
||||||
|
TF_sets += generate_TF_sets(n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i])
|
||||||
|
else:
|
||||||
|
TF_sets+=[[idx_prefix+[i]] for i in range(n_TF) if not cons_test(i, idx_prefix)]
|
||||||
|
return TF_sets
|
||||||
|
|
||||||
|
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
|
||||||
|
self._nb_TF_sets=len(self._TF_sets)
|
||||||
|
print("Number of TF sets:",self._nb_TF_sets)
|
||||||
|
|
||||||
|
#Params
|
||||||
|
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||||
|
self._params = nn.ParameterDict({
|
||||||
|
"prob": nn.Parameter(torch.ones(self._nb_TF_sets)/self._nb_TF_sets), #Distribution prob uniforme
|
||||||
|
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
||||||
|
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||||
|
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
|
||||||
|
})
|
||||||
|
|
||||||
|
#for tf in TF.TF_no_grad :
|
||||||
|
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||||
|
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||||
|
|
||||||
|
#Mag regularisation
|
||||||
|
if not self._fixed_mag:
|
||||||
|
if self._shared_mag :
|
||||||
|
self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max
|
||||||
|
else:
|
||||||
|
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in TF.TF_ignore_mag]
|
||||||
|
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
""" Main method of the Data augmentation module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Batch of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor : Batch of tranformed data.
|
||||||
|
"""
|
||||||
|
self._samples = None
|
||||||
|
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||||
|
device = x.device
|
||||||
|
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||||
|
|
||||||
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||||
|
|
||||||
|
|
||||||
|
## Echantillonage ##
|
||||||
|
uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1)
|
||||||
|
|
||||||
|
if not self._mix_dist:
|
||||||
|
self._distrib = uniforme_dist
|
||||||
|
else:
|
||||||
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||||
|
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
|
||||||
|
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||||
|
|
||||||
|
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_TF_sets), device=device)*self._distrib)
|
||||||
|
sample = cat_distrib.sample()
|
||||||
|
|
||||||
|
self._samples=sample
|
||||||
|
TF_samples=self._TF_sets[sample,:].to(device) #[Batch_size, TFseq]
|
||||||
|
|
||||||
|
for i in range(self._N_seqTF):
|
||||||
|
## Transformations ##
|
||||||
|
x = self.apply_TF(x, TF_samples[:,i])
|
||||||
|
return x
|
||||||
|
|
||||||
|
def apply_TF(self, x, sampled_TF):
|
||||||
|
""" Applies the sampled transformations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Batch of data.
|
||||||
|
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Batch of tranformed data.
|
||||||
|
"""
|
||||||
|
device = x.device
|
||||||
|
batch_size, channels, h, w = x.shape
|
||||||
|
smps_x=[]
|
||||||
|
|
||||||
|
for tf_idx in range(self._nb_tf):
|
||||||
|
mask = sampled_TF==tf_idx #Create selection mask
|
||||||
|
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
|
||||||
|
|
||||||
|
if smp_x.shape[0]!=0: #if there's data to TF
|
||||||
|
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
|
||||||
|
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
|
||||||
|
|
||||||
|
tf=self._TF[tf_idx]
|
||||||
|
|
||||||
|
#In place
|
||||||
|
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||||
|
|
||||||
|
#Out of place
|
||||||
|
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||||
|
idx= mask.nonzero()
|
||||||
|
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
|
x=x.scatter(dim=0, index=idx, src=smp_x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||||
|
""" Enforce limitations to the learned parameters.
|
||||||
|
|
||||||
|
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
|
||||||
|
"""
|
||||||
|
if not self._fixed_prob:
|
||||||
|
if soft :
|
||||||
|
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||||
|
else:
|
||||||
|
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||||
|
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||||
|
|
||||||
|
if not self._fixed_mag:
|
||||||
|
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||||
|
|
||||||
|
if not self._fixed_mix:
|
||||||
|
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
||||||
|
|
||||||
|
def loss_weight(self):
|
||||||
|
""" Weights for the loss.
|
||||||
|
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||||
|
Should be applied to the loss before reduction.
|
||||||
|
|
||||||
|
TODO: Take into account the order of application of the TF.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor : Loss weights.
|
||||||
|
"""
|
||||||
|
if self._samples is None : return 1 #Pas d'echantillon = pas de ponderation
|
||||||
|
|
||||||
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||||
|
|
||||||
|
w_loss = torch.zeros((self._samples.shape[0],self._nb_TF_sets), device=self._samples.device)
|
||||||
|
w_loss.scatter_(1, self._samples.view(-1,1), 1)
|
||||||
|
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||||
|
w_loss = torch.sum(w_loss,dim=1)
|
||||||
|
return w_loss
|
||||||
|
|
||||||
|
def reg_loss(self, reg_factor=0.005):
|
||||||
|
""" Regularisation term used to learn the magnitudes.
|
||||||
|
Use an L2 loss to encourage high magnitudes TF.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005)
|
||||||
|
Returns:
|
||||||
|
Tensor containing the regularisation loss value.
|
||||||
|
"""
|
||||||
|
if self._fixed_mag:
|
||||||
|
return torch.tensor(0)
|
||||||
|
else:
|
||||||
|
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||||
|
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||||
|
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
||||||
|
return max_mag_reg
|
||||||
|
|
||||||
|
def TF_prob(self): #Eviter recalcul si pas de changement des proba
|
||||||
|
#print("WARNING: Calcul de proba inexact")
|
||||||
|
res=torch.zeros(self._nb_tf)
|
||||||
|
for idx_tf in range(self._nb_tf):
|
||||||
|
for i, t_set in enumerate(self._TF_sets):
|
||||||
|
if idx_tf in t_set:
|
||||||
|
res[idx_tf]+=self._params['prob'][i]
|
||||||
|
|
||||||
|
return res/sum(res) #*(self._nb_tf/self._nb_TF_sets)
|
||||||
|
|
||||||
|
def train(self, mode=True):
|
||||||
|
""" Set the module training mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
||||||
|
"""
|
||||||
|
#if mode is None :
|
||||||
|
# mode=self._data_augmentation
|
||||||
|
self.augment(mode=mode) #Inutile si mode=None
|
||||||
|
super(Data_augV7, self).train(mode)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
""" Set the module to evaluation mode.
|
||||||
|
"""
|
||||||
|
return self.train(mode=False)
|
||||||
|
|
||||||
|
def augment(self, mode=True):
|
||||||
|
""" Set the augmentation mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
||||||
|
"""
|
||||||
|
self._data_augmentation=mode
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
"""Access to the learnable parameters
|
||||||
|
Args:
|
||||||
|
key (string): Name of the learnable parameter to access.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Parameter.
|
||||||
|
"""
|
||||||
|
if key == 'prob': #Override prob access
|
||||||
|
return self.TF_prob()
|
||||||
|
return self._params[key]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Name of the module
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String containing the name of the module as well as the higher levels parameters.
|
||||||
|
"""
|
||||||
|
dist_param=''
|
||||||
|
if self._fixed_prob: dist_param+='Fx'
|
||||||
|
mag_param='Mag'
|
||||||
|
if self._fixed_mag: mag_param+= 'Fx'
|
||||||
|
if self._shared_mag: mag_param+= 'Sh'
|
||||||
|
if not self._mix_dist:
|
||||||
|
return "Data_augV7(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||||
|
elif self._fixed_mix:
|
||||||
|
return "Data_augV7(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||||
|
else:
|
||||||
|
return "Data_augV7(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||||
|
|
||||||
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||||
"""RandAugment implementation.
|
"""RandAugment implementation.
|
||||||
|
|
||||||
|
|
|
@ -585,7 +585,7 @@ class Data_augV6(nn.Module): #Optimisation sequentielle #Mauvais resultats
|
||||||
print("Warning : using only fixed set of TF : ", self._fixed_TF)
|
print("Warning : using only fixed set of TF : ", self._fixed_TF)
|
||||||
self._TF_sets=torch.tensor([self._fixed_TF])
|
self._TF_sets=torch.tensor([self._fixed_TF])
|
||||||
else:
|
else:
|
||||||
def generate_TF_sets(n_TF, set_size, idx_prefix=[]):
|
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every combinaison (without reuse) of TF
|
||||||
TF_sets=[]
|
TF_sets=[]
|
||||||
if len(idx_prefix)!=0:
|
if len(idx_prefix)!=0:
|
||||||
if set_size>2:
|
if set_size>2:
|
||||||
|
|
|
@ -170,7 +170,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
model = Higher_model(model) #run_dist_dataugV3
|
model = Higher_model(model) #run_dist_dataugV3
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
||||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||||
|
|
||||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||||
|
@ -179,7 +179,7 @@ if __name__ == "__main__":
|
||||||
inner_it=n_inner_iter,
|
inner_it=n_inner_iter,
|
||||||
dataug_epoch_start=dataug_epoch_start,
|
dataug_epoch_start=dataug_epoch_start,
|
||||||
opt_param=optim_param,
|
opt_param=optim_param,
|
||||||
print_freq=1,
|
print_freq=20,
|
||||||
KLdiv=True,
|
KLdiv=True,
|
||||||
hp_opt=False)
|
hp_opt=False)
|
||||||
|
|
||||||
|
|
|
@ -200,7 +200,7 @@ def float_parameter(level, maxval):
|
||||||
An int that results from scaling `maxval` according to `level`.
|
An int that results from scaling `maxval` according to `level`.
|
||||||
"""
|
"""
|
||||||
#return int(level * maxval / PARAMETER_MAX)
|
#return int(level * maxval / PARAMETER_MAX)
|
||||||
# return (level * maxval / PARAMETER_MAX)
|
# return (level * maxval / PARAMETER_MAX)
|
||||||
|
|
||||||
def flipLR(x):
|
def flipLR(x):
|
||||||
"""Flip horizontaly/Left-Right images.
|
"""Flip horizontaly/Left-Right images.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue