Evite redefinition inutile de prob dist + Fix mineur transformation

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-27 17:29:25 -05:00
parent 923ef7b85e
commit a2135e4709
2 changed files with 44 additions and 41 deletions

View file

@ -19,7 +19,7 @@ import copy
import transformations as TF import transformations as TF
### Data augmenter ###
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
"""Data augmentation module with learnable parameters. """Data augmentation module with learnable parameters.
@ -125,18 +125,20 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
if not self._mix_dist:
self._distrib = uniforme_dist
else:
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
for _ in range(self._N_seqTF): for _ in range(self._N_seqTF):
## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
if not self._mix_dist:
self._distrib = uniforme_dist
else:
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
sample = cat_distrib.sample() sample = cat_distrib.sample()
self._samples.append(sample) self._samples.append(sample)
@ -205,12 +207,12 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Compute the weights for the loss of each inputs depending on wich TF was applied to them. Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction. Should be applied to the loss before reduction.
Do nottake into account the order of application of the TF. See Data_augV7. Do not take into account the order of application of the TF. See Data_augV7.
Returns: Returns:
Tensor : Loss weights. Tensor : Loss weights.
""" """
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
@ -769,6 +771,7 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
""" """
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
### Models ###
import higher import higher
class Higher_model(nn.Module): class Higher_model(nn.Module):
"""Model wrapper for higher gradient tracking. """Model wrapper for higher gradient tracking.

View file

@ -429,31 +429,31 @@ def auto_contrast(x):
x = int_image(x) #Expect image in the range of [0, 1] x = int_image(x) #Expect image in the range of [0, 1]
#print('Start',x[0]) #print('Start',x[0])
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
#print(img.shape) #print(img.shape)
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
#print(chan.shape) #print(chan.shape)
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
# find lowest/highest samples after preprocessing # find lowest/highest samples after preprocessing
for lo in range(256): for lo in range(256):
if hist[lo]: if hist[lo]:
break break
for hi in range(255, -1, -1): for hi in range(255, -1, -1):
if hist[hi]: if hist[hi]:
break break
if hi <= lo: if hi <= lo:
# don't bother # don't bother
pass pass
else: else:
scale = 255.0 / (hi - lo) scale = 255.0 / (hi - lo)
offset = -lo * scale offset = -lo * scale
for ix in range(256): for ix in range(256):
n_ix = int(ix * scale + offset) n_ix = int(ix * scale + offset)
if n_ix < 0: n_ix = 0 if n_ix < 0: n_ix = 0
elif n_ix > 255: n_ix = 255 elif n_ix > 255: n_ix = 255
chan[chan==ix]=n_ix chan[chan==ix]=n_ix
x[im_idx, chan_idx]=chan x[im_idx, chan_idx]=chan
#print('End',x[0]) #print('End',x[0])
return float_image(x) return float_image(x)
@ -468,9 +468,9 @@ def equalize(x):
x = int_image(x) #Expect image in the range of [0, 1] x = int_image(x) #Expect image in the range of [0, 1]
#print('Start',x[0]) #print('Start',x[0])
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
#print(img.shape) #print(img.shape)
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
#print(chan.shape) #print(chan.shape)
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
return float_image(x) return float_image(x)