mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Minor improvement + Comments
This commit is contained in:
parent
d21a6bbf5c
commit
c1ad787d97
5 changed files with 165 additions and 62 deletions
|
@ -6,15 +6,51 @@ import torch.nn.functional as F
|
|||
|
||||
import higher
|
||||
class Higher_model(nn.Module):
|
||||
"""Model wrapper for higher gradient tracking.
|
||||
|
||||
Keep in memory the orginial model and it's functionnal, higher, version.
|
||||
|
||||
Might not be needed anymore if Higher implement detach for fmodel.
|
||||
|
||||
see : https://github.com/facebookresearch/higher
|
||||
|
||||
TODO: Get rid of the original model if not needed by user.
|
||||
|
||||
Attributes:
|
||||
_name (string): Name of the model.
|
||||
_mods (nn.ModuleDict): Models (Orginial and Higher version).
|
||||
"""
|
||||
def __init__(self, model):
|
||||
"""Init Higher_model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Network for which higher gradients can be tracked.
|
||||
"""
|
||||
super(Higher_model, self).__init__()
|
||||
|
||||
self._name = model.__str__()
|
||||
self._mods = nn.ModuleDict({
|
||||
'original': model,
|
||||
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
})
|
||||
|
||||
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
|
||||
"""Get a differentiable version of an Optimizer.
|
||||
|
||||
Higher/Differentiable optimizer required to be used for higher gradient tracking.
|
||||
Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called.
|
||||
Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation.
|
||||
|
||||
Args:
|
||||
opt (torch.optim): Optimizer to make differentiable.
|
||||
grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None)
|
||||
track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True)
|
||||
|
||||
Returns:
|
||||
(Higher.DifferentiableOptimizer): Differentiable version of the optimizer.
|
||||
"""
|
||||
return higher.optim.get_diff_optim(opt,
|
||||
self._mods['original'].parameters(),
|
||||
fmodel=self._mods['functional'],
|
||||
|
@ -22,20 +58,49 @@ class Higher_model(nn.Module):
|
|||
track_higher_grads=track_higher_grads)
|
||||
|
||||
def forward(self, x):
|
||||
""" Main method of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
|
||||
Returns:
|
||||
Tensor : Output of the network. Should be logits.
|
||||
"""
|
||||
return self._mods['functional'](x)
|
||||
|
||||
def detach_(self):
|
||||
"""Detach from the graph.
|
||||
|
||||
Needed to limit the number of state kept in memory.
|
||||
"""
|
||||
tmp = self._mods['functional'].fast_params
|
||||
self._mods['functional']._fast_params=[]
|
||||
self._mods['functional'].update_params(tmp)
|
||||
for p in self._mods['functional'].fast_params:
|
||||
p.detach_().requires_grad_()
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
"""
|
||||
return self._mods['functional'].state_dict()
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Access to modules
|
||||
Args:
|
||||
key (string): Name of the module to access.
|
||||
|
||||
Returns:
|
||||
nn.Module.
|
||||
"""
|
||||
return self._mods[key]
|
||||
|
||||
def __str__(self):
|
||||
return self._mods['original'].__str__()
|
||||
"""Name of the module
|
||||
|
||||
Returns:
|
||||
String containing the name of the module.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
## Basic CNN ##
|
||||
class LeNet_F(nn.Module):
|
||||
|
|
|
@ -50,7 +50,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
|
||||
#model = LeNet(3,10)
|
||||
model = ResNet(num_classes=10)
|
||||
#model = ResNet(num_classes=10)
|
||||
#model = MobileNetV2(num_classes=10)
|
||||
#model = WideResNet(num_classes=10, wrn_size=32)
|
||||
|
||||
|
@ -126,6 +126,8 @@ if __name__ == "__main__":
|
|||
|
||||
t0 = time.process_time()
|
||||
|
||||
model = ResNet(num_classes=10)
|
||||
model = Higher_model(model) #run_dist_dataugV3
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
|
||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||
|
||||
|
@ -136,8 +138,7 @@ if __name__ == "__main__":
|
|||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=50,
|
||||
KLdiv=True,
|
||||
loss_patience=None)
|
||||
KLdiv=True)
|
||||
|
||||
exec_time=time.process_time() - t0
|
||||
####
|
||||
|
|
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
n_inner_iter = 1
|
||||
epochs = 150
|
||||
dataug_epoch_start=10
|
||||
dataug_epoch_start=0
|
||||
optim_param={
|
||||
'Meta':{
|
||||
'optim':'Adam',
|
||||
|
|
|
@ -876,7 +876,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
else:
|
||||
#Methode KL div
|
||||
# Supervised loss (classic)
|
||||
if model._data_augmentation :
|
||||
if model.is_augmenting() :
|
||||
model.augment(mode=False)
|
||||
sup_logits = model(xs)
|
||||
model.augment(mode=True)
|
||||
|
@ -886,7 +886,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
# Unsupervised loss (KLdiv)
|
||||
if model._data_augmentation:
|
||||
if model.is_augmenting() :
|
||||
aug_logits = model(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss=0
|
||||
|
@ -948,7 +948,6 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
accuracy, test_loss =test(model)
|
||||
model.train()
|
||||
|
||||
print(model['data_aug']._data_augmentation)
|
||||
#### Log ####
|
||||
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||
data={
|
||||
|
|
|
@ -53,64 +53,91 @@ TF_dict={ #Dataugv5 #AutoAugment
|
|||
#'Equalize': (lambda mag: None),
|
||||
}
|
||||
'''
|
||||
# Dictionnary mapping tranformations identifiers to their function.
|
||||
# Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data.
|
||||
TF_dict={ #Dataugv5
|
||||
## Geometric TF ##
|
||||
'Identity' : (lambda x, mag: x),
|
||||
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
|
||||
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
||||
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
||||
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||
## Geometric TF ##
|
||||
'Identity' : (lambda x, mag: x),
|
||||
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
|
||||
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
||||
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
||||
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||
|
||||
#Color TF (Common mag scale)
|
||||
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||
#Color TF (Common mag scale)
|
||||
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||
|
||||
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))),
|
||||
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))),
|
||||
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))),
|
||||
'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))),
|
||||
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))),
|
||||
'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))),
|
||||
## Bad Tranformations ##
|
||||
# Bad Geometric TF #
|
||||
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))),
|
||||
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))),
|
||||
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))),
|
||||
'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))),
|
||||
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))),
|
||||
'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))),
|
||||
|
||||
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))),
|
||||
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
||||
# Bad Color TF #
|
||||
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))),
|
||||
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
||||
|
||||
'Random':(lambda x, mag: torch.rand_like(x)),
|
||||
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))),
|
||||
# Random TF #
|
||||
'Random':(lambda x, mag: torch.rand_like(x)),
|
||||
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))),
|
||||
|
||||
#Non fonctionnel
|
||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||
#'Equalize': (lambda mag: None),
|
||||
#Non fonctionnel
|
||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||
#'Equalize': (lambda mag: None),
|
||||
}
|
||||
|
||||
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'}
|
||||
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'}
|
||||
TF_ignore_mag= TF_no_mag | TF_no_grad
|
||||
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} #TF that don't have use for magnitude parameter.
|
||||
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} #TF which implemetation doesn't allow gradient propagaition.
|
||||
TF_ignore_mag= TF_no_mag | TF_no_grad #TF for which magnitude should be ignored (Magnitude fixed).
|
||||
|
||||
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
||||
return (float_image*255.).type(torch.uint8)
|
||||
def int_image(float_image):
|
||||
"""Convert a float Tensor/Image to an int Tensor/Image.
|
||||
|
||||
Be warry that this transformation isn't bijective, each conversion will result in small loss of information.
|
||||
Granularity: 1/256 = 0.0039.
|
||||
|
||||
This will also result in the loss of the gradient associated to input as gradient cannot be tracked on int Tensor.
|
||||
|
||||
Args:
|
||||
float_image (torch.float): Image tensor.
|
||||
|
||||
Returns:
|
||||
(torch.uint8) Converted tensor.
|
||||
"""
|
||||
return (float_image*255.).type(torch.uint8)
|
||||
|
||||
def float_image(int_image):
|
||||
return int_image.type(torch.float)/255.
|
||||
"""Convert a int Tensor/Image to an float Tensor/Image.
|
||||
|
||||
Args:
|
||||
int_image (torch.uint8): Image tensor.
|
||||
|
||||
Returns:
|
||||
(torch.float) Converted tensor.
|
||||
"""
|
||||
return int_image.type(torch.float)/255.
|
||||
|
||||
#def rand_inverse(value):
|
||||
# return value if random.random() < 0.5 else -value
|
||||
|
@ -125,11 +152,22 @@ def float_image(int_image):
|
|||
# if not minval : minval = -real_max
|
||||
# return random.uniform(minval, real_max)
|
||||
|
||||
def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||
real_mag = float_parameter(mag, maxval=maxval)
|
||||
if not minval : minval = -real_mag
|
||||
#return random.uniform(minval, real_max)
|
||||
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
|
||||
def rand_floats(size, mag, maxval, minval=None):
|
||||
"""Generate a batch of random values.
|
||||
|
||||
Args:
|
||||
size (int): Number of value to generate.
|
||||
mag (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX].
|
||||
maxval (float): Maximum value that can be generated. This will be scaled to mag/PARAMETER_MAX.
|
||||
minval (float): Minimum value that can be generated. (default: -maxval)
|
||||
|
||||
Returns:
|
||||
Generated batch of float values between [minval, maxval].
|
||||
"""
|
||||
real_mag = float_parameter(mag, maxval=maxval)
|
||||
if not minval : minval = -real_mag
|
||||
#return random.uniform(minval, real_max)
|
||||
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
|
||||
|
||||
def invScale_rand_floats(size, mag, maxval, minval):
|
||||
#Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue