From 383f63c7b8078ba975fd3414ca640f6e2d524d21 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Mon, 10 Feb 2020 16:32:59 -0500 Subject: [PATCH] Ajout option Weight decay / Nesterov sur inner opt --- higher/smart_aug/dataug.py | 6 +++++- higher/smart_aug/test_dataug.py | 20 +++++++++++--------- higher/smart_aug/train_utils.py | 12 ++++++++++-- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/higher/smart_aug/dataug.py b/higher/smart_aug/dataug.py index 0b459f1..7c85473 100755 --- a/higher/smart_aug/dataug.py +++ b/higher/smart_aug/dataug.py @@ -972,7 +972,11 @@ class Augmented_model(nn.Module): self._opt_param=opt_param #Inner Opt - inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 + inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), + lr=opt_param['Inner']['lr'], + momentum=opt_param['Inner']['momentum'], + weight_decay=opt_param['Inner']['decay'], + nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9 #Validation data self._dl_val=dl_val diff --git a/higher/smart_aug/test_dataug.py b/higher/smart_aug/test_dataug.py index 8084fa2..67f042b 100755 --- a/higher/smart_aug/test_dataug.py +++ b/higher/smart_aug/test_dataug.py @@ -77,12 +77,12 @@ if __name__ == "__main__": #Task to perform tasks={ - #'classic', - 'aug_model' + 'classic', + #'aug_model' } #Parameters n_inner_iter = 1 - epochs = 2 + epochs = 150 dataug_epoch_start=0 optim_param={ 'Meta':{ @@ -93,16 +93,18 @@ if __name__ == "__main__": 'optim': 'SGD', 'lr':1e-2, #1e-2 'momentum':0.9, #0.9 + 'decay':0.0001, + 'nesterov':True, } } #Models - model = LeNet(3,10) + #model = LeNet(3,10) #model = ResNet(num_classes=10) - #import torchvision.models as models + import torchvision.models as models #model=models.resnet18() - model_name = str(model) #'wide_resnet50_2' #'resnet18' #str(model) - #model = getattr(models.resnet, model_name)(pretrained=False) + model_name = 'resnet50' #'wide_resnet50_2' #'resnet18' #str(model) + model = getattr(models.resnet, model_name)(pretrained=False) #### Classic #### if 'classic' in tasks: @@ -111,7 +113,7 @@ if __name__ == "__main__": print("{} on {} for {} epochs".format(model_name, device_name, epochs)) - log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1) + log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10) #log= train_classic_higher(model=model, epochs=epochs) exec_time=time.perf_counter() - t0 @@ -161,7 +163,7 @@ if __name__ == "__main__": inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, opt_param=optim_param, - print_freq=1, + print_freq=10, unsup_loss=1, hp_opt=False, save_sample_freq=None) diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index 6f36951..680c8a6 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -144,7 +144,11 @@ def train_classic(model, opt_param, epochs=1, print_freq=1): """ device = next(model.parameters()).device #opt = torch.optim.Adam(model.parameters(), lr=1e-3) - optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 + optim = torch.optim.SGD(model.parameters(), + lr=opt_param['Inner']['lr'], + momentum=opt_param['Inner']['momentum'], + weight_decay=opt_param['Inner']['decay'], + nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9 model.train() dl_val_it = iter(dl_val) @@ -232,7 +236,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start ## Optimizers ## #Inner Opt - inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 + optim = torch.optim.SGD(model.parameters(), + lr=opt_param['Inner']['lr'], + momentum=opt_param['Inner']['momentum'], + weight_decay=opt_param['Inner']['decay'], + nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9 diffopt = model['model'].get_diffopt( inner_opt,