mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Ajout RandAugment
This commit is contained in:
parent
3c2022de32
commit
4a7e73088d
4 changed files with 249 additions and 37 deletions
|
@ -64,6 +64,27 @@ TF_dict={ #Dataugv5
|
|||
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
||||
|
||||
#Color TF (Common mag scale)
|
||||
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
||||
|
||||
|
||||
'BRotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30*3))),
|
||||
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=0))),
|
||||
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=1))),
|
||||
|
@ -74,14 +95,11 @@ TF_dict={ #Dataugv5
|
|||
'BadTranslateX_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=0))),
|
||||
'BadTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=1))),
|
||||
'BadTranslateY_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=1))),
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
||||
|
||||
'BadColor':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||
'BadSharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||
|
||||
#Non fonctionnel
|
||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||
|
@ -111,10 +129,15 @@ def float_image(int_image):
|
|||
# return random.uniform(minval, real_max)
|
||||
|
||||
def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||
real_max = float_parameter(mag, maxval=maxval)
|
||||
if not minval : minval = -real_max
|
||||
real_mag = float_parameter(mag, maxval=maxval)
|
||||
if not minval : minval = -real_mag
|
||||
#return random.uniform(minval, real_max)
|
||||
return minval +(real_max-minval) * torch.rand(size, device=mag.device)
|
||||
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
|
||||
|
||||
def invScale_rand_floats(size, mag, maxval, minval):
|
||||
#Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval]
|
||||
real_mag = float_parameter(float(PARAMETER_MAX) - mag, maxval=maxval-minval)+minval
|
||||
return real_mag + (maxval-real_mag) * torch.rand(size, device=mag.device) #[real_mag, max_val]
|
||||
|
||||
def zero_stack(tensor, zero_pos):
|
||||
if zero_pos==0:
|
||||
|
@ -139,7 +162,7 @@ def float_parameter(level, maxval):
|
|||
#return float(level) * maxval / PARAMETER_MAX
|
||||
return (level * maxval / PARAMETER_MAX)#.to(torch.float)
|
||||
|
||||
def int_parameter(level, maxval): #Perte de gradient
|
||||
#def int_parameter(level, maxval): #Perte de gradient
|
||||
"""Helper function to scale `val` between 0 and maxval .
|
||||
Args:
|
||||
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
||||
|
@ -149,7 +172,7 @@ def int_parameter(level, maxval): #Perte de gradient
|
|||
An int that results from scaling `maxval` according to `level`.
|
||||
"""
|
||||
#return int(level * maxval / PARAMETER_MAX)
|
||||
return (level * maxval / PARAMETER_MAX)
|
||||
# return (level * maxval / PARAMETER_MAX)
|
||||
|
||||
def flipLR(x):
|
||||
device = x.device
|
||||
|
@ -279,19 +302,19 @@ def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
|||
for idx, t in enumerate(thresholds): #Operation par image
|
||||
mask = x[idx] > t #Perte du gradient
|
||||
#In place
|
||||
#inv_x = 1-x[idx][mask]
|
||||
#x[idx][mask]=inv_x
|
||||
inv_x = 1-x[idx][mask]
|
||||
x[idx][mask]=inv_x
|
||||
#
|
||||
|
||||
#Out of place
|
||||
im = x[idx]
|
||||
inv_x = 1-im[mask]
|
||||
# im = x[idx]
|
||||
# inv_x = 1-im[mask]
|
||||
|
||||
imgs.append(im.masked_scatter(mask,inv_x))
|
||||
# imgs.append(im.masked_scatter(mask,inv_x))
|
||||
|
||||
idxs=torch.tensor(range(x.shape[0]), device=x.device)
|
||||
idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
|
||||
#idxs=torch.tensor(range(x.shape[0]), device=x.device)
|
||||
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
|
||||
#
|
||||
return x
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue