Evite redefinition inutile de prob dist + Fix mineur transformation

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-27 17:29:25 -05:00
parent 923ef7b85e
commit a2135e4709
2 changed files with 44 additions and 41 deletions

View file

@ -19,7 +19,7 @@ import copy
import transformations as TF import transformations as TF
### Data augmenter ###
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
"""Data augmentation module with learnable parameters. """Data augmentation module with learnable parameters.
@ -125,7 +125,6 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
for _ in range(self._N_seqTF):
## Echantillonage ## ## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
@ -137,6 +136,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
for _ in range(self._N_seqTF):
sample = cat_distrib.sample() sample = cat_distrib.sample()
self._samples.append(sample) self._samples.append(sample)
@ -210,7 +212,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Returns: Returns:
Tensor : Loss weights. Tensor : Loss weights.
""" """
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
@ -769,6 +771,7 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
""" """
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
### Models ###
import higher import higher
class Higher_model(nn.Module): class Higher_model(nn.Module):
"""Model wrapper for higher gradient tracking. """Model wrapper for higher gradient tracking.