From da711d17cdac390e75ed6e7d3d3b5c14965f814e Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Wed, 22 Jan 2020 17:30:27 -0500 Subject: [PATCH] Transformations comments --- higher/transformations.py | 198 ++++++++++++++++++++------------------ 1 file changed, 106 insertions(+), 92 deletions(-) diff --git a/higher/transformations.py b/higher/transformations.py index c6cefde..f7b661e 100755 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -18,11 +18,18 @@ import torch import kornia import random -### Available TF for Dataug ### +TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} #TF that don't have use for magnitude parameter. +TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} #TF which implemetation doesn't allow gradient propagaition. +TF_ignore_mag= TF_no_mag | TF_no_grad #TF for which magnitude should be ignored (Magnitude fixed). + +PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted +PARAMETER_MIN = 0.1 # What is the min 'level' a transform could be predicted + +### Available TF for Dataug ### # Dictionnary mapping tranformations identifiers to their function. # Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data. -TF_dict={ #Dataugv5 +TF_dict={ #Dataugv5+ ## Geometric TF ## 'Identity' : (lambda x, mag: x), 'FlipUD' : (lambda x, mag: flipUD(x)), @@ -70,18 +77,12 @@ TF_dict={ #Dataugv5 'Random':(lambda x, mag: torch.rand_like(x)), 'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))), - #Non fonctionnel + #Not ready for use #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) #'Equalize': (lambda mag: None), } -TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} #TF that don't have use for magnitude parameter. -TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} #TF which implemetation doesn't allow gradient propagaition. -TF_ignore_mag= TF_no_mag | TF_no_grad #TF for which magnitude should be ignored (Magnitude fixed). - -PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted -PARAMETER_MIN = 0.1 # What is the min 'level' a transform could be predicted - +## Image type cast ## def int_image(float_image): """Convert a float Tensor/Image to an int Tensor/Image. @@ -109,19 +110,7 @@ def float_image(int_image): """ return int_image.type(torch.float)/255. -#def rand_inverse(value): -# return value if random.random() < 0.5 else -value - -#def rand_int(mag, maxval, minval=None): #[(-maxval,minval), maxval] -# real_max = int_parameter(mag, maxval=maxval) -# if not minval : minval = -real_max -# return random.randint(minval, real_max) - -#def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval] -# real_max = float_parameter(mag, maxval=maxval) -# if not minval : minval = -real_max -# return random.uniform(minval, real_max) - +## Parameters utils ## def rand_floats(size, mag, maxval, minval=None): """Generate a batch of random values. @@ -188,20 +177,9 @@ def float_parameter(level, maxval): """ #return float(level) * maxval / PARAMETER_MAX - return (level * maxval / PARAMETER_MAX)#.to(torch.float) - -#def int_parameter(level, maxval): #Perte de gradient - """Helper function to scale `val` between 0 and maxval . - Args: - level: Level of the operation that will be between [0, `PARAMETER_MAX`]. - maxval: Maximum value that the operation can have. This will be scaled - to level/PARAMETER_MAX. - Returns: - An int that results from scaling `maxval` according to `level`. - """ - #return int(level * maxval / PARAMETER_MAX) - # return (level * maxval / PARAMETER_MAX) + return (level * maxval / PARAMETER_MAX)#.to(torch.float) +## Tranformations ## def flipLR(x): """Flip horizontaly/Left-Right images. @@ -343,16 +321,101 @@ def sharpeness(x, sharpness_factor): return blend(smooth_x, x, sharpness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1] def posterize(x, bits): - bits = bits.type(torch.uint8) #Perte du gradient - x = int_image(x) #Expect image in the range of [0, 1] + """Reduce the number of bits for each color channel. - mask = ~(2 ** (8 - bits) - 1).type(torch.uint8) + Be warry that the cast to integers block the gradient propagation. + Args: + x (Tensor): Batch of images. + bits (Tensor): The number of bits to keep for each channel (1-8). - (batch_size, channels, h, w) = x.shape - mask = mask.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + Returns: + (Tensor): Batch of posterized images. + """ + bits = bits.type(torch.uint8) #Perte du gradient + x = int_image(x) #Expect image in the range of [0, 1] - return float_image(x & mask) + mask = ~(2 ** (8 - bits) - 1).type(torch.uint8) + (batch_size, channels, h, w) = x.shape + mask = mask.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + + return float_image(x & mask) + +def solarize(x, thresholds): + """Invert all pixel values above a threshold. + + Be warry that the use of the inequality (x>tresholds) block the gradient propagation. + Args: + x (Tensor): Batch of images. + thresholds (Tensor): All pixels above this level are inverted + + Returns: + (Tensor): Batch of solarized images. + """ + batch_size, channels, h, w = x.shape + #imgs=[] + #for idx, t in enumerate(thresholds): #Operation par image + # mask = x[idx] > t #Perte du gradient + #In place + # inv_x = 1-x[idx][mask] + # x[idx][mask]=inv_x + # + + #Out of place + # im = x[idx] + # inv_x = 1-im[mask] + + # imgs.append(im.masked_scatter(mask,inv_x)) + + #idxs=torch.tensor(range(x.shape[0]), device=x.device) + #idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + #x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs)) + # + + thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + x=torch.where(x>thresholds,1-x, x) + + #x=x.min(thresholds) + #inv_x = 1-x[mask] + #x=x.where(x t #Perte du gradient - #In place - # inv_x = 1-x[idx][mask] - # x[idx][mask]=inv_x - # - - #Out of place - # im = x[idx] - # inv_x = 1-im[mask] - - # imgs.append(im.masked_scatter(mask,inv_x)) - - #idxs=torch.tensor(range(x.shape[0]), device=x.device) - #idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... - #x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs)) - # - - thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... - #print(thresholds.grad_fn) - x=torch.where(x>thresholds,1-x, x) - #print(mask.grad_fn) - - #x=x.min(thresholds) - #inv_x = 1-x[mask] - #x=x.where(x