mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Evite redefinition inutile de prob dist + Fix mineur transformation
This commit is contained in:
parent
923ef7b85e
commit
a2135e4709
2 changed files with 44 additions and 41 deletions
|
@ -19,7 +19,7 @@ import copy
|
||||||
|
|
||||||
import transformations as TF
|
import transformations as TF
|
||||||
|
|
||||||
|
### Data augmenter ###
|
||||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
"""Data augmentation module with learnable parameters.
|
"""Data augmentation module with learnable parameters.
|
||||||
|
|
||||||
|
@ -125,18 +125,20 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
|
|
||||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||||
|
|
||||||
|
## Echantillonage ##
|
||||||
|
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||||
|
|
||||||
|
if not self._mix_dist:
|
||||||
|
self._distrib = uniforme_dist
|
||||||
|
else:
|
||||||
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||||
|
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
|
||||||
|
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||||
|
|
||||||
|
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||||
|
|
||||||
for _ in range(self._N_seqTF):
|
for _ in range(self._N_seqTF):
|
||||||
## Echantillonage ##
|
|
||||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
|
||||||
|
|
||||||
if not self._mix_dist:
|
|
||||||
self._distrib = uniforme_dist
|
|
||||||
else:
|
|
||||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
|
||||||
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
|
|
||||||
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
|
||||||
|
|
||||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
|
||||||
sample = cat_distrib.sample()
|
sample = cat_distrib.sample()
|
||||||
self._samples.append(sample)
|
self._samples.append(sample)
|
||||||
|
|
||||||
|
@ -205,12 +207,12 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||||
Should be applied to the loss before reduction.
|
Should be applied to the loss before reduction.
|
||||||
|
|
||||||
Do nottake into account the order of application of the TF. See Data_augV7.
|
Do not take into account the order of application of the TF. See Data_augV7.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor : Loss weights.
|
Tensor : Loss weights.
|
||||||
"""
|
"""
|
||||||
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
|
||||||
|
|
||||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||||
|
|
||||||
|
@ -769,6 +771,7 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||||
"""
|
"""
|
||||||
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||||
|
|
||||||
|
### Models ###
|
||||||
import higher
|
import higher
|
||||||
class Higher_model(nn.Module):
|
class Higher_model(nn.Module):
|
||||||
"""Model wrapper for higher gradient tracking.
|
"""Model wrapper for higher gradient tracking.
|
||||||
|
|
|
@ -429,31 +429,31 @@ def auto_contrast(x):
|
||||||
x = int_image(x) #Expect image in the range of [0, 1]
|
x = int_image(x) #Expect image in the range of [0, 1]
|
||||||
#print('Start',x[0])
|
#print('Start',x[0])
|
||||||
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
|
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
|
||||||
#print(img.shape)
|
#print(img.shape)
|
||||||
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
|
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
|
||||||
#print(chan.shape)
|
#print(chan.shape)
|
||||||
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
||||||
|
|
||||||
# find lowest/highest samples after preprocessing
|
# find lowest/highest samples after preprocessing
|
||||||
for lo in range(256):
|
for lo in range(256):
|
||||||
if hist[lo]:
|
if hist[lo]:
|
||||||
break
|
break
|
||||||
for hi in range(255, -1, -1):
|
for hi in range(255, -1, -1):
|
||||||
if hist[hi]:
|
if hist[hi]:
|
||||||
break
|
break
|
||||||
if hi <= lo:
|
if hi <= lo:
|
||||||
# don't bother
|
# don't bother
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
scale = 255.0 / (hi - lo)
|
scale = 255.0 / (hi - lo)
|
||||||
offset = -lo * scale
|
offset = -lo * scale
|
||||||
for ix in range(256):
|
for ix in range(256):
|
||||||
n_ix = int(ix * scale + offset)
|
n_ix = int(ix * scale + offset)
|
||||||
if n_ix < 0: n_ix = 0
|
if n_ix < 0: n_ix = 0
|
||||||
elif n_ix > 255: n_ix = 255
|
elif n_ix > 255: n_ix = 255
|
||||||
|
|
||||||
chan[chan==ix]=n_ix
|
chan[chan==ix]=n_ix
|
||||||
x[im_idx, chan_idx]=chan
|
x[im_idx, chan_idx]=chan
|
||||||
|
|
||||||
#print('End',x[0])
|
#print('End',x[0])
|
||||||
return float_image(x)
|
return float_image(x)
|
||||||
|
@ -468,9 +468,9 @@ def equalize(x):
|
||||||
x = int_image(x) #Expect image in the range of [0, 1]
|
x = int_image(x) #Expect image in the range of [0, 1]
|
||||||
#print('Start',x[0])
|
#print('Start',x[0])
|
||||||
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
|
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
|
||||||
#print(img.shape)
|
#print(img.shape)
|
||||||
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
|
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
|
||||||
#print(chan.shape)
|
#print(chan.shape)
|
||||||
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
||||||
|
|
||||||
return float_image(x)
|
return float_image(x)
|
Loading…
Add table
Add a link
Reference in a new issue