Evite redefinition inutile de prob dist + Fix mineur transformation

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-27 17:29:25 -05:00
parent 923ef7b85e
commit a2135e4709
2 changed files with 44 additions and 41 deletions

View file

@ -19,7 +19,7 @@ import copy
import transformations as TF
### Data augmenter ###
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
"""Data augmentation module with learnable parameters.
@ -125,7 +125,6 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
for _ in range(self._N_seqTF):
## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
@ -137,6 +136,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
for _ in range(self._N_seqTF):
sample = cat_distrib.sample()
self._samples.append(sample)
@ -205,12 +207,12 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction.
Do nottake into account the order of application of the TF. See Data_augV7.
Do not take into account the order of application of the TF. See Data_augV7.
Returns:
Tensor : Loss weights.
"""
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
@ -769,6 +771,7 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
"""
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
### Models ###
import higher
class Higher_model(nn.Module):
"""Model wrapper for higher gradient tracking.