Transformations comments

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-22 17:30:27 -05:00
parent dc18397660
commit da711d17cd

View file

@ -18,11 +18,18 @@ import torch
import kornia import kornia
import random import random
### Available TF for Dataug ###
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} #TF that don't have use for magnitude parameter.
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} #TF which implemetation doesn't allow gradient propagaition.
TF_ignore_mag= TF_no_mag | TF_no_grad #TF for which magnitude should be ignored (Magnitude fixed).
PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted
PARAMETER_MIN = 0.1 # What is the min 'level' a transform could be predicted
### Available TF for Dataug ###
# Dictionnary mapping tranformations identifiers to their function. # Dictionnary mapping tranformations identifiers to their function.
# Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data. # Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data.
TF_dict={ #Dataugv5 TF_dict={ #Dataugv5+
## Geometric TF ## ## Geometric TF ##
'Identity' : (lambda x, mag: x), 'Identity' : (lambda x, mag: x),
'FlipUD' : (lambda x, mag: flipUD(x)), 'FlipUD' : (lambda x, mag: flipUD(x)),
@ -70,18 +77,12 @@ TF_dict={ #Dataugv5
'Random':(lambda x, mag: torch.rand_like(x)), 'Random':(lambda x, mag: torch.rand_like(x)),
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))), 'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))),
#Non fonctionnel #Not ready for use
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None), #'Equalize': (lambda mag: None),
} }
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} #TF that don't have use for magnitude parameter. ## Image type cast ##
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} #TF which implemetation doesn't allow gradient propagaition.
TF_ignore_mag= TF_no_mag | TF_no_grad #TF for which magnitude should be ignored (Magnitude fixed).
PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted
PARAMETER_MIN = 0.1 # What is the min 'level' a transform could be predicted
def int_image(float_image): def int_image(float_image):
"""Convert a float Tensor/Image to an int Tensor/Image. """Convert a float Tensor/Image to an int Tensor/Image.
@ -109,19 +110,7 @@ def float_image(int_image):
""" """
return int_image.type(torch.float)/255. return int_image.type(torch.float)/255.
#def rand_inverse(value): ## Parameters utils ##
# return value if random.random() < 0.5 else -value
#def rand_int(mag, maxval, minval=None): #[(-maxval,minval), maxval]
# real_max = int_parameter(mag, maxval=maxval)
# if not minval : minval = -real_max
# return random.randint(minval, real_max)
#def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval]
# real_max = float_parameter(mag, maxval=maxval)
# if not minval : minval = -real_max
# return random.uniform(minval, real_max)
def rand_floats(size, mag, maxval, minval=None): def rand_floats(size, mag, maxval, minval=None):
"""Generate a batch of random values. """Generate a batch of random values.
@ -188,20 +177,9 @@ def float_parameter(level, maxval):
""" """
#return float(level) * maxval / PARAMETER_MAX #return float(level) * maxval / PARAMETER_MAX
return (level * maxval / PARAMETER_MAX)#.to(torch.float) return (level * maxval / PARAMETER_MAX)#.to(torch.float)
#def int_parameter(level, maxval): #Perte de gradient
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
An int that results from scaling `maxval` according to `level`.
"""
#return int(level * maxval / PARAMETER_MAX)
# return (level * maxval / PARAMETER_MAX)
## Tranformations ##
def flipLR(x): def flipLR(x):
"""Flip horizontaly/Left-Right images. """Flip horizontaly/Left-Right images.
@ -343,16 +321,101 @@ def sharpeness(x, sharpness_factor):
return blend(smooth_x, x, sharpness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1] return blend(smooth_x, x, sharpness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
def posterize(x, bits): def posterize(x, bits):
bits = bits.type(torch.uint8) #Perte du gradient """Reduce the number of bits for each color channel.
x = int_image(x) #Expect image in the range of [0, 1]
mask = ~(2 ** (8 - bits) - 1).type(torch.uint8) Be warry that the cast to integers block the gradient propagation.
Args:
x (Tensor): Batch of images.
bits (Tensor): The number of bits to keep for each channel (1-8).
(batch_size, channels, h, w) = x.shape Returns:
mask = mask.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... (Tensor): Batch of posterized images.
"""
bits = bits.type(torch.uint8) #Perte du gradient
x = int_image(x) #Expect image in the range of [0, 1]
return float_image(x & mask) mask = ~(2 ** (8 - bits) - 1).type(torch.uint8)
(batch_size, channels, h, w) = x.shape
mask = mask.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
return float_image(x & mask)
def solarize(x, thresholds):
"""Invert all pixel values above a threshold.
Be warry that the use of the inequality (x>tresholds) block the gradient propagation.
Args:
x (Tensor): Batch of images.
thresholds (Tensor): All pixels above this level are inverted
Returns:
(Tensor): Batch of solarized images.
"""
batch_size, channels, h, w = x.shape
#imgs=[]
#for idx, t in enumerate(thresholds): #Operation par image
# mask = x[idx] > t #Perte du gradient
#In place
# inv_x = 1-x[idx][mask]
# x[idx][mask]=inv_x
#
#Out of place
# im = x[idx]
# inv_x = 1-im[mask]
# imgs.append(im.masked_scatter(mask,inv_x))
#idxs=torch.tensor(range(x.shape[0]), device=x.device)
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
#
thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
x=torch.where(x>thresholds,1-x, x)
#x=x.min(thresholds)
#inv_x = 1-x[mask]
#x=x.where(x<thresholds,1-x)
#x[mask]=inv_x
#x=x.masked_scatter(mask, inv_x)
return x
def blend(x,y,alpha):
"""Creates a new images by interpolating between two input images, using a constant alpha.
x and y should have the same size.
alpha should have the same batch size as the images.
Apply batch wise :
out = image1 * (1.0 - alpha) + image2 * alpha
Args:
x (Tensor): Batch of images.
y (Tensor): Batch of images.
alpha (Tensor): The interpolation alpha factor for each images.
Returns:
(Tensor): Batch of solarized images.
"""
#return kornia.add_weighted(src1=x, alpha=(1-alpha), src2=y, beta=alpha, gamma=0) #out=src1alpha+src2beta+gamma #Ne fonctionne pas pour des batch de alpha
if not isinstance(x, torch.Tensor):
raise TypeError("x should be a tensor. Got {}".format(type(x)))
if not isinstance(y, torch.Tensor):
raise TypeError("y should be a tensor. Got {}".format(type(y)))
assert(x.shape==y.shape and x.shape[0]==alpha.shape[0])
(batch_size, channels, h, w) = x.shape
alpha = alpha.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
res = x*(1-alpha) + y*alpha
return res
#Not working
def auto_contrast(x): #PAS OPTIMISE POUR DES BATCH #EXTRA LENT def auto_contrast(x): #PAS OPTIMISE POUR DES BATCH #EXTRA LENT
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel # Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
print("Warning : Pas encore check !") print("Warning : Pas encore check !")
@ -401,53 +464,4 @@ def equalize(x): #PAS OPTIMISE POUR DES BATCH
#print(chan.shape) #print(chan.shape)
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
return float_image(x) return float_image(x)
def solarize(x, thresholds):
batch_size, channels, h, w = x.shape
#imgs=[]
#for idx, t in enumerate(thresholds): #Operation par image
# mask = x[idx] > t #Perte du gradient
#In place
# inv_x = 1-x[idx][mask]
# x[idx][mask]=inv_x
#
#Out of place
# im = x[idx]
# inv_x = 1-im[mask]
# imgs.append(im.masked_scatter(mask,inv_x))
#idxs=torch.tensor(range(x.shape[0]), device=x.device)
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
#
thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
#print(thresholds.grad_fn)
x=torch.where(x>thresholds,1-x, x)
#print(mask.grad_fn)
#x=x.min(thresholds)
#inv_x = 1-x[mask]
#x=x.where(x<thresholds,1-x)
#x[mask]=inv_x
#x=x.masked_scatter(mask, inv_x)
return x
def blend(x,y,alpha): #out = image1 * (1.0 - alpha) + image2 * alpha
#return kornia.add_weighted(src1=x, alpha=(1-alpha), src2=y, beta=alpha, gamma=0) #out=src1alpha+src2beta+gamma #Ne fonctionne pas pour des batch de alpha
if not isinstance(x, torch.Tensor):
raise TypeError("x should be a tensor. Got {}".format(type(x)))
if not isinstance(y, torch.Tensor):
raise TypeError("y should be a tensor. Got {}".format(type(y)))
(batch_size, channels, h, w) = x.shape
alpha = alpha.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
res = x*(1-alpha) + y*alpha
return res