Amelioration Dataugv7

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-24 11:50:45 -05:00
parent 2e09f07f52
commit f83c73ec17

View file

@ -27,6 +27,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
The TF probabilities defines a distribution from which we sample the TF applied.
Be warry, that the order of sequential application of TF is not taken into account. See Data_augV7.
Attributes:
_data_augmentation (bool): Wether TF will be applied during forward pass.
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
@ -203,7 +205,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction.
TODO: Take into account the order of application of the TF.
Do nottake into account the order of application of the TF. See Data_augV7.
Returns:
Tensor : Loss weights.
@ -318,12 +320,12 @@ class Data_augV7(nn.Module): #Proba sequentielles
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
_reg_mask (list): Mask selecting the TF considered for the regularisation.
"""
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
def __init__(self, TF_dict=TF.TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
"""Init Data_augv7.
Args:
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
N_TF (int): Number of TF to be applied sequentially to each inputs. Minimum 2, otherwise prefer using Data_augV5. (default: 2)
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
@ -333,6 +335,9 @@ class Data_augV7(nn.Module): #Proba sequentielles
assert len(TF_dict)>0
assert N_TF>=0
if N_TF<2:
print("WARNING: Data_augv7 isn't designed to use less than 2 sequentials TF. Please use Data_augv5 instead.")
self._data_augmentation = True
#TF
@ -362,9 +367,13 @@ class Data_augV7(nn.Module): #Proba sequentielles
mix_dist=0.5
#TF sets
no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}}
#import itertools
#itertools.product(range(self._nb_tf), repeat=self._N_seqTF)
#no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}} #Specific No consecutive ops
no_consecutive={idx for idx, t in enumerate(self._TF) if t not in {'Identity'}} #No consecutive same ops (except Identity)
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF (exclude cons_test arrangement)
TF_sets=[]
if set_size>1:
for i in range(n_TF):
@ -377,6 +386,8 @@ class Data_augV7(nn.Module): #Proba sequentielles
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
self._nb_TF_sets=len(self._TF_sets)
print("Number of TF sets:",self._nb_TF_sets)
#print(self._TF_sets)
self._prob_mem=torch.zeros(self._nb_TF_sets)
#Params
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
@ -498,8 +509,6 @@ class Data_augV7(nn.Module): #Proba sequentielles
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction.
TODO: Take into account the order of application of the TF.
Returns:
Tensor : Loss weights.
"""
@ -530,15 +539,27 @@ class Data_augV7(nn.Module): #Proba sequentielles
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
return max_mag_reg
def TF_prob(self): #Eviter recalcul si pas de changement des proba
#print("WARNING: Calcul de proba inexact")
res=torch.zeros(self._nb_tf)
def TF_prob(self):
""" Gives an estimation of the individual TF probabilities.
Be warry that the probability returned isn't exact. The TF distribution isn't fully represented by those.
Each probability should be taken individualy. They only represent the chance for a specific TF to be picked at least once.
Returms:
Tensor containing the single TF probabilities of applications.
"""
if torch.all(self._params['prob']!=self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
self._prob_mem=self._params['prob'].data.detach_()
self._single_TF_prob=torch.zeros(self._nb_tf)
for idx_tf in range(self._nb_tf):
for i, t_set in enumerate(self._TF_sets):
#uni, count = np.unique(t_set, return_counts=True)
#if idx_tf in uni:
# res[idx_tf]+=self._params['prob'][i]*int(count[np.where(uni==idx_tf)])
if idx_tf in t_set:
res[idx_tf]+=self._params['prob'][i]
self._single_TF_prob[idx_tf]+=self._params['prob'][i]
return res/sum(res) #*(self._nb_tf/self._nb_TF_sets)
return self._single_TF_prob
def train(self, mode=True):
""" Set the module training mode.