Sample groupee

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-21 11:32:44 -05:00
parent 9513483893
commit d0a49a9d61

View file

@ -127,7 +127,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Returns: Returns:
Tensor : Batch of tranformed data. Tensor : Batch of tranformed data.
""" """
self._samples = [] self._samples = torch.Tensor([])
if self._data_augmentation:# and TF.random.random() < 0.5: if self._data_augmentation:# and TF.random.random() < 0.5:
device = x.device device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
@ -145,12 +145,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
self._samples=cat_distrib.sample([self._N_seqTF])
for _ in range(self._N_seqTF): for sample in self._samples:
sample = cat_distrib.sample()
self._samples.append(sample)
## Transformations ## ## Transformations ##
x = self.apply_TF(x, sample) x = self.apply_TF(x, sample)
return x return x
@ -448,7 +445,6 @@ class Data_augV7(nn.Module): #Proba sequentielles
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
## Echantillonage ## ## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1) uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1)
@ -707,14 +703,15 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
for _ in range(self._N_seqTF): ## Echantillonage ## == sampled_ops = np.random.choice(transforms, N)
## Echantillonage ## == sampled_ops = np.random.choice(transforms, N) uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist)
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist) self._samples=cat_distrib.sample([self._N_seqTF])
sample = cat_distrib.sample()
for sample in self._samples:
## Transformations ## ## Transformations ##
x = self.apply_TF(x, sample) x = self.apply_TF(x, sample)
return x return x
def apply_TF(self, x, sampled_TF): def apply_TF(self, x, sampled_TF):
@ -964,7 +961,7 @@ class Augmented_model(nn.Module):
model.step(loss) model.step(loss)
Lacking epoch informations, this does not support LR scheduler and delayed meta-optimisation(Meta-optimizer: epoch_start>1). Does not support LR scheduler.
See ''run_simple_smartaug'' for a complete example. See ''run_simple_smartaug'' for a complete example.