mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
Ajout meta-learning differee
This commit is contained in:
parent
f2cf244801
commit
53f6600ff6
4 changed files with 70 additions and 64 deletions
|
@ -4,6 +4,7 @@
|
||||||
from dataug import *
|
from dataug import *
|
||||||
#from utils import *
|
#from utils import *
|
||||||
from train_utils import *
|
from train_utils import *
|
||||||
|
from transformations import TF_loader
|
||||||
|
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
|
|
||||||
|
@ -13,6 +14,7 @@ optim_param={
|
||||||
'Meta':{
|
'Meta':{
|
||||||
'optim':'Adam',
|
'optim':'Adam',
|
||||||
'lr':1e-2, #1e-2
|
'lr':1e-2, #1e-2
|
||||||
|
'epoch_start': 2, #0 / 2 (Resnet?)
|
||||||
},
|
},
|
||||||
'Inner':{
|
'Inner':{
|
||||||
'optim': 'SGD',
|
'optim': 'SGD',
|
||||||
|
@ -26,9 +28,9 @@ optim_param={
|
||||||
|
|
||||||
res_folder="../res/benchmark/CIFAR10/"
|
res_folder="../res/benchmark/CIFAR10/"
|
||||||
#res_folder="../res/HPsearch/"
|
#res_folder="../res/HPsearch/"
|
||||||
epochs= 200
|
epochs= 300
|
||||||
dataug_epoch_start=0
|
dataug_epoch_start=0
|
||||||
nb_run= 3
|
nb_run= 1
|
||||||
|
|
||||||
tf_config='../config/base_tf_config.json'
|
tf_config='../config/base_tf_config.json'
|
||||||
TF_loader=TF_loader()
|
TF_loader=TF_loader()
|
||||||
|
|
|
@ -964,7 +964,7 @@ class Augmented_model(nn.Module):
|
||||||
|
|
||||||
model.step(loss)
|
model.step(loss)
|
||||||
|
|
||||||
Does not support LR scheduler.
|
Lacking epoch informations, this does not support LR scheduler and delayed meta-optimisation(Meta-optimizer: epoch_start>1).
|
||||||
|
|
||||||
See ''run_simple_smartaug'' for a complete example.
|
See ''run_simple_smartaug'' for a complete example.
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,10 @@ from LeNet import *
|
||||||
from dataug import *
|
from dataug import *
|
||||||
#from utils import *
|
#from utils import *
|
||||||
from train_utils import *
|
from train_utils import *
|
||||||
#from transformations import TF_loader
|
from transformations import TF_loader
|
||||||
|
|
||||||
postfix=''
|
postfix=''
|
||||||
TF_loader=TF.TF_loader()
|
TF_loader=TF_loader()
|
||||||
|
|
||||||
device = torch.device('cuda') #Select device to use
|
device = torch.device('cuda') #Select device to use
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ if __name__ == "__main__":
|
||||||
'Meta':{
|
'Meta':{
|
||||||
'optim':'Adam',
|
'optim':'Adam',
|
||||||
'lr':1e-2, #1e-2
|
'lr':1e-2, #1e-2
|
||||||
|
'epoch_start': 2, #0 / 2 (Resnet?)
|
||||||
},
|
},
|
||||||
'Inner':{
|
'Inner':{
|
||||||
'optim': 'SGD',
|
'optim': 'SGD',
|
||||||
|
|
|
@ -31,64 +31,6 @@ PARAMETER_MAX = 1
|
||||||
# What is the min 'level' a transform could be predicted
|
# What is the min 'level' a transform could be predicted
|
||||||
PARAMETER_MIN = 0.1
|
PARAMETER_MIN = 0.1
|
||||||
|
|
||||||
'''
|
|
||||||
# Dictionnary mapping tranformations identifiers to their function.
|
|
||||||
# Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data.
|
|
||||||
TF_dict={ #Dataugv5+
|
|
||||||
## Geometric TF ##
|
|
||||||
'Identity' : (lambda x, mag: x),
|
|
||||||
'FlipUD' : (lambda x, mag: flipUD(x)),
|
|
||||||
'FlipLR' : (lambda x, mag: flipLR(x)),
|
|
||||||
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
|
|
||||||
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=x.shape[2]*0.33), zero_pos=0))),
|
|
||||||
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=x.shape[3]*0.33), zero_pos=1))),
|
|
||||||
'TranslateXabs': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
|
||||||
'TranslateYabs': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
|
||||||
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
|
||||||
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
|
||||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
|
||||||
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
|
||||||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
|
||||||
'Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
|
||||||
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
|
||||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
|
||||||
|
|
||||||
#Color TF (Common mag scale)
|
|
||||||
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
|
||||||
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
|
||||||
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
|
||||||
'+Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
|
||||||
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
|
||||||
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
|
||||||
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
|
||||||
'-Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
|
||||||
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
|
||||||
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
|
||||||
|
|
||||||
## Bad Tranformations ##
|
|
||||||
# Bad Geometric TF #
|
|
||||||
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))),
|
|
||||||
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))),
|
|
||||||
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))),
|
|
||||||
'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))),
|
|
||||||
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))),
|
|
||||||
'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))),
|
|
||||||
|
|
||||||
# Bad Color TF #
|
|
||||||
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))),
|
|
||||||
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
|
||||||
|
|
||||||
# Random TF #
|
|
||||||
'Random':(lambda x, mag: torch.rand_like(x)),
|
|
||||||
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))),
|
|
||||||
|
|
||||||
#Not ready for use
|
|
||||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
|
||||||
#'Equalize': (lambda mag: None),
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
class TF_loader(object):
|
class TF_loader(object):
|
||||||
""" Transformations builder.
|
""" Transformations builder.
|
||||||
|
|
||||||
|
@ -155,6 +97,8 @@ class TF_loader(object):
|
||||||
def build_lambda(self, fct_name, rand_fct_name, minval, maxval, absolute=True, axis=None):
|
def build_lambda(self, fct_name, rand_fct_name, minval, maxval, absolute=True, axis=None):
|
||||||
""" Build a lambda function performing transformations.
|
""" Build a lambda function performing transformations.
|
||||||
|
|
||||||
|
Force different context for creation of each lambda function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fct_name (str): Name of the transformations to use (see transformations.py).
|
fct_name (str): Name of the transformations to use (see transformations.py).
|
||||||
rand_fct_name (str): Name of the random mapping function to use (see transformations.py).
|
rand_fct_name (str): Name of the random mapping function to use (see transformations.py).
|
||||||
|
@ -620,4 +564,63 @@ def equalize(x):
|
||||||
#print(chan.shape)
|
#print(chan.shape)
|
||||||
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
||||||
|
|
||||||
return float_image(x)
|
return float_image(x)
|
||||||
|
|
||||||
|
'''
|
||||||
|
# Dictionnary mapping tranformations identifiers to their function.
|
||||||
|
# Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data.
|
||||||
|
TF_dict={ #Dataugv5+
|
||||||
|
## Geometric TF ##
|
||||||
|
'Identity' : (lambda x, mag: x),
|
||||||
|
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||||
|
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||||
|
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
|
||||||
|
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=x.shape[2]*0.33), zero_pos=0))),
|
||||||
|
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=x.shape[3]*0.33), zero_pos=1))),
|
||||||
|
'TranslateXabs': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
||||||
|
'TranslateYabs': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
||||||
|
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||||
|
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
|
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||||
|
|
||||||
|
#Color TF (Common mag scale)
|
||||||
|
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'+Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'-Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
|
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||||
|
|
||||||
|
## Bad Tranformations ##
|
||||||
|
# Bad Geometric TF #
|
||||||
|
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))),
|
||||||
|
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))),
|
||||||
|
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))),
|
||||||
|
'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))),
|
||||||
|
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))),
|
||||||
|
'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))),
|
||||||
|
|
||||||
|
# Bad Color TF #
|
||||||
|
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))),
|
||||||
|
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
||||||
|
|
||||||
|
# Random TF #
|
||||||
|
'Random':(lambda x, mag: torch.rand_like(x)),
|
||||||
|
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))),
|
||||||
|
|
||||||
|
#Not ready for use
|
||||||
|
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||||
|
#'Equalize': (lambda mag: None),
|
||||||
|
}
|
||||||
|
'''
|
Loading…
Add table
Add a link
Reference in a new issue