mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Ajout fonctionnalitees apprentissage parametre optimisateur + mix dist
This commit is contained in:
parent
75901b69b4
commit
cd4b0405b9
3 changed files with 70 additions and 37 deletions
|
@ -545,11 +545,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._shared_mag = shared_mag
|
||||
self._fixed_mag = fixed_mag
|
||||
|
||||
self._fixed_mix=True
|
||||
if mix_dist is None: #Learn Mix dist
|
||||
self._fixed_mix = False
|
||||
mix_dist=0.5
|
||||
|
||||
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
||||
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
|
||||
})
|
||||
|
||||
for tf in TF.TF_no_grad :
|
||||
|
@ -560,9 +566,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._fixed_prob=fixed_prob
|
||||
self._samples = []
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
if mix_dist != 0.0: #Mix dist
|
||||
self._mix_dist = True
|
||||
self._mix_factor = max(min(mix_dist, 0.999), 0.0)
|
||||
#self._mix_factor = max(min(mix_dist, 0.999), 0.0)
|
||||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
|
@ -588,7 +594,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._distrib = uniforme_dist
|
||||
else:
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
|
||||
#self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
|
@ -638,6 +646,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||
#self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||
|
||||
if not self._fixed_mix:
|
||||
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
||||
|
||||
def loss_weight(self):
|
||||
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
||||
|
||||
|
@ -692,8 +703,10 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
if self._shared_mag: mag_param+= 'Sh'
|
||||
if not self._mix_dist:
|
||||
return "Data_augV5(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
elif self._fixed_mix:
|
||||
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
else:
|
||||
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
|
||||
|
||||
class Data_augV6(nn.Module): #Optimisation sequentielle #Mauvais resultats
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue