mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Amelioration Dataugv7
This commit is contained in:
parent
2e09f07f52
commit
f83c73ec17
1 changed files with 36 additions and 15 deletions
|
@ -27,6 +27,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
|
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
|
||||||
The TF probabilities defines a distribution from which we sample the TF applied.
|
The TF probabilities defines a distribution from which we sample the TF applied.
|
||||||
|
|
||||||
|
Be warry, that the order of sequential application of TF is not taken into account. See Data_augV7.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||||
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||||
|
@ -203,7 +205,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||||
Should be applied to the loss before reduction.
|
Should be applied to the loss before reduction.
|
||||||
|
|
||||||
TODO: Take into account the order of application of the TF.
|
Do nottake into account the order of application of the TF. See Data_augV7.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor : Loss weights.
|
Tensor : Loss weights.
|
||||||
|
@ -318,12 +320,12 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
||||||
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||||
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||||
"""
|
"""
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||||
"""Init Data_augv7.
|
"""Init Data_augv7.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||||
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
N_TF (int): Number of TF to be applied sequentially to each inputs. Minimum 2, otherwise prefer using Data_augV5. (default: 2)
|
||||||
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
||||||
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||||
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||||
|
@ -333,6 +335,9 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
||||||
assert len(TF_dict)>0
|
assert len(TF_dict)>0
|
||||||
assert N_TF>=0
|
assert N_TF>=0
|
||||||
|
|
||||||
|
if N_TF<2:
|
||||||
|
print("WARNING: Data_augv7 isn't designed to use less than 2 sequentials TF. Please use Data_augv5 instead.")
|
||||||
|
|
||||||
self._data_augmentation = True
|
self._data_augmentation = True
|
||||||
|
|
||||||
#TF
|
#TF
|
||||||
|
@ -362,9 +367,13 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
||||||
mix_dist=0.5
|
mix_dist=0.5
|
||||||
|
|
||||||
#TF sets
|
#TF sets
|
||||||
no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}}
|
#import itertools
|
||||||
|
#itertools.product(range(self._nb_tf), repeat=self._N_seqTF)
|
||||||
|
|
||||||
|
#no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}} #Specific No consecutive ops
|
||||||
|
no_consecutive={idx for idx, t in enumerate(self._TF) if t not in {'Identity'}} #No consecutive same ops (except Identity)
|
||||||
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
|
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
|
||||||
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF
|
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF (exclude cons_test arrangement)
|
||||||
TF_sets=[]
|
TF_sets=[]
|
||||||
if set_size>1:
|
if set_size>1:
|
||||||
for i in range(n_TF):
|
for i in range(n_TF):
|
||||||
|
@ -377,6 +386,8 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
||||||
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
|
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
|
||||||
self._nb_TF_sets=len(self._TF_sets)
|
self._nb_TF_sets=len(self._TF_sets)
|
||||||
print("Number of TF sets:",self._nb_TF_sets)
|
print("Number of TF sets:",self._nb_TF_sets)
|
||||||
|
#print(self._TF_sets)
|
||||||
|
self._prob_mem=torch.zeros(self._nb_TF_sets)
|
||||||
|
|
||||||
#Params
|
#Params
|
||||||
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||||
|
@ -498,8 +509,6 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
||||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||||
Should be applied to the loss before reduction.
|
Should be applied to the loss before reduction.
|
||||||
|
|
||||||
TODO: Take into account the order of application of the TF.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor : Loss weights.
|
Tensor : Loss weights.
|
||||||
"""
|
"""
|
||||||
|
@ -530,15 +539,27 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
||||||
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
||||||
return max_mag_reg
|
return max_mag_reg
|
||||||
|
|
||||||
def TF_prob(self): #Eviter recalcul si pas de changement des proba
|
def TF_prob(self):
|
||||||
#print("WARNING: Calcul de proba inexact")
|
""" Gives an estimation of the individual TF probabilities.
|
||||||
res=torch.zeros(self._nb_tf)
|
|
||||||
for idx_tf in range(self._nb_tf):
|
|
||||||
for i, t_set in enumerate(self._TF_sets):
|
|
||||||
if idx_tf in t_set:
|
|
||||||
res[idx_tf]+=self._params['prob'][i]
|
|
||||||
|
|
||||||
return res/sum(res) #*(self._nb_tf/self._nb_TF_sets)
|
Be warry that the probability returned isn't exact. The TF distribution isn't fully represented by those.
|
||||||
|
Each probability should be taken individualy. They only represent the chance for a specific TF to be picked at least once.
|
||||||
|
|
||||||
|
Returms:
|
||||||
|
Tensor containing the single TF probabilities of applications.
|
||||||
|
"""
|
||||||
|
if torch.all(self._params['prob']!=self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
|
||||||
|
self._prob_mem=self._params['prob'].data.detach_()
|
||||||
|
self._single_TF_prob=torch.zeros(self._nb_tf)
|
||||||
|
for idx_tf in range(self._nb_tf):
|
||||||
|
for i, t_set in enumerate(self._TF_sets):
|
||||||
|
#uni, count = np.unique(t_set, return_counts=True)
|
||||||
|
#if idx_tf in uni:
|
||||||
|
# res[idx_tf]+=self._params['prob'][i]*int(count[np.where(uni==idx_tf)])
|
||||||
|
if idx_tf in t_set:
|
||||||
|
self._single_TF_prob[idx_tf]+=self._params['prob'][i]
|
||||||
|
|
||||||
|
return self._single_TF_prob
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
""" Set the module training mode.
|
""" Set the module training mode.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue