diff --git a/higher/smart_aug/dataug.py b/higher/smart_aug/dataug.py index 4df0e15..93cc598 100755 --- a/higher/smart_aug/dataug.py +++ b/higher/smart_aug/dataug.py @@ -127,7 +127,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) Returns: Tensor : Batch of tranformed data. """ - self._samples = [] + self._samples = torch.Tensor([]) if self._data_augmentation:# and TF.random.random() < 0.5: device = x.device batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] @@ -145,12 +145,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) - - for _ in range(self._N_seqTF): - - sample = cat_distrib.sample() - self._samples.append(sample) + self._samples=cat_distrib.sample([self._N_seqTF]) + for sample in self._samples: ## Transformations ## x = self.apply_TF(x, sample) return x @@ -448,7 +445,6 @@ class Data_augV7(nn.Module): #Proba sequentielles x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) - ## Echantillonage ## uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1) @@ -707,14 +703,15 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) - for _ in range(self._N_seqTF): - ## Echantillonage ## == sampled_ops = np.random.choice(transforms, N) - uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) - cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist) - sample = cat_distrib.sample() + ## Echantillonage ## == sampled_ops = np.random.choice(transforms, N) + uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) + cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist) + self._samples=cat_distrib.sample([self._N_seqTF]) + for sample in self._samples: ## Transformations ## x = self.apply_TF(x, sample) + return x def apply_TF(self, x, sampled_TF): @@ -964,7 +961,7 @@ class Augmented_model(nn.Module): model.step(loss) - Lacking epoch informations, this does not support LR scheduler and delayed meta-optimisation(Meta-optimizer: epoch_start>1). + Does not support LR scheduler. See ''run_simple_smartaug'' for a complete example.