diff --git a/higher/model.py b/higher/model.py index 84afeff..d51dd6a 100755 --- a/higher/model.py +++ b/higher/model.py @@ -6,15 +6,51 @@ import torch.nn.functional as F import higher class Higher_model(nn.Module): + """Model wrapper for higher gradient tracking. + + Keep in memory the orginial model and it's functionnal, higher, version. + + Might not be needed anymore if Higher implement detach for fmodel. + + see : https://github.com/facebookresearch/higher + + TODO: Get rid of the original model if not needed by user. + + Attributes: + _name (string): Name of the model. + _mods (nn.ModuleDict): Models (Orginial and Higher version). + """ def __init__(self, model): + """Init Higher_model. + + Args: + model (nn.Module): Network for which higher gradients can be tracked. + """ super(Higher_model, self).__init__() + self._name = model.__str__() self._mods = nn.ModuleDict({ 'original': model, 'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) }) def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True): + """Get a differentiable version of an Optimizer. + + Higher/Differentiable optimizer required to be used for higher gradient tracking. + Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step) + + Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called. + Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation. + + Args: + opt (torch.optim): Optimizer to make differentiable. + grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None) + track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True) + + Returns: + (Higher.DifferentiableOptimizer): Differentiable version of the optimizer. + """ return higher.optim.get_diff_optim(opt, self._mods['original'].parameters(), fmodel=self._mods['functional'], @@ -22,20 +58,49 @@ class Higher_model(nn.Module): track_higher_grads=track_higher_grads) def forward(self, x): + """ Main method of the model. + + Args: + x (Tensor): Batch of data. + + Returns: + Tensor : Output of the network. Should be logits. + """ return self._mods['functional'](x) def detach_(self): + """Detach from the graph. + + Needed to limit the number of state kept in memory. + """ tmp = self._mods['functional'].fast_params self._mods['functional']._fast_params=[] self._mods['functional'].update_params(tmp) for p in self._mods['functional'].fast_params: p.detach_().requires_grad_() + def state_dict(self): + """Returns a dictionary containing a whole state of the module. + """ + return self._mods['functional'].state_dict() + def __getitem__(self, key): + """Access to modules + Args: + key (string): Name of the module to access. + + Returns: + nn.Module. + """ return self._mods[key] def __str__(self): - return self._mods['original'].__str__() + """Name of the module + + Returns: + String containing the name of the module. + """ + return self._name ## Basic CNN ## class LeNet_F(nn.Module): diff --git a/higher/test_brutus.py b/higher/test_brutus.py index 860e4b8..e2c087b 100755 --- a/higher/test_brutus.py +++ b/higher/test_brutus.py @@ -50,7 +50,7 @@ if __name__ == "__main__": } #model = LeNet(3,10) - model = ResNet(num_classes=10) + #model = ResNet(num_classes=10) #model = MobileNetV2(num_classes=10) #model = WideResNet(num_classes=10, wrn_size=32) @@ -126,6 +126,8 @@ if __name__ == "__main__": t0 = time.process_time() + model = ResNet(num_classes=10) + model = Higher_model(model) #run_dist_dataugV3 aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) @@ -136,8 +138,7 @@ if __name__ == "__main__": dataug_epoch_start=dataug_epoch_start, opt_param=optim_param, print_freq=50, - KLdiv=True, - loss_patience=None) + KLdiv=True) exec_time=time.process_time() - t0 #### diff --git a/higher/test_dataug.py b/higher/test_dataug.py index b58dcf7..972c6aa 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -68,7 +68,7 @@ if __name__ == "__main__": } n_inner_iter = 1 epochs = 150 - dataug_epoch_start=10 + dataug_epoch_start=0 optim_param={ 'Meta':{ 'optim':'Adam', diff --git a/higher/train_utils.py b/higher/train_utils.py index b547568..5a792bb 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -876,7 +876,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start else: #Methode KL div # Supervised loss (classic) - if model._data_augmentation : + if model.is_augmenting() : model.augment(mode=False) sup_logits = model(xs) model.augment(mode=True) @@ -886,7 +886,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start loss = F.cross_entropy(log_sup, ys) # Unsupervised loss (KLdiv) - if model._data_augmentation: + if model.is_augmenting() : aug_logits = model(xs) log_aug=F.log_softmax(aug_logits, dim=1) aug_loss=0 @@ -948,7 +948,6 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start accuracy, test_loss =test(model) model.train() - print(model['data_aug']._data_augmentation) #### Log #### param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])] data={ diff --git a/higher/transformations.py b/higher/transformations.py index c4f4175..430e7e8 100755 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -53,64 +53,91 @@ TF_dict={ #Dataugv5 #AutoAugment #'Equalize': (lambda mag: None), } ''' +# Dictionnary mapping tranformations identifiers to their function. +# Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data. TF_dict={ #Dataugv5 - ## Geometric TF ## - 'Identity' : (lambda x, mag: x), - 'FlipUD' : (lambda x, mag: flipUD(x)), - 'FlipLR' : (lambda x, mag: flipLR(x)), - 'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))), - 'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))), - 'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))), - 'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))), - 'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))), + ## Geometric TF ## + 'Identity' : (lambda x, mag: x), + 'FlipUD' : (lambda x, mag: flipUD(x)), + 'FlipLR' : (lambda x, mag: flipLR(x)), + 'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))), + 'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))), + 'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))), + 'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))), + 'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))), - ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), - 'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient - 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] + ## Color TF (Expect image in the range of [0, 1]) ## + 'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient + 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] - #Color TF (Common mag scale) - '+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), - '+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), - '+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), - '+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), - '-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), - '-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), - '-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), - '-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), - '=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient - '=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] - - 'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))), - 'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))), - 'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))), - 'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))), - 'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))), - 'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))), - - 'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))), - 'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))), + #Color TF (Common mag scale) + '+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), + '+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), + '+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), + '+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), + '-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), + '-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), + '-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), + '-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), + '=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient + '=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] - 'Random':(lambda x, mag: torch.rand_like(x)), - 'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))), - - #Non fonctionnel - #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) - #'Equalize': (lambda mag: None), + ## Bad Tranformations ## + # Bad Geometric TF # + 'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))), + 'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))), + 'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))), + 'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))), + 'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))), + 'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))), + + # Bad Color TF # + 'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))), + 'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))), + + # Random TF # + 'Random':(lambda x, mag: torch.rand_like(x)), + 'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))), + + #Non fonctionnel + #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) + #'Equalize': (lambda mag: None), } -TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} -TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} -TF_ignore_mag= TF_no_mag | TF_no_grad +TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} #TF that don't have use for magnitude parameter. +TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} #TF which implemetation doesn't allow gradient propagaition. +TF_ignore_mag= TF_no_mag | TF_no_grad #TF for which magnitude should be ignored (Magnitude fixed). -def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039) - return (float_image*255.).type(torch.uint8) +def int_image(float_image): + """Convert a float Tensor/Image to an int Tensor/Image. + + Be warry that this transformation isn't bijective, each conversion will result in small loss of information. + Granularity: 1/256 = 0.0039. + + This will also result in the loss of the gradient associated to input as gradient cannot be tracked on int Tensor. + + Args: + float_image (torch.float): Image tensor. + + Returns: + (torch.uint8) Converted tensor. + """ + return (float_image*255.).type(torch.uint8) def float_image(int_image): - return int_image.type(torch.float)/255. + """Convert a int Tensor/Image to an float Tensor/Image. + + Args: + int_image (torch.uint8): Image tensor. + + Returns: + (torch.float) Converted tensor. + """ + return int_image.type(torch.float)/255. #def rand_inverse(value): # return value if random.random() < 0.5 else -value @@ -125,11 +152,22 @@ def float_image(int_image): # if not minval : minval = -real_max # return random.uniform(minval, real_max) -def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval] - real_mag = float_parameter(mag, maxval=maxval) - if not minval : minval = -real_mag - #return random.uniform(minval, real_max) - return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag] +def rand_floats(size, mag, maxval, minval=None): + """Generate a batch of random values. + + Args: + size (int): Number of value to generate. + mag (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX]. + maxval (float): Maximum value that can be generated. This will be scaled to mag/PARAMETER_MAX. + minval (float): Minimum value that can be generated. (default: -maxval) + + Returns: + Generated batch of float values between [minval, maxval]. + """ + real_mag = float_parameter(mag, maxval=maxval) + if not minval : minval = -real_mag + #return random.uniform(minval, real_max) + return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag] def invScale_rand_floats(size, mag, maxval, minval): #Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval]