diff --git a/higher/smart_aug/test_dataug.py b/higher/smart_aug/test_dataug.py index 9858d4e..ce1cc0e 100755 --- a/higher/smart_aug/test_dataug.py +++ b/higher/smart_aug/test_dataug.py @@ -8,7 +8,7 @@ from dataug import * from train_utils import * from transformations import TF_loader -postfix='' +postfix='-metaScheduler' TF_loader=TF_loader() device = torch.device('cuda') #Select device to use @@ -40,9 +40,10 @@ if __name__ == "__main__": optim_param={ 'Meta':{ 'optim':'Adam', - 'lr':1e-2, #1e-2 + 'lr':1e-4, #1e-2 'epoch_start': 2, #0 / 2 (Resnet?) 'reg_factor': 0.001, + 'scheduler': 'multiStep', #None, 'multiStep' }, 'Inner':{ 'optim': 'SGD', @@ -138,7 +139,7 @@ if __name__ == "__main__": inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, opt_param=optim_param, - print_freq=10, + print_freq=20, unsup_loss=1, hp_opt=False, save_sample_freq=None) diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index ce4295e..1df18c8 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -292,6 +292,14 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start hyper_param += [param_group[param]] meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr']) #lr=1e-2 + meta_scheduler=None + if opt_param['Meta']['scheduler']=='multiStep': + meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt, + milestones=[int(epochs/3), int(epochs*2/3), int(epochs*2.7/3)], + gamma=10) + elif opt_param['Meta']['scheduler'] is not None: + raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler']) + model.train() meta_opt.zero_grad() @@ -356,12 +364,15 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start tf = time.perf_counter() + #Schedulers if inner_scheduler is not None: inner_scheduler.step() #Transfer inner_opt lr to diffopt for diff_param_group in diffopt.param_groups: for param_group in inner_opt.param_groups: diff_param_group['lr'] = param_group['lr'] + if meta_scheduler is not None: + meta_scheduler.step() if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving try: