mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Changes since Teledyne
This commit is contained in:
parent
03ffd7fe05
commit
b89dac9084
185 changed files with 16668 additions and 484 deletions
|
@ -18,13 +18,17 @@ import numpy as np
|
|||
import copy
|
||||
|
||||
import transformations as TF
|
||||
import torchvision
|
||||
|
||||
import higher
|
||||
import higher_patch
|
||||
|
||||
from utils import clip_norm
|
||||
from utils import clip_norm
|
||||
from train_utils import compute_vaLoss
|
||||
|
||||
from datasets import MEAN, STD
|
||||
norm = TF.Normalizer(MEAN, STD)
|
||||
|
||||
### Data augmenter ###
|
||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||
"""Data augmentation module with learnable parameters.
|
||||
|
@ -46,19 +50,19 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
_fixed_mag (bool): Wether to lock the TF magnitudes.
|
||||
_fixed_prob (bool): Wether to lock the TF probabilies.
|
||||
_samples (list): Sampled TF index during last forward pass.
|
||||
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||
_fixed_mix (bool): Wether we lock the mix distribution factor.
|
||||
_temp (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||
_fixed_temp (bool): Wether we lock the mix distribution factor.
|
||||
_params (nn.ParameterDict): Learnable parameters.
|
||||
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||
"""
|
||||
def __init__(self, TF_dict, N_TF=1, mix_dist=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
def __init__(self, TF_dict, N_TF=1, temp=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
"""Init Data_augv5.
|
||||
|
||||
Args:
|
||||
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
||||
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0.5)
|
||||
temp (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-temp)*Uniform_distribution + temp*Real_distribution. If None is given, try to learn this parameter. (default: 0.5)
|
||||
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||
|
@ -88,27 +92,30 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._fixed_prob=fixed_prob
|
||||
self._samples = []
|
||||
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0: #Mix dist
|
||||
self._mix_dist = True
|
||||
# self._temp = False
|
||||
# if temp != 0.0: #Mix dist
|
||||
# self._temp = True
|
||||
|
||||
self._fixed_mix=True
|
||||
if mix_dist is None: #Learn Mix dist
|
||||
self._fixed_mix = False
|
||||
mix_dist=0.5
|
||||
self._fixed_temp=True
|
||||
if temp is None: #Learn Temp
|
||||
print("WARNING: Learning Temperature parameter isn't working with this version (No grad)")
|
||||
self._fixed_temp = False
|
||||
temp=0.5
|
||||
|
||||
#Params
|
||||
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
#"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)),
|
||||
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
||||
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
|
||||
"temp": nn.Parameter(torch.tensor(temp))#.clamp(min=0.0,max=0.999))
|
||||
})
|
||||
|
||||
for tf in self._TF_ignore_mag :
|
||||
self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
if not self._shared_mag:
|
||||
for tf in self._TF_ignore_mag :
|
||||
self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
|
@ -117,7 +124,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
else:
|
||||
TF_mag=[t for t in self._TF if t not in self._TF_ignore_mag] #TF w/ differentiable mag
|
||||
self._reg_mask=[self._TF.index(t) for t in TF_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX, dtype=self._params['mag'].dtype) #Encourage amplitude max
|
||||
|
||||
#Prevent Identity
|
||||
#print(TF.TF_identity)
|
||||
|
@ -137,28 +144,44 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
Tensor : Batch of tranformed data.
|
||||
"""
|
||||
self._samples = torch.Tensor([])
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
# x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
|
||||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
|
||||
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"]
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
self._distrib = F.softmax(prob*temp, dim=0)
|
||||
# prob = F.softmax(prob[1:], dim=0) #Bernouilli
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
self._samples=cat_distrib.sample([self._N_seqTF])
|
||||
|
||||
#Bernoulli (Requiert Identité en position 0)
|
||||
#assert(self._TF[0]=="Identity")
|
||||
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf-1), device=device)*self._distrib)
|
||||
# bern_distrib = Bernoulli(torch.tensor([0.5], device=device))
|
||||
# mask = bern_distrib.sample([self._N_seqTF, batch_size]).squeeze()
|
||||
# self._samples=(cat_distrib.sample([self._N_seqTF])+1)*mask
|
||||
|
||||
for sample in self._samples:
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
|
||||
# self._samples.to(device)
|
||||
# for n in range(self._N_seqTF):
|
||||
# # print('temp', (temp+0.3*n))
|
||||
# self._distrib = F.softmax(prob*(temp+0.2*n), dim=0)
|
||||
# # prob = F.softmax(prob[1:], dim=0) #Bernouilli
|
||||
|
||||
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
# new_sample=cat_distrib.sample()
|
||||
# self._samples=torch.cat((self._samples.to(device).to(new_sample.dtype), new_sample.unsqueeze(dim=0)), dim=0)
|
||||
|
||||
# x = self.apply_TF(x, new_sample)
|
||||
# print('sample',self._samples.shape)
|
||||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
|
@ -204,20 +227,20 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
Args:
|
||||
soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False)
|
||||
"""
|
||||
if not self._fixed_prob:
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
|
||||
else:
|
||||
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
# if not self._fixed_prob:
|
||||
# if soft :
|
||||
# self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
|
||||
# else:
|
||||
# self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||
# self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
if not self._fixed_mag:
|
||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||
|
||||
if not self._fixed_mix:
|
||||
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
||||
if not self._fixed_temp:
|
||||
self._params['temp'].data = self._params['temp'].data.clamp(min=0.0, max=0.999)
|
||||
|
||||
def loss_weight(self, mean_norm=False):
|
||||
def loss_weight(self, batch_norm=True):
|
||||
""" Weights for the loss.
|
||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||
Should be applied to the loss before reduction.
|
||||
|
@ -225,30 +248,37 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
Do not take into account the order of application of the TF. See Data_augV7.
|
||||
|
||||
Args:
|
||||
mean_norm (bool): Wether to normalize weights by mean or by distribution. (Default: Normalize by distribution.)
|
||||
Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
|
||||
Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if mix_dist=1.
|
||||
batch_norm (bool): Wether to normalize mean of the weights. (Default: True)
|
||||
|
||||
Returns:
|
||||
Tensor : Loss weights.
|
||||
"""
|
||||
if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
|
||||
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
#prob = F.softmax(prob, dim=0)
|
||||
|
||||
#Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !)
|
||||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||
for sample in self._samples:
|
||||
for sample in self._samples.to(torch.long):
|
||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
||||
w_loss += tmp_w
|
||||
|
||||
if mean_norm:
|
||||
w_loss = w_loss * prob
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
#w_loss=w_loss/w_loss.sum(dim=1, keepdim=True) #Bernoulli
|
||||
|
||||
#Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
|
||||
w_loss = w_loss * prob
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
if batch_norm:
|
||||
w_min = w_loss.min()
|
||||
w_loss = w_loss-w_min if w_min<0 else w_loss
|
||||
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
|
||||
else:
|
||||
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
#Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if temp=1.
|
||||
# w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
# w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
|
@ -310,6 +340,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
Returns:
|
||||
nn.Parameter.
|
||||
"""
|
||||
if key == 'prob': #Override prob access
|
||||
return F.softmax(self._params["prob"]*self._params["temp"], dim=0)
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
|
@ -323,22 +355,20 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
mag_param='Mag'
|
||||
if self._fixed_mag: mag_param+= 'Fx'
|
||||
if self._shared_mag: mag_param+= 'Sh'
|
||||
if not self._mix_dist:
|
||||
return "Data_augV5(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
elif self._fixed_mix:
|
||||
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
# if not self._temp:
|
||||
# return "Data_augV5(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
if self._fixed_temp:
|
||||
return "Data_augV5(T%.1f%s-%dTFx%d-%s)" % (self._params['temp'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
else:
|
||||
return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
return "Data_augV5(T%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
|
||||
class Data_augV7(nn.Module): #Proba sequentielles
|
||||
class Data_augV8(nn.Module): #Apprentissage proba sequentielles
|
||||
"""Data augmentation module with learnable parameters.
|
||||
|
||||
Applies transformations (TF) to batch of data.
|
||||
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
|
||||
The TF probabilities defines a distribution from which we sample the TF applied.
|
||||
|
||||
Replace the use of TF by TF sets which are combinaisons of classic TF.
|
||||
|
||||
Attributes:
|
||||
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||
|
@ -350,37 +380,34 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
_fixed_mag (bool): Wether to lock the TF magnitudes.
|
||||
_fixed_prob (bool): Wether to lock the TF probabilies.
|
||||
_samples (list): Sampled TF index during last forward pass.
|
||||
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||
_fixed_mix (bool): Wether we lock the mix distribution factor.
|
||||
_temp (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||
_fixed_temp (bool): Wether we lock the mix distribution factor.
|
||||
_params (nn.ParameterDict): Learnable parameters.
|
||||
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||
"""
|
||||
def __init__(self, TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
"""Init Data_augv7.
|
||||
def __init__(self, TF_dict, N_TF=1, temp=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
"""Init Data_augv8.
|
||||
|
||||
Args:
|
||||
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||
N_TF (int): Number of TF to be applied sequentially to each inputs. Minimum 2, otherwise prefer using Data_augV5. (default: 2)
|
||||
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
||||
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
||||
temp (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-temp)*Uniform_distribution + temp*Real_distribution. If None is given, try to learn this parameter. (default: 0.5)
|
||||
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||
TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
|
||||
"""
|
||||
super(Data_augV7, self).__init__()
|
||||
super(Data_augV8, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
assert N_TF>=0
|
||||
|
||||
if N_TF<2:
|
||||
print("WARNING: Data_augv7 isn't designed to use less than 2 sequentials TF. Please use Data_augv5 instead.")
|
||||
|
||||
self._data_augmentation = True
|
||||
|
||||
#TF
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._TF_ignore_mag= TF_ignore_mag
|
||||
self._TF_ignore_mag=TF_ignore_mag
|
||||
self._nb_tf= len(self._TF)
|
||||
self._N_seqTF = N_TF
|
||||
|
||||
|
@ -395,58 +422,50 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
self._fixed_prob=fixed_prob
|
||||
self._samples = []
|
||||
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0: #Mix dist
|
||||
self._mix_dist = True
|
||||
# self._temp = False
|
||||
# if temp != 0.0: #Mix dist
|
||||
# self._temp = True
|
||||
|
||||
self._fixed_mix=True
|
||||
if mix_dist is None: #Learn Mix dist
|
||||
self._fixed_mix = False
|
||||
mix_dist=0.5
|
||||
self._fixed_temp=True
|
||||
if temp is None: #Learn temp
|
||||
print("WARNING: Learning Temperature parameter isn't working with this version (No grad)")
|
||||
self._fixed_temp = False
|
||||
temp=0.5
|
||||
|
||||
#TF sets
|
||||
#import itertools
|
||||
#itertools.product(range(self._nb_tf), repeat=self._N_seqTF)
|
||||
|
||||
#no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}} #Specific No consecutive ops
|
||||
no_consecutive={idx for idx, t in enumerate(self._TF) if t not in {'Identity'}} #No consecutive same ops (except Identity)
|
||||
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
|
||||
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF (exclude cons_test arrangement)
|
||||
TF_sets=[]
|
||||
if set_size>1:
|
||||
for i in range(n_TF):
|
||||
if not cons_test(i, idx_prefix):
|
||||
TF_sets += generate_TF_sets(n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i])
|
||||
else:
|
||||
TF_sets+=[[idx_prefix+[i]] for i in range(n_TF) if not cons_test(i, idx_prefix)]
|
||||
return TF_sets
|
||||
|
||||
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
|
||||
self._nb_TF_sets=len(self._TF_sets)
|
||||
print("Number of TF sets:",self._nb_TF_sets)
|
||||
#print(self._TF_sets)
|
||||
self._prob_mem=torch.zeros(self._nb_TF_sets)
|
||||
|
||||
#Params
|
||||
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_TF_sets)/self._nb_TF_sets), #Distribution prob uniforme
|
||||
#"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
# "prob": nn.Parameter(torch.ones([self._nb_tf for _ in range(self._N_seqTF)])),
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf**self._N_seqTF)),
|
||||
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
||||
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
|
||||
"temp": nn.Parameter(torch.tensor(temp))#.clamp(min=0.0,max=0.999))
|
||||
})
|
||||
|
||||
#for tf in TF.TF_no_grad :
|
||||
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
self._prob_mem=torch.zeros(self._nb_tf**self._N_seqTF)
|
||||
|
||||
if not self._shared_mag:
|
||||
for tf in self._TF_ignore_mag :
|
||||
self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
if self._shared_mag :
|
||||
self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
|
||||
else:
|
||||
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in self._TF_ignore_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
TF_mag=[t for t in self._TF if t not in self._TF_ignore_mag] #TF w/ differentiable mag
|
||||
self._reg_mask=[self._TF.index(t) for t in TF_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX, dtype=self._params['mag'].dtype) #Encourage amplitude max
|
||||
|
||||
#Prevent Identity
|
||||
#print(TF.TF_identity)
|
||||
#self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=0.0)
|
||||
#for val in TF.TF_identity.keys():
|
||||
# idx=[self._reg_mask.index(self._TF.index(t)) for t in TF_mag if t in TF.TF_identity[val]]
|
||||
# self._reg_tgt[idx]=val
|
||||
#print(TF_mag, self._reg_tgt)
|
||||
|
||||
def forward(self, x):
|
||||
""" Main method of the Data augmentation module.
|
||||
|
@ -457,32 +476,54 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
Returns:
|
||||
Tensor : Batch of tranformed data.
|
||||
"""
|
||||
self._samples = None
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
self._samples = torch.Tensor([])
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
# x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1)
|
||||
# if not self._temp:
|
||||
# self._distrib = torch.ones(1,self._nb_tf**self._N_seqTF,device=device).softmax(dim=1)
|
||||
# else:
|
||||
# prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] #Uniform dist
|
||||
# # print(prob.shape)
|
||||
# # prob = prob.view(1, -1)
|
||||
# # prob = F.softmax(prob, dim=0)
|
||||
|
||||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
|
||||
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
# temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"] #Temperature
|
||||
# self._distrib = F.softmax(temp*prob, dim=0)
|
||||
# # self._distrib = (temp*prob+(1-temp)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
# # print(prob.shape)
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_TF_sets), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
|
||||
self._samples=sample
|
||||
TF_samples=self._TF_sets[sample,:].to(device) #[Batch_size, TFseq]
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"] #Temperature
|
||||
self._distrib = F.softmax(temp*prob, dim=0)
|
||||
|
||||
for i in range(self._N_seqTF):
|
||||
cat_distrib= Categorical(probs=torch.ones((self._nb_tf**self._N_seqTF), device=device)*self._distrib)
|
||||
samples=cat_distrib.sample([batch_size]) # (batch_size)
|
||||
# print(samples.shape)
|
||||
samples=torch.zeros((batch_size, self._nb_tf**self._N_seqTF), dtype=torch.bool, device=device).scatter_(dim=1, index=samples.unsqueeze(dim=1), value=1)
|
||||
self._samples=samples
|
||||
# print(samples.shape)
|
||||
# print(samples)
|
||||
samples=samples.view((batch_size,)+tuple([self._nb_tf for _ in range(self._N_seqTF)]))
|
||||
# print(samples.shape)
|
||||
# print(samples)
|
||||
samples= torch.nonzero(samples)[:,1:].T #Find indexes (TF sequence) => (N_seqTF, batch_size)
|
||||
# print(samples.shape)
|
||||
|
||||
#Bernoulli (Requiert Identité en position 0)
|
||||
#assert(self._TF[0]=="Identity")
|
||||
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf-1), device=device)*self._distrib)
|
||||
# bern_distrib = Bernoulli(torch.tensor([0.5], device=device))
|
||||
# mask = bern_distrib.sample([self._N_seqTF, batch_size]).squeeze()
|
||||
# self._samples=(cat_distrib.sample([self._N_seqTF])+1)*mask
|
||||
|
||||
for sample in samples:
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, TF_samples[:,i])
|
||||
x = self.apply_TF(x, sample)
|
||||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
|
@ -526,37 +567,55 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
|
||||
|
||||
Args:
|
||||
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
|
||||
soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False)
|
||||
"""
|
||||
if not self._fixed_prob:
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
# if not self._fixed_prob:
|
||||
# if soft :
|
||||
# self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
|
||||
# else:
|
||||
# self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||
# self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
if not self._fixed_mag:
|
||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||
|
||||
if not self._fixed_mix:
|
||||
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
||||
if not self._fixed_temp:
|
||||
self._params['temp'].data = self._params['temp'].data.clamp(min=0.0, max=0.999)
|
||||
|
||||
def loss_weight(self):
|
||||
def loss_weight(self, batch_norm=True):
|
||||
""" Weights for the loss.
|
||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||
Should be applied to the loss before reduction.
|
||||
|
||||
Args:
|
||||
batch_norm (bool): Wether to normalize mean of the weights. (Default: True)
|
||||
|
||||
Returns:
|
||||
Tensor : Loss weights.
|
||||
"""
|
||||
if self._samples is None : return 1 #Pas d'echantillon = pas de ponderation
|
||||
device=self._params["prob"].device
|
||||
if len(self._samples)==0 : return torch.tensor(1, device=device) #Pas d'echantillon = pas de ponderation
|
||||
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
|
||||
w_loss = torch.zeros((self._samples.shape[0],self._nb_TF_sets), device=self._samples.device)
|
||||
w_loss.scatter_(1, self._samples.view(-1,1), 1)
|
||||
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
# print("prob",prob.shape)
|
||||
# print(self._samples.shape)
|
||||
|
||||
#w_loss=w_loss/w_loss.sum(dim=1, keepdim=True) #Bernoulli
|
||||
|
||||
#Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
|
||||
w_loss = self._samples * prob
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
# print("W_loss",w_loss.shape)
|
||||
# print(w_loss)
|
||||
|
||||
if batch_norm:
|
||||
w_min = w_loss.min()
|
||||
w_loss = w_loss-w_min if w_min<0 else w_loss
|
||||
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
|
||||
|
||||
#Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if temp=1.
|
||||
# w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
# w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
|
@ -573,30 +632,43 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
else:
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
||||
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Close to target ?
|
||||
#max_mag_reg = - reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Far from target ?
|
||||
return max_mag_reg
|
||||
|
||||
def TF_prob(self):
|
||||
""" Gives an estimation of the individual TF probabilities.
|
||||
|
||||
Be warry that the probability returned isn't exact. The TF distribution isn't fully represented by those.
|
||||
Each probability should be taken individualy. They only represent the chance for a specific TF to be picked at least once.
|
||||
|
||||
Returms:
|
||||
Tensor containing the single TF probabilities of applications.
|
||||
"""
|
||||
if torch.all(self._params['prob']!=self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
|
||||
self._prob_mem=self._params['prob'].data.detach_()
|
||||
self._single_TF_prob=torch.zeros(self._nb_tf)
|
||||
for idx_tf in range(self._nb_tf):
|
||||
for i, t_set in enumerate(self._TF_sets):
|
||||
#uni, count = np.unique(t_set, return_counts=True)
|
||||
#if idx_tf in uni:
|
||||
# res[idx_tf]+=self._params['prob'][i]*int(count[np.where(uni==idx_tf)])
|
||||
if idx_tf in t_set:
|
||||
self._single_TF_prob[idx_tf]+=self._params['prob'][i]
|
||||
# if not torch.all(self._params['prob']==self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
|
||||
# self._prob_mem=self._params['prob'].data.detach_()
|
||||
|
||||
return self._single_TF_prob
|
||||
# p = self._params['prob'].view([self._nb_tf for _ in range(self._N_seqTF)])
|
||||
# # print('prob',p)
|
||||
# self._single_TF_prob=p.mean(dim=[i+1 for i in range(self._N_seqTF-1)]) #Reduce to 1D tensor
|
||||
# # print(self._single_TF_prob)
|
||||
# self._single_TF_prob=F.softmax(self._single_TF_prob, dim=0)
|
||||
# print('Soft',self._single_TF_prob)
|
||||
|
||||
p=F.softmax(self._params['prob']*self._params["temp"], dim=0) #Sampling dist
|
||||
p=p.view([self._nb_tf for _ in range(self._N_seqTF)])
|
||||
p=p.mean(dim=[i+1 for i in range(self._N_seqTF-1)]) #Reduce to 1D tensor
|
||||
|
||||
#Means over each dim
|
||||
# dim_idx=tuple(range(self._N_seqTF))
|
||||
# means=[]
|
||||
# for d in dim_idx:
|
||||
# dim_mean=list(dim_idx)
|
||||
# dim_mean.remove(d)
|
||||
# means.append(p.mean(dim=dim_mean).unsqueeze(dim=1))
|
||||
# means=torch.cat(means,dim=1)
|
||||
# print(means)
|
||||
# p=means.mean(dim=1)
|
||||
# print(p)
|
||||
|
||||
return p
|
||||
|
||||
def train(self, mode=True):
|
||||
""" Set the module training mode.
|
||||
|
@ -607,7 +679,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
#if mode is None :
|
||||
# mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(Data_augV7, self).train(mode)
|
||||
super(Data_augV8, self).train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
|
@ -654,12 +726,13 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
mag_param='Mag'
|
||||
if self._fixed_mag: mag_param+= 'Fx'
|
||||
if self._shared_mag: mag_param+= 'Sh'
|
||||
if not self._mix_dist:
|
||||
return "Data_augV7(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
elif self._fixed_mix:
|
||||
return "Data_augV7(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
# if not self._temp:
|
||||
# return "Data_augV8(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
if self._fixed_temp:
|
||||
return "Data_augV8(T%.1f%s-%dTFx%d-%s)" % (self._params['temp'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
else:
|
||||
return "Data_augV7(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
return "Data_augV8(T%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
|
||||
|
||||
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||
"""RandAugment implementation.
|
||||
|
@ -703,7 +776,7 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
self._shared_mag = True
|
||||
self._fixed_mag = True
|
||||
self._fixed_prob=True
|
||||
self._fixed_mix=True
|
||||
self._fixed_temp=True
|
||||
|
||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||
|
||||
|
@ -716,17 +789,24 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
Returns:
|
||||
Tensor : Batch of tranformed data.
|
||||
"""
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
# x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
|
||||
## Echantillonage ## == sampled_ops = np.random.choice(transforms, N)
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist)
|
||||
self._samples=cat_distrib.sample([self._N_seqTF])
|
||||
|
||||
#Bernoulli (Requiert Identité en position 0)
|
||||
# uniforme_dist = torch.ones(1,self._nb_tf-1,device=device).softmax(dim=1)
|
||||
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf-1), device=device)*uniforme_dist)
|
||||
# bern_distrib = Bernoulli(torch.tensor([0.5], device=device))
|
||||
# mask = bern_distrib.sample([self._N_seqTF, batch_size]).squeeze()
|
||||
# self._samples=(cat_distrib.sample([self._N_seqTF])+1)*mask
|
||||
|
||||
for sample in self._samples:
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
|
@ -765,10 +845,10 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
"""
|
||||
pass #Pas de parametre a opti
|
||||
|
||||
def loss_weight(self):
|
||||
def loss_weight(self, batch_norm=False):
|
||||
"""Not used
|
||||
"""
|
||||
return 1 #Pas d'echantillon = pas de ponderation
|
||||
return torch.tensor([1], device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
"""Not used
|
||||
|
@ -949,18 +1029,22 @@ class Augmented_model(nn.Module):
|
|||
|
||||
self.augment(mode=True)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, copy=False):
|
||||
""" Main method of the Augmented model.
|
||||
|
||||
Perform the forward pass of both modules.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
copy (Bool): Wether to alter a copy or the original input. It's recommended to use a copy for parallel use of the input. (Default: False)
|
||||
|
||||
Returns:
|
||||
Tensor : Output of the networks. Should be logits.
|
||||
"""
|
||||
return self._mods['model'](self._mods['data_aug'](x))
|
||||
if copy:
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
return self._mods['model'](norm(self._mods['data_aug'](x)))
|
||||
# return self._mods['model'](self._mods['data_aug'](x))
|
||||
|
||||
def augment(self, mode=True):
|
||||
""" Set the augmentation mode.
|
||||
|
@ -970,6 +1054,12 @@ class Augmented_model(nn.Module):
|
|||
"""
|
||||
self._data_augmentation=mode
|
||||
self._mods['data_aug'].augment(mode)
|
||||
|
||||
#ABN
|
||||
# if mode :
|
||||
# self._mods['model']['functional'].set_mode('augmented')
|
||||
# else :
|
||||
# self._mods['model']['functional'].set_mode('clean')
|
||||
|
||||
#### Encapsulation Meta Opt ####
|
||||
def start_bilevel_opt(self, inner_it, hp_list, opt_param, dl_val):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue