mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Dataugv5- Modification des TF pour propagation du gradient (mag)
This commit is contained in:
parent
05f81787d6
commit
994d657a28
5 changed files with 94 additions and 21 deletions
|
@ -583,19 +583,33 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
|
|
||||||
def apply_TF(self, x, sampled_TF):
|
def apply_TF(self, x, sampled_TF):
|
||||||
device = x.device
|
device = x.device
|
||||||
|
batch_size, channels, h, w = x.shape
|
||||||
smps_x=[]
|
smps_x=[]
|
||||||
masks=[]
|
|
||||||
for tf_idx in range(self._nb_tf):
|
for tf_idx in range(self._nb_tf):
|
||||||
mask = sampled_TF==tf_idx #Create selection mask
|
mask = sampled_TF==tf_idx #Create selection mask
|
||||||
smp_x = x[mask] #torch.masked_select() ?
|
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
|
||||||
|
|
||||||
if smp_x.shape[0]!=0: #if there's data to TF
|
if smp_x.shape[0]!=0: #if there's data to TF
|
||||||
magnitude=self._params["mag"][tf_idx]*10
|
magnitude=self._params["mag"][tf_idx]*10
|
||||||
tf=self._TF[tf_idx]
|
tf=self._TF[tf_idx]
|
||||||
#print(magnitude)
|
#print(magnitude)
|
||||||
|
|
||||||
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
|
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
|
||||||
|
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||||
|
|
||||||
|
idx= mask.nonzero()
|
||||||
|
#print('-'*8)
|
||||||
|
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
|
#print(idx.shape, smp_x.shape)
|
||||||
|
#print(idx[0], tf_idx)
|
||||||
|
#print(smp_x[0,])
|
||||||
|
#x=x.view(-1,3*32*32)
|
||||||
|
#smp_x=smp_x.view(-1,3*32*32)
|
||||||
|
x=x.scatter(dim=0, index=idx, src=smp_x)
|
||||||
|
#x=x.view(-1,3,32,32)
|
||||||
|
#print(x[0,])
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||||
|
|
|
@ -5,9 +5,9 @@ from train_utils import *
|
||||||
|
|
||||||
tf_names = [
|
tf_names = [
|
||||||
## Geometric TF ##
|
## Geometric TF ##
|
||||||
'Identity',
|
#'Identity',
|
||||||
'FlipUD',
|
#'FlipUD',
|
||||||
'FlipLR',
|
#'FlipLR',
|
||||||
'Rotate',
|
'Rotate',
|
||||||
'TranslateX',
|
'TranslateX',
|
||||||
'TranslateY',
|
'TranslateY',
|
||||||
|
@ -37,7 +37,7 @@ else:
|
||||||
##########################################
|
##########################################
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
n_inner_iter = 10
|
n_inner_iter = 1
|
||||||
epochs = 2
|
epochs = 2
|
||||||
dataug_epoch_start=0
|
dataug_epoch_start=0
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#tf_dict = TF.TF_dict
|
#tf_dict = TF.TF_dict
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, glob_mag=False), LeNet(3,10)).to(device)
|
||||||
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||||
|
|
|
@ -623,8 +623,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
|
|
||||||
tf = time.process_time()
|
tf = time.process_time()
|
||||||
|
|
||||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||||
#viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||||
|
|
||||||
if(not high_grad_track):
|
if(not high_grad_track):
|
||||||
countcopy+=1
|
countcopy+=1
|
||||||
|
@ -648,8 +648,9 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
print('Accuracy :', accuracy)
|
print('Accuracy :', accuracy)
|
||||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||||
print('TF Proba :', model['data_aug']['prob'].data)
|
print('TF Proba :', model['data_aug']['prob'].data)
|
||||||
#print('proba grad',aug_model['data_aug']['prob'].grad)
|
#print('proba grad',model['data_aug']['prob'].grad)
|
||||||
print('TF Mag :', model['data_aug']['mag'].data)
|
print('TF Mag :', model['data_aug']['mag'].data)
|
||||||
|
print('Mag grad',model['data_aug']['mag'].grad)
|
||||||
#############
|
#############
|
||||||
#### Log ####
|
#### Log ####
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -28,6 +28,7 @@ TF_dict={ #f(mag_normalise)=mag_reelle
|
||||||
#'Equalize': (lambda mag: None),
|
#'Equalize': (lambda mag: None),
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
|
'''
|
||||||
TF_dict={
|
TF_dict={
|
||||||
## Geometric TF ##
|
## Geometric TF ##
|
||||||
'Identity' : (lambda x, mag: x),
|
'Identity' : (lambda x, mag: x),
|
||||||
|
@ -42,7 +43,7 @@ TF_dict={
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||||
'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=1., maxval=1.9) for _ in x], device=x.device))),
|
'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||||
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||||
'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))),
|
'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))),
|
||||||
'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch
|
'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
@ -51,6 +52,27 @@ TF_dict={
|
||||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||||
#'Equalize': (lambda mag: None),
|
#'Equalize': (lambda mag: None),
|
||||||
}
|
}
|
||||||
|
'''
|
||||||
|
TF_dict={
|
||||||
|
## Geometric TF ##
|
||||||
|
'Identity' : (lambda x, mag: x),
|
||||||
|
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||||
|
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||||
|
'Rotate': (lambda x, mag: rotate(x, angle=rand_float(size=x.shape[0], mag=mag, maxval=30))),
|
||||||
|
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
||||||
|
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
||||||
|
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||||
|
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Color':(lambda x, mag: color(x, color_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Posterize': (lambda x, mag: posterize(x, bits=rand_float(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
|
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_float(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
||||||
return (float_image*255.).type(torch.uint8)
|
return (float_image*255.).type(torch.uint8)
|
||||||
|
@ -71,6 +93,19 @@ def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||||
if not minval : minval = -real_max
|
if not minval : minval = -real_max
|
||||||
return random.uniform(minval, real_max)
|
return random.uniform(minval, real_max)
|
||||||
|
|
||||||
|
def rand_float(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||||
|
real_max = float_parameter(mag, maxval=maxval)
|
||||||
|
if not minval : minval = -real_max
|
||||||
|
#return random.uniform(minval, real_max)
|
||||||
|
return minval +(real_max-minval) * torch.rand(size, device=mag.device)
|
||||||
|
|
||||||
|
def zero_stack(tensor, zero_pos):
|
||||||
|
if zero_pos==0:
|
||||||
|
return torch.stack((tensor, torch.zeros((tensor.shape[0],), device=tensor.device)), dim=1)
|
||||||
|
if zero_pos==1:
|
||||||
|
return torch.stack((torch.zeros((tensor.shape[0],), device=tensor.device), tensor), dim=1)
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid zero_pos : ", zero_pos)
|
||||||
|
|
||||||
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
|
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
|
||||||
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
|
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
|
||||||
|
@ -83,7 +118,9 @@ def float_parameter(level, maxval):
|
||||||
Returns:
|
Returns:
|
||||||
A float that results from scaling `maxval` according to `level`.
|
A float that results from scaling `maxval` according to `level`.
|
||||||
"""
|
"""
|
||||||
return float(level) * maxval / PARAMETER_MAX
|
|
||||||
|
#return float(level) * maxval / PARAMETER_MAX
|
||||||
|
return (level * maxval / PARAMETER_MAX)#.to(torch.float32)
|
||||||
|
|
||||||
def int_parameter(level, maxval):
|
def int_parameter(level, maxval):
|
||||||
"""Helper function to scale `val` between 0 and maxval .
|
"""Helper function to scale `val` between 0 and maxval .
|
||||||
|
@ -94,7 +131,11 @@ def int_parameter(level, maxval):
|
||||||
Returns:
|
Returns:
|
||||||
An int that results from scaling `maxval` according to `level`.
|
An int that results from scaling `maxval` according to `level`.
|
||||||
"""
|
"""
|
||||||
return int(level * maxval / PARAMETER_MAX)
|
#return int(level * maxval / PARAMETER_MAX)
|
||||||
|
print(level)
|
||||||
|
res= (level * maxval / PARAMETER_MAX).to(torch.int8).requires_grad_()#.type(torch.int8)
|
||||||
|
print(res)
|
||||||
|
return res
|
||||||
|
|
||||||
def flipLR(x):
|
def flipLR(x):
|
||||||
device = x.device
|
device = x.device
|
||||||
|
@ -119,10 +160,11 @@ def flipUD(x):
|
||||||
return kornia.warp_perspective(x, M, dsize=(h, w))
|
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||||
|
|
||||||
def rotate(x, angle):
|
def rotate(x, angle):
|
||||||
return kornia.rotate(x, angle=angle.type(torch.float32)) #Kornia ne supporte pas les int
|
return kornia.rotate(x, angle=angle)#.type(torch.float32)) #Kornia ne supporte pas les int
|
||||||
|
|
||||||
def translate(x, translation):
|
def translate(x, translation):
|
||||||
return kornia.translate(x, translation=translation.type(torch.float32)) #Kornia ne supporte pas les int
|
#print(translation)
|
||||||
|
return kornia.translate(x, translation=translation)#.type(torch.float32)) #Kornia ne supporte pas les int
|
||||||
|
|
||||||
def shear(x, shear):
|
def shear(x, shear):
|
||||||
return kornia.shear(x, shear=shear)
|
return kornia.shear(x, shear=shear)
|
||||||
|
@ -156,6 +198,7 @@ def sharpeness(x, sharpness_factor):
|
||||||
|
|
||||||
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
|
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
|
||||||
def posterize(x, bits):
|
def posterize(x, bits):
|
||||||
|
bits = bits.type(torch.uint8) #Perte du gradient
|
||||||
x = int_image(x) #Expect image in the range of [0, 1]
|
x = int_image(x) #Expect image in the range of [0, 1]
|
||||||
|
|
||||||
mask = ~(2 ** (8 - bits) - 1).type(torch.uint8)
|
mask = ~(2 ** (8 - bits) - 1).type(torch.uint8)
|
||||||
|
@ -217,10 +260,25 @@ def equalize(x): #PAS OPTIMISE POUR DES BATCH
|
||||||
|
|
||||||
def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
||||||
# Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B))
|
# Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B))
|
||||||
|
batch_size, channels, h, w = x.shape
|
||||||
|
imgs=[]
|
||||||
for idx, t in enumerate(thresholds): #Operation par image
|
for idx, t in enumerate(thresholds): #Operation par image
|
||||||
mask = x[idx] > t.item()
|
mask = x[idx] > t #Perte du gradient
|
||||||
inv_x = 1-x[idx][mask]
|
#In place
|
||||||
x[idx][mask]=inv_x
|
#inv_x = 1-x[idx][mask]
|
||||||
|
#x[idx][mask]=inv_x
|
||||||
|
#
|
||||||
|
|
||||||
|
#Out of place
|
||||||
|
im = x[idx]
|
||||||
|
inv_x = 1-im[mask]
|
||||||
|
|
||||||
|
imgs.append(im.masked_scatter(mask,inv_x))
|
||||||
|
|
||||||
|
idxs=torch.tensor(range(x.shape[0]), device=x.device)
|
||||||
|
idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
|
x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
|
||||||
|
#
|
||||||
return x
|
return x
|
||||||
|
|
||||||
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
|
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
|
||||||
|
|
|
@ -170,7 +170,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample'):
|
||||||
plt.xticks([])
|
plt.xticks([])
|
||||||
plt.yticks([])
|
plt.yticks([])
|
||||||
plt.grid(False)
|
plt.grid(False)
|
||||||
plt.imshow(sample[i,], cmap=plt.cm.binary)
|
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
|
||||||
plt.xlabel(labels[i].item())
|
plt.xlabel(labels[i].item())
|
||||||
|
|
||||||
plt.savefig(fig_name)
|
plt.savefig(fig_name)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue