diff --git a/higher/smart_aug/dataug.py b/higher/smart_aug/dataug.py index 2f44397..1e7a100 100755 --- a/higher/smart_aug/dataug.py +++ b/higher/smart_aug/dataug.py @@ -19,7 +19,7 @@ import copy import transformations as TF - +### Data augmenter ### class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) """Data augmentation module with learnable parameters. @@ -125,18 +125,20 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) + ## Echantillonage ## + uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) + + if not self._mix_dist: + self._distrib = uniforme_dist + else: + prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] + mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"] + self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor + + cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) + for _ in range(self._N_seqTF): - ## Echantillonage ## - uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) - - if not self._mix_dist: - self._distrib = uniforme_dist - else: - prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] - mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"] - self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor - - cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) + sample = cat_distrib.sample() self._samples.append(sample) @@ -205,12 +207,12 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) Compute the weights for the loss of each inputs depending on wich TF was applied to them. Should be applied to the loss before reduction. - Do nottake into account the order of application of the TF. See Data_augV7. + Do not take into account the order of application of the TF. See Data_augV7. Returns: Tensor : Loss weights. """ - if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation + if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] @@ -769,6 +771,7 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide """ return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) +### Models ### import higher class Higher_model(nn.Module): """Model wrapper for higher gradient tracking. diff --git a/higher/smart_aug/transformations.py b/higher/smart_aug/transformations.py index 8534584..a46eb1a 100755 --- a/higher/smart_aug/transformations.py +++ b/higher/smart_aug/transformations.py @@ -429,31 +429,31 @@ def auto_contrast(x): x = int_image(x) #Expect image in the range of [0, 1] #print('Start',x[0]) for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image - #print(img.shape) - for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel - #print(chan.shape) - hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE + #print(img.shape) + for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel + #print(chan.shape) + hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE - # find lowest/highest samples after preprocessing - for lo in range(256): - if hist[lo]: - break - for hi in range(255, -1, -1): - if hist[hi]: - break - if hi <= lo: - # don't bother - pass - else: - scale = 255.0 / (hi - lo) - offset = -lo * scale - for ix in range(256): - n_ix = int(ix * scale + offset) - if n_ix < 0: n_ix = 0 - elif n_ix > 255: n_ix = 255 + # find lowest/highest samples after preprocessing + for lo in range(256): + if hist[lo]: + break + for hi in range(255, -1, -1): + if hist[hi]: + break + if hi <= lo: + # don't bother + pass + else: + scale = 255.0 / (hi - lo) + offset = -lo * scale + for ix in range(256): + n_ix = int(ix * scale + offset) + if n_ix < 0: n_ix = 0 + elif n_ix > 255: n_ix = 255 - chan[chan==ix]=n_ix - x[im_idx, chan_idx]=chan + chan[chan==ix]=n_ix + x[im_idx, chan_idx]=chan #print('End',x[0]) return float_image(x) @@ -468,9 +468,9 @@ def equalize(x): x = int_image(x) #Expect image in the range of [0, 1] #print('Start',x[0]) for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image - #print(img.shape) - for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel - #print(chan.shape) - hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE + #print(img.shape) + for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel + #print(chan.shape) + hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE return float_image(x) \ No newline at end of file