From d21a6bbf5cf6b8dc9299c2dde472ff12f9f51b52 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Mon, 20 Jan 2020 17:09:31 -0500 Subject: [PATCH] Fix etat Train/Eval pour augmentation differee (Retester !) --- higher/datasets.py | 22 +++++++++++----------- higher/dataug.py | 26 +++++++++++++++----------- higher/test_dataug.py | 5 ++--- higher/train_utils.py | 14 ++++++++++---- 4 files changed, 38 insertions(+), 29 deletions(-) diff --git a/higher/datasets.py b/higher/datasets.py index a3473d2..2c06c7e 100755 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -7,31 +7,31 @@ TEST_SIZE = 300 #TEST_SIZE = 10000 #legerement +Rapide / + Consomation memoire ! download_data=False -num_workers=4 #4 +num_workers=2 #4 pin_memory=False #True :+ GPU memory / + Lent #ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1] #transform_train = torchvision.transforms.Compose([ # torchvision.transforms.RandomHorizontalFlip(), # torchvision.transforms.ToTensor(), - #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #CIFAR10 +# torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #CIFAR10 #]) transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #CIFAR10 ]) -''' -data_train = torchvision.datasets.MNIST( - "./data", train=True, download=True, - transform=torchvision.transforms.Compose([ - #torchvision.transforms.RandomAffine(degrees=180, translate=None, scale=None, shear=None, resample=False, fillcolor=0), - torchvision.transforms.ToTensor() - ]) -) + +#data_train = torchvision.datasets.MNIST( +# "./data", train=True, download=True, +# transform=torchvision.transforms.Compose([ +# #torchvision.transforms.RandomAffine(degrees=180, translate=None, scale=None, shear=None, resample=False, fillcolor=0), +# torchvision.transforms.ToTensor() +# ]) +#) data_test = torchvision.datasets.MNIST( "./data", train=False, download=True, transform=torchvision.transforms.ToTensor() ) -''' + from torchvision.datasets.vision import VisionDataset from PIL import Image diff --git a/higher/dataug.py b/higher/dataug.py index 3255ae8..b33256a 100755 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -747,21 +747,22 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') return max_mag_reg - def train(self, mode=None): + def train(self, mode=True): """ Set the module training mode. Args: mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None) """ - if mode is None : - mode=self._data_augmentation + #if mode is None : + # mode=self._data_augmentation self.augment(mode=mode) #Inutile si mode=None super(Data_augV5, self).train(mode) + return self def eval(self): """ Set the module to evaluation mode. """ - self.train(mode=False) + return self.train(mode=False) def augment(self, mode=True): """ Set the augmentation mode. @@ -1266,7 +1267,7 @@ class Augmented_model(nn.Module): Attributes: _mods (nn.ModuleDict): A dictionary containing the modules. - _data_augmentation (bool): Wether data augmentation is used. + _data_augmentation (bool): Wether data augmentation should be used. """ def __init__(self, data_augmenter, model): """Init Augmented Model. @@ -1308,22 +1309,25 @@ class Augmented_model(nn.Module): self._data_augmentation=mode self._mods['data_aug'].augment(mode) - def train(self, mode=None): + def train(self, mode=True): """ Set the module training mode. Args: - mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None) + mode (bool): Wether to learn the parameter of the module. (default: None) """ - if mode is None : - mode=self._data_augmentation - self._mods['data_aug'].augment(mode) + #if mode is None : + # mode=self._data_augmentation super(Augmented_model, self).train(mode) + self._mods['data_aug'].augment(mode=self._data_augmentation) #Restart if needed data augmentation return self def eval(self): """ Set the module to evaluation mode. """ - return self.train(mode=False) + #return self.train(mode=False) + super(Augmented_model, self).train(mode=False) + self._mods['data_aug'].augment(mode=False) + return self def items(self): """Return an iterable of the ModuleDict key/value pairs. diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 0faac55..b58dcf7 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -68,7 +68,7 @@ if __name__ == "__main__": } n_inner_iter = 1 epochs = 150 - dataug_epoch_start=0 + dataug_epoch_start=10 optim_param={ 'Meta':{ 'optim':'Adam', @@ -87,8 +87,6 @@ if __name__ == "__main__": #model = MobileNetV2(num_classes=10) #model = WideResNet(num_classes=10, wrn_size=32) - model = Higher_model(model) #run_dist_dataugV3 - #### Classic #### if 'classic' in tasks: t0 = time.process_time() @@ -171,6 +169,7 @@ if __name__ == "__main__": t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} + model = Higher_model(model) #run_dist_dataugV3 aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) diff --git a/higher/train_utils.py b/higher/train_utils.py index c99ddfa..b547568 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -827,6 +827,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start device = next(model.parameters()).device log = [] dl_val_it = iter(dl_val) + val_loss=None high_grad_track = True if inner_it == 0: #No HP optimization @@ -909,10 +910,8 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step #print("meta") - val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss() #print_graph(val_loss) #to visualize computational graph - val_loss.backward() torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN @@ -920,7 +919,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start meta_opt.step() #Adjust Hyper-parameters - model['data_aug'].adjust_param(soft=True) #Contrainte sum(proba)=1 + model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 if hp_opt: for param_group in diffopt.param_groups: for param in list(opt_param['Inner'].keys())[1:]: @@ -949,6 +948,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start accuracy, test_loss =test(model) model.train() + print(model['data_aug']._data_augmentation) #### Log #### param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])] data={ @@ -989,7 +989,13 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start print('Starting Data Augmention...') dataug_epoch_start = epoch model.augment(mode=True) - if inner_it != 0: high_grad_track = True + if inner_it != 0: #Rebuild diffopt if needed + high_grad_track = True + diffopt = model['model'].get_diffopt( + inner_opt, + grad_callback=(lambda grads: clip_norm(grads, max_norm=10)), + track_higher_grads=high_grad_track) + #Data sample saving try: