Minor improvement + Comments

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-21 13:53:07 -05:00
parent d21a6bbf5c
commit c1ad787d97
5 changed files with 165 additions and 62 deletions

View file

@ -6,15 +6,51 @@ import torch.nn.functional as F
import higher import higher
class Higher_model(nn.Module): class Higher_model(nn.Module):
"""Model wrapper for higher gradient tracking.
Keep in memory the orginial model and it's functionnal, higher, version.
Might not be needed anymore if Higher implement detach for fmodel.
see : https://github.com/facebookresearch/higher
TODO: Get rid of the original model if not needed by user.
Attributes:
_name (string): Name of the model.
_mods (nn.ModuleDict): Models (Orginial and Higher version).
"""
def __init__(self, model): def __init__(self, model):
"""Init Higher_model.
Args:
model (nn.Module): Network for which higher gradients can be tracked.
"""
super(Higher_model, self).__init__() super(Higher_model, self).__init__()
self._name = model.__str__()
self._mods = nn.ModuleDict({ self._mods = nn.ModuleDict({
'original': model, 'original': model,
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) 'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
}) })
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True): def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
"""Get a differentiable version of an Optimizer.
Higher/Differentiable optimizer required to be used for higher gradient tracking.
Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step)
Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called.
Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation.
Args:
opt (torch.optim): Optimizer to make differentiable.
grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None)
track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True)
Returns:
(Higher.DifferentiableOptimizer): Differentiable version of the optimizer.
"""
return higher.optim.get_diff_optim(opt, return higher.optim.get_diff_optim(opt,
self._mods['original'].parameters(), self._mods['original'].parameters(),
fmodel=self._mods['functional'], fmodel=self._mods['functional'],
@ -22,20 +58,49 @@ class Higher_model(nn.Module):
track_higher_grads=track_higher_grads) track_higher_grads=track_higher_grads)
def forward(self, x): def forward(self, x):
""" Main method of the model.
Args:
x (Tensor): Batch of data.
Returns:
Tensor : Output of the network. Should be logits.
"""
return self._mods['functional'](x) return self._mods['functional'](x)
def detach_(self): def detach_(self):
"""Detach from the graph.
Needed to limit the number of state kept in memory.
"""
tmp = self._mods['functional'].fast_params tmp = self._mods['functional'].fast_params
self._mods['functional']._fast_params=[] self._mods['functional']._fast_params=[]
self._mods['functional'].update_params(tmp) self._mods['functional'].update_params(tmp)
for p in self._mods['functional'].fast_params: for p in self._mods['functional'].fast_params:
p.detach_().requires_grad_() p.detach_().requires_grad_()
def state_dict(self):
"""Returns a dictionary containing a whole state of the module.
"""
return self._mods['functional'].state_dict()
def __getitem__(self, key): def __getitem__(self, key):
"""Access to modules
Args:
key (string): Name of the module to access.
Returns:
nn.Module.
"""
return self._mods[key] return self._mods[key]
def __str__(self): def __str__(self):
return self._mods['original'].__str__() """Name of the module
Returns:
String containing the name of the module.
"""
return self._name
## Basic CNN ## ## Basic CNN ##
class LeNet_F(nn.Module): class LeNet_F(nn.Module):

View file

@ -50,7 +50,7 @@ if __name__ == "__main__":
} }
#model = LeNet(3,10) #model = LeNet(3,10)
model = ResNet(num_classes=10) #model = ResNet(num_classes=10)
#model = MobileNetV2(num_classes=10) #model = MobileNetV2(num_classes=10)
#model = WideResNet(num_classes=10, wrn_size=32) #model = WideResNet(num_classes=10, wrn_size=32)
@ -126,6 +126,8 @@ if __name__ == "__main__":
t0 = time.process_time() t0 = time.process_time()
model = ResNet(num_classes=10)
model = Higher_model(model) #run_dist_dataugV3
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device) aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
@ -136,8 +138,7 @@ if __name__ == "__main__":
dataug_epoch_start=dataug_epoch_start, dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param, opt_param=optim_param,
print_freq=50, print_freq=50,
KLdiv=True, KLdiv=True)
loss_patience=None)
exec_time=time.process_time() - t0 exec_time=time.process_time() - t0
#### ####

View file

@ -68,7 +68,7 @@ if __name__ == "__main__":
} }
n_inner_iter = 1 n_inner_iter = 1
epochs = 150 epochs = 150
dataug_epoch_start=10 dataug_epoch_start=0
optim_param={ optim_param={
'Meta':{ 'Meta':{
'optim':'Adam', 'optim':'Adam',

View file

@ -876,7 +876,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
else: else:
#Methode KL div #Methode KL div
# Supervised loss (classic) # Supervised loss (classic)
if model._data_augmentation : if model.is_augmenting() :
model.augment(mode=False) model.augment(mode=False)
sup_logits = model(xs) sup_logits = model(xs)
model.augment(mode=True) model.augment(mode=True)
@ -886,7 +886,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
loss = F.cross_entropy(log_sup, ys) loss = F.cross_entropy(log_sup, ys)
# Unsupervised loss (KLdiv) # Unsupervised loss (KLdiv)
if model._data_augmentation: if model.is_augmenting() :
aug_logits = model(xs) aug_logits = model(xs)
log_aug=F.log_softmax(aug_logits, dim=1) log_aug=F.log_softmax(aug_logits, dim=1)
aug_loss=0 aug_loss=0
@ -948,7 +948,6 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
accuracy, test_loss =test(model) accuracy, test_loss =test(model)
model.train() model.train()
print(model['data_aug']._data_augmentation)
#### Log #### #### Log ####
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])] param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
data={ data={

View file

@ -53,64 +53,91 @@ TF_dict={ #Dataugv5 #AutoAugment
#'Equalize': (lambda mag: None), #'Equalize': (lambda mag: None),
} }
''' '''
# Dictionnary mapping tranformations identifiers to their function.
# Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data.
TF_dict={ #Dataugv5 TF_dict={ #Dataugv5
## Geometric TF ## ## Geometric TF ##
'Identity' : (lambda x, mag: x), 'Identity' : (lambda x, mag: x),
'FlipUD' : (lambda x, mag: flipUD(x)), 'FlipUD' : (lambda x, mag: flipUD(x)),
'FlipLR' : (lambda x, mag: flipLR(x)), 'FlipLR' : (lambda x, mag: flipLR(x)),
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))), 'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))), 'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))), 'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))), 'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))), 'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
## Color TF (Expect image in the range of [0, 1]) ## ## Color TF (Expect image in the range of [0, 1]) ##
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), 'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), 'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), 'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient 'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
#Color TF (Common mag scale) #Color TF (Common mag scale)
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), '+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), '+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), '+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), '+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), '-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), '-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), '-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), '-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient '=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] '=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))), ## Bad Tranformations ##
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))), # Bad Geometric TF #
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))), 'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))),
'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))), 'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))),
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))), 'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))),
'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))), 'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))),
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))),
'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))),
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))), # Bad Color TF #
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))), 'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))),
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
'Random':(lambda x, mag: torch.rand_like(x)), # Random TF #
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))), 'Random':(lambda x, mag: torch.rand_like(x)),
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))),
#Non fonctionnel #Non fonctionnel
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None), #'Equalize': (lambda mag: None),
} }
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} #TF that don't have use for magnitude parameter.
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} #TF which implemetation doesn't allow gradient propagaition.
TF_ignore_mag= TF_no_mag | TF_no_grad TF_ignore_mag= TF_no_mag | TF_no_grad #TF for which magnitude should be ignored (Magnitude fixed).
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039) def int_image(float_image):
return (float_image*255.).type(torch.uint8) """Convert a float Tensor/Image to an int Tensor/Image.
Be warry that this transformation isn't bijective, each conversion will result in small loss of information.
Granularity: 1/256 = 0.0039.
This will also result in the loss of the gradient associated to input as gradient cannot be tracked on int Tensor.
Args:
float_image (torch.float): Image tensor.
Returns:
(torch.uint8) Converted tensor.
"""
return (float_image*255.).type(torch.uint8)
def float_image(int_image): def float_image(int_image):
return int_image.type(torch.float)/255. """Convert a int Tensor/Image to an float Tensor/Image.
Args:
int_image (torch.uint8): Image tensor.
Returns:
(torch.float) Converted tensor.
"""
return int_image.type(torch.float)/255.
#def rand_inverse(value): #def rand_inverse(value):
# return value if random.random() < 0.5 else -value # return value if random.random() < 0.5 else -value
@ -125,11 +152,22 @@ def float_image(int_image):
# if not minval : minval = -real_max # if not minval : minval = -real_max
# return random.uniform(minval, real_max) # return random.uniform(minval, real_max)
def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval] def rand_floats(size, mag, maxval, minval=None):
real_mag = float_parameter(mag, maxval=maxval) """Generate a batch of random values.
if not minval : minval = -real_mag
#return random.uniform(minval, real_max) Args:
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag] size (int): Number of value to generate.
mag (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX].
maxval (float): Maximum value that can be generated. This will be scaled to mag/PARAMETER_MAX.
minval (float): Minimum value that can be generated. (default: -maxval)
Returns:
Generated batch of float values between [minval, maxval].
"""
real_mag = float_parameter(mag, maxval=maxval)
if not minval : minval = -real_mag
#return random.uniform(minval, real_max)
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
def invScale_rand_floats(size, mag, maxval, minval): def invScale_rand_floats(size, mag, maxval, minval):
#Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval] #Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval]