2019-11-08 11:28:06 -05:00
import torch
import kornia
import random
### Available TF for Dataug ###
2019-11-14 21:17:54 -05:00
'''
2019-11-18 13:05:50 -05:00
TF_dict = { #Dataugv4
2019-11-14 21:17:54 -05:00
## Geometric TF ##
' Identity ' : ( lambda x , mag : x ) ,
' FlipUD ' : ( lambda x , mag : flipUD ( x ) ) ,
' FlipLR ' : ( lambda x , mag : flipLR ( x ) ) ,
' Rotate ' : ( lambda x , mag : rotate ( x , angle = torch . tensor ( [ rand_int ( mag , maxval = 30 ) for _ in x ] , device = x . device ) ) ) ,
' TranslateX ' : ( lambda x , mag : translate ( x , translation = torch . tensor ( [ [ rand_int ( mag , maxval = 20 ) , 0 ] for _ in x ] , device = x . device ) ) ) ,
' TranslateY ' : ( lambda x , mag : translate ( x , translation = torch . tensor ( [ [ 0 , rand_int ( mag , maxval = 20 ) ] for _ in x ] , device = x . device ) ) ) ,
' ShearX ' : ( lambda x , mag : shear ( x , shear = torch . tensor ( [ [ rand_float ( mag , maxval = 0.3 ) , 0 ] for _ in x ] , device = x . device ) ) ) ,
' ShearY ' : ( lambda x , mag : shear ( x , shear = torch . tensor ( [ [ 0 , rand_float ( mag , maxval = 0.3 ) ] for _ in x ] , device = x . device ) ) ) ,
## Color TF (Expect image in the range of [0, 1]) ##
' Contrast ' : ( lambda x , mag : contrast ( x , contrast_factor = torch . tensor ( [ rand_float ( mag , minval = 0.1 , maxval = 1.9 ) for _ in x ] , device = x . device ) ) ) ,
' Color ' : ( lambda x , mag : color ( x , color_factor = torch . tensor ( [ rand_float ( mag , minval = 0.1 , maxval = 1.9 ) for _ in x ] , device = x . device ) ) ) ,
2019-11-18 12:53:23 -05:00
' Brightness ' : ( lambda x , mag : brightness ( x , brightness_factor = torch . tensor ( [ rand_float ( mag , minval = 0.1 , maxval = 1.9 ) for _ in x ] , device = x . device ) ) ) ,
2019-11-14 21:17:54 -05:00
' Sharpness ' : ( lambda x , mag : sharpeness ( x , sharpness_factor = torch . tensor ( [ rand_float ( mag , minval = 0.1 , maxval = 1.9 ) for _ in x ] , device = x . device ) ) ) ,
' Posterize ' : ( lambda x , mag : posterize ( x , bits = torch . tensor ( [ rand_int ( mag , minval = 4 , maxval = 8 ) for _ in x ] , device = x . device ) ) ) ,
' Solarize ' : ( lambda x , mag : solarize ( x , thresholds = torch . tensor ( [ rand_int ( mag , minval = 1 , maxval = 256 ) / 256. for _ in x ] , device = x . device ) ) ) , #=>Image entre [0,1] #Pas opti pour des batch
2019-11-08 11:28:06 -05:00
2019-11-14 21:17:54 -05:00
#Non fonctionnel
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None),
}
2019-11-18 12:53:23 -05:00
'''
2019-11-18 13:05:50 -05:00
TF_dict = { #Dataugv5
2019-11-18 12:53:23 -05:00
## Geometric TF ##
' Identity ' : ( lambda x , mag : x ) ,
' FlipUD ' : ( lambda x , mag : flipUD ( x ) ) ,
' FlipLR ' : ( lambda x , mag : flipLR ( x ) ) ,
2019-11-18 13:05:50 -05:00
' Rotate ' : ( lambda x , mag : rotate ( x , angle = rand_floats ( size = x . shape [ 0 ] , mag = mag , maxval = 30 ) ) ) ,
' TranslateX ' : ( lambda x , mag : translate ( x , translation = zero_stack ( rand_floats ( size = ( x . shape [ 0 ] , ) , mag = mag , maxval = 20 ) , zero_pos = 0 ) ) ) ,
' TranslateY ' : ( lambda x , mag : translate ( x , translation = zero_stack ( rand_floats ( size = ( x . shape [ 0 ] , ) , mag = mag , maxval = 20 ) , zero_pos = 1 ) ) ) ,
' ShearX ' : ( lambda x , mag : shear ( x , shear = zero_stack ( rand_floats ( size = ( x . shape [ 0 ] , ) , mag = mag , maxval = 0.3 ) , zero_pos = 0 ) ) ) ,
' ShearY ' : ( lambda x , mag : shear ( x , shear = zero_stack ( rand_floats ( size = ( x . shape [ 0 ] , ) , mag = mag , maxval = 0.3 ) , zero_pos = 1 ) ) ) ,
2019-11-18 12:53:23 -05:00
## Color TF (Expect image in the range of [0, 1]) ##
2019-11-18 13:05:50 -05:00
' Contrast ' : ( lambda x , mag : contrast ( x , contrast_factor = rand_floats ( size = x . shape [ 0 ] , mag = mag , minval = 0.1 , maxval = 1.9 ) ) ) ,
' Color ' : ( lambda x , mag : color ( x , color_factor = rand_floats ( size = x . shape [ 0 ] , mag = mag , minval = 0.1 , maxval = 1.9 ) ) ) ,
' Brightness ' : ( lambda x , mag : brightness ( x , brightness_factor = rand_floats ( size = x . shape [ 0 ] , mag = mag , minval = 0.1 , maxval = 1.9 ) ) ) ,
' Sharpness ' : ( lambda x , mag : sharpeness ( x , sharpness_factor = rand_floats ( size = x . shape [ 0 ] , mag = mag , minval = 0.1 , maxval = 1.9 ) ) ) ,
' Posterize ' : ( lambda x , mag : posterize ( x , bits = rand_floats ( size = x . shape [ 0 ] , mag = mag , minval = 4. , maxval = 8. ) ) ) , #Perte du gradient
' Solarize ' : ( lambda x , mag : solarize ( x , thresholds = rand_floats ( size = x . shape [ 0 ] , mag = mag , minval = 1 / 256. , maxval = 256 / 256. ) ) ) , #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
#Non fonctionnel
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None),
2019-11-18 12:53:23 -05:00
}
2019-11-08 11:28:06 -05:00
def int_image ( float_image ) : #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
return ( float_image * 255. ) . type ( torch . uint8 )
def float_image ( int_image ) :
return int_image . type ( torch . float ) / 255.
2019-11-14 21:17:54 -05:00
#def rand_inverse(value):
# return value if random.random() < 0.5 else -value
def rand_int ( mag , maxval , minval = None ) : #[(-maxval,minval), maxval]
real_max = int_parameter ( mag , maxval = maxval )
if not minval : minval = - real_max
return random . randint ( minval , real_max )
def rand_float ( mag , maxval , minval = None ) : #[(-maxval,minval), maxval]
real_max = float_parameter ( mag , maxval = maxval )
if not minval : minval = - real_max
return random . uniform ( minval , real_max )
2019-11-18 13:05:50 -05:00
def rand_floats ( size , mag , maxval , minval = None ) : #[(-maxval,minval), maxval]
2019-11-18 12:53:23 -05:00
real_max = float_parameter ( mag , maxval = maxval )
if not minval : minval = - real_max
#return random.uniform(minval, real_max)
return minval + ( real_max - minval ) * torch . rand ( size , device = mag . device )
def zero_stack ( tensor , zero_pos ) :
if zero_pos == 0 :
return torch . stack ( ( tensor , torch . zeros ( ( tensor . shape [ 0 ] , ) , device = tensor . device ) ) , dim = 1 )
if zero_pos == 1 :
return torch . stack ( ( torch . zeros ( ( tensor . shape [ 0 ] , ) , device = tensor . device ) , tensor ) , dim = 1 )
else :
raise Exception ( " Invalid zero_pos : " , zero_pos )
2019-11-08 11:28:06 -05:00
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
2019-11-18 14:18:15 -05:00
PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted
2019-11-08 11:28:06 -05:00
def float_parameter ( level , maxval ) :
""" Helper function to scale `val` between 0 and maxval .
Args :
level : Level of the operation that will be between [ 0 , ` PARAMETER_MAX ` ] .
maxval : Maximum value that the operation can have . This will be scaled
to level / PARAMETER_MAX .
Returns :
A float that results from scaling ` maxval ` according to ` level ` .
"""
2019-11-18 12:53:23 -05:00
#return float(level) * maxval / PARAMETER_MAX
2019-11-18 14:18:15 -05:00
return ( level * maxval / PARAMETER_MAX ) #.to(torch.float)
2019-11-08 11:28:06 -05:00
2019-11-18 13:05:50 -05:00
def int_parameter ( level , maxval ) : #Perte de gradient
2019-11-08 11:28:06 -05:00
""" Helper function to scale `val` between 0 and maxval .
Args :
level : Level of the operation that will be between [ 0 , ` PARAMETER_MAX ` ] .
maxval : Maximum value that the operation can have . This will be scaled
to level / PARAMETER_MAX .
Returns :
An int that results from scaling ` maxval ` according to ` level ` .
"""
2019-11-18 12:53:23 -05:00
#return int(level * maxval / PARAMETER_MAX)
2019-11-18 13:05:50 -05:00
return ( level * maxval / PARAMETER_MAX )
2019-11-08 11:28:06 -05:00
def flipLR ( x ) :
device = x . device
( batch_size , channels , h , w ) = x . shape
M = torch . tensor ( [ [ [ - 1. , 0. , w - 1 ] ,
[ 0. , 1. , 0. ] ,
[ 0. , 0. , 1. ] ] ] , device = device ) . expand ( batch_size , - 1 , - 1 )
# warp the original image by the found transform
return kornia . warp_perspective ( x , M , dsize = ( h , w ) )
def flipUD ( x ) :
device = x . device
( batch_size , channels , h , w ) = x . shape
M = torch . tensor ( [ [ [ 1. , 0. , 0. ] ,
[ 0. , - 1. , h - 1 ] ,
[ 0. , 0. , 1. ] ] ] , device = device ) . expand ( batch_size , - 1 , - 1 )
# warp the original image by the found transform
return kornia . warp_perspective ( x , M , dsize = ( h , w ) )
def rotate ( x , angle ) :
2019-11-18 14:18:15 -05:00
return kornia . rotate ( x , angle = angle . type ( torch . float ) ) #Kornia ne supporte pas les int
2019-11-08 11:28:06 -05:00
def translate ( x , translation ) :
2019-11-18 12:53:23 -05:00
#print(translation)
2019-11-18 14:18:15 -05:00
return kornia . translate ( x , translation = translation . type ( torch . float ) ) #Kornia ne supporte pas les int
2019-11-08 11:28:06 -05:00
def shear ( x , shear ) :
return kornia . shear ( x , shear = shear )
def contrast ( x , contrast_factor ) :
return kornia . adjust_contrast ( x , contrast_factor = contrast_factor ) #Expect image in the range of [0, 1]
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageEnhance.py
def color ( x , color_factor ) :
( batch_size , channels , h , w ) = x . shape
gray_x = kornia . rgb_to_grayscale ( x )
gray_x = gray_x . repeat_interleave ( channels , dim = 1 )
return blend ( gray_x , x , color_factor ) . clamp ( min = 0.0 , max = 1.0 ) #Expect image in the range of [0, 1]
def brightness ( x , brightness_factor ) :
device = x . device
return blend ( torch . zeros ( x . size ( ) , device = device ) , x , brightness_factor ) . clamp ( min = 0.0 , max = 1.0 ) #Expect image in the range of [0, 1]
def sharpeness ( x , sharpness_factor ) :
device = x . device
( batch_size , channels , h , w ) = x . shape
k = torch . tensor ( [ [ [ 1. , 1. , 1. ] ,
[ 1. , 5. , 1. ] ,
[ 1. , 1. , 1. ] ] ] , device = device ) #Smooth Filter : https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageFilter.py
smooth_x = kornia . filter2D ( x , kernel = k , border_type = ' reflect ' , normalized = True ) #Peut etre necessaire de s'occuper du channel Alhpa differement
return blend ( smooth_x , x , sharpness_factor ) . clamp ( min = 0.0 , max = 1.0 ) #Expect image in the range of [0, 1]
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
def posterize ( x , bits ) :
2019-11-18 12:53:23 -05:00
bits = bits . type ( torch . uint8 ) #Perte du gradient
2019-11-08 11:28:06 -05:00
x = int_image ( x ) #Expect image in the range of [0, 1]
mask = ~ ( 2 * * ( 8 - bits ) - 1 ) . type ( torch . uint8 )
( batch_size , channels , h , w ) = x . shape
mask = mask . unsqueeze ( dim = 1 ) . expand ( - 1 , channels ) . unsqueeze ( dim = 2 ) . expand ( - 1 , channels , h ) . unsqueeze ( dim = 3 ) . expand ( - 1 , channels , h , w ) #Il y a forcement plus simple ...
return float_image ( x & mask )
def auto_contrast ( x ) : #PAS OPTIMISE POUR DES BATCH #EXTRA LENT
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
print ( " Warning : Pas encore check ! " )
( batch_size , channels , h , w ) = x . shape
x = int_image ( x ) #Expect image in the range of [0, 1]
#print('Start',x[0])
for im_idx , img in enumerate ( x . chunk ( batch_size , dim = 0 ) ) : #Operation par image
#print(img.shape)
for chan_idx , chan in enumerate ( img . chunk ( channels , dim = 1 ) ) : # Operation par channel
#print(chan.shape)
hist = torch . histc ( chan , bins = 256 , min = 0 , max = 255 ) #PAS DIFFERENTIABLE
# find lowest/highest samples after preprocessing
for lo in range ( 256 ) :
if hist [ lo ] :
break
for hi in range ( 255 , - 1 , - 1 ) :
if hist [ hi ] :
break
if hi < = lo :
# don't bother
pass
else :
scale = 255.0 / ( hi - lo )
offset = - lo * scale
for ix in range ( 256 ) :
n_ix = int ( ix * scale + offset )
if n_ix < 0 : n_ix = 0
elif n_ix > 255 : n_ix = 255
chan [ chan == ix ] = n_ix
x [ im_idx , chan_idx ] = chan
#print('End',x[0])
return float_image ( x )
def equalize ( x ) : #PAS OPTIMISE POUR DES BATCH
raise Exception ( self , " not implemented " )
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
( batch_size , channels , h , w ) = x . shape
x = int_image ( x ) #Expect image in the range of [0, 1]
#print('Start',x[0])
for im_idx , img in enumerate ( x . chunk ( batch_size , dim = 0 ) ) : #Operation par image
#print(img.shape)
for chan_idx , chan in enumerate ( img . chunk ( channels , dim = 1 ) ) : # Operation par channel
#print(chan.shape)
hist = torch . histc ( chan , bins = 256 , min = 0 , max = 255 ) #PAS DIFFERENTIABLE
return float_image ( x )
def solarize ( x , thresholds ) : #PAS OPTIMISE POUR DES BATCH
# Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B))
2019-11-18 12:53:23 -05:00
batch_size , channels , h , w = x . shape
imgs = [ ]
2019-11-08 11:28:06 -05:00
for idx , t in enumerate ( thresholds ) : #Operation par image
2019-11-18 12:53:23 -05:00
mask = x [ idx ] > t #Perte du gradient
#In place
#inv_x = 1-x[idx][mask]
#x[idx][mask]=inv_x
#
#Out of place
im = x [ idx ]
inv_x = 1 - im [ mask ]
imgs . append ( im . masked_scatter ( mask , inv_x ) )
idxs = torch . tensor ( range ( x . shape [ 0 ] ) , device = x . device )
idxs = idxs . unsqueeze ( dim = 1 ) . expand ( - 1 , channels ) . unsqueeze ( dim = 2 ) . expand ( - 1 , channels , h ) . unsqueeze ( dim = 3 ) . expand ( - 1 , channels , h , w ) #Il y a forcement plus simple ...
x = x . scatter ( dim = 0 , index = idxs , src = torch . stack ( imgs ) )
#
2019-11-08 11:28:06 -05:00
return x
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
def blend ( x , y , alpha ) : #out = image1 * (1.0 - alpha) + image2 * alpha
#return kornia.add_weighted(src1=x, alpha=(1-alpha), src2=y, beta=alpha, gamma=0) #out=src1∗ alpha+src2∗ beta+gamma #Ne fonctionne pas pour des batch de alpha
if not isinstance ( x , torch . Tensor ) :
raise TypeError ( " x should be a tensor. Got {} " . format ( type ( x ) ) )
if not isinstance ( y , torch . Tensor ) :
raise TypeError ( " y should be a tensor. Got {} " . format ( type ( y ) ) )
( batch_size , channels , h , w ) = x . shape
alpha = alpha . unsqueeze ( dim = 1 ) . expand ( - 1 , channels ) . unsqueeze ( dim = 2 ) . expand ( - 1 , channels , h ) . unsqueeze ( dim = 3 ) . expand ( - 1 , channels , h , w ) #Il y a forcement plus simple ...
res = x * ( 1 - alpha ) + y * alpha
return res