Dataugv5- Modification des TF pour propagation du gradient (mag)

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-18 12:53:23 -05:00
parent 05f81787d6
commit 994d657a28
5 changed files with 94 additions and 21 deletions

View file

@ -28,6 +28,7 @@ TF_dict={ #f(mag_normalise)=mag_reelle
#'Equalize': (lambda mag: None),
}
'''
'''
TF_dict={
## Geometric TF ##
'Identity' : (lambda x, mag: x),
@ -42,7 +43,7 @@ TF_dict={
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=1., maxval=1.9) for _ in x], device=x.device))),
'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))),
'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch
@ -51,6 +52,27 @@ TF_dict={
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None),
}
'''
TF_dict={
## Geometric TF ##
'Identity' : (lambda x, mag: x),
'FlipUD' : (lambda x, mag: flipUD(x)),
'FlipLR' : (lambda x, mag: flipLR(x)),
'Rotate': (lambda x, mag: rotate(x, angle=rand_float(size=x.shape[0], mag=mag, maxval=30))),
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Color':(lambda x, mag: color(x, color_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Posterize': (lambda x, mag: posterize(x, bits=rand_float(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_float(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
}
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
return (float_image*255.).type(torch.uint8)
@ -71,6 +93,19 @@ def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval]
if not minval : minval = -real_max
return random.uniform(minval, real_max)
def rand_float(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
real_max = float_parameter(mag, maxval=maxval)
if not minval : minval = -real_max
#return random.uniform(minval, real_max)
return minval +(real_max-minval) * torch.rand(size, device=mag.device)
def zero_stack(tensor, zero_pos):
if zero_pos==0:
return torch.stack((tensor, torch.zeros((tensor.shape[0],), device=tensor.device)), dim=1)
if zero_pos==1:
return torch.stack((torch.zeros((tensor.shape[0],), device=tensor.device), tensor), dim=1)
else:
raise Exception("Invalid zero_pos : ", zero_pos)
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
@ -83,7 +118,9 @@ def float_parameter(level, maxval):
Returns:
A float that results from scaling `maxval` according to `level`.
"""
return float(level) * maxval / PARAMETER_MAX
#return float(level) * maxval / PARAMETER_MAX
return (level * maxval / PARAMETER_MAX)#.to(torch.float32)
def int_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
@ -94,7 +131,11 @@ def int_parameter(level, maxval):
Returns:
An int that results from scaling `maxval` according to `level`.
"""
return int(level * maxval / PARAMETER_MAX)
#return int(level * maxval / PARAMETER_MAX)
print(level)
res= (level * maxval / PARAMETER_MAX).to(torch.int8).requires_grad_()#.type(torch.int8)
print(res)
return res
def flipLR(x):
device = x.device
@ -119,10 +160,11 @@ def flipUD(x):
return kornia.warp_perspective(x, M, dsize=(h, w))
def rotate(x, angle):
return kornia.rotate(x, angle=angle.type(torch.float32)) #Kornia ne supporte pas les int
return kornia.rotate(x, angle=angle)#.type(torch.float32)) #Kornia ne supporte pas les int
def translate(x, translation):
return kornia.translate(x, translation=translation.type(torch.float32)) #Kornia ne supporte pas les int
#print(translation)
return kornia.translate(x, translation=translation)#.type(torch.float32)) #Kornia ne supporte pas les int
def shear(x, shear):
return kornia.shear(x, shear=shear)
@ -156,6 +198,7 @@ def sharpeness(x, sharpness_factor):
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
def posterize(x, bits):
bits = bits.type(torch.uint8) #Perte du gradient
x = int_image(x) #Expect image in the range of [0, 1]
mask = ~(2 ** (8 - bits) - 1).type(torch.uint8)
@ -217,10 +260,25 @@ def equalize(x): #PAS OPTIMISE POUR DES BATCH
def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
# Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B))
batch_size, channels, h, w = x.shape
imgs=[]
for idx, t in enumerate(thresholds): #Operation par image
mask = x[idx] > t.item()
inv_x = 1-x[idx][mask]
x[idx][mask]=inv_x
mask = x[idx] > t #Perte du gradient
#In place
#inv_x = 1-x[idx][mask]
#x[idx][mask]=inv_x
#
#Out of place
im = x[idx]
inv_x = 1-im[mask]
imgs.append(im.masked_scatter(mask,inv_x))
idxs=torch.tensor(range(x.shape[0]), device=x.device)
idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
#
return x
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818