Ajout Meta-scheduler a run_dist_dataugV3

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-26 12:19:44 -05:00
parent 2cbe3d09aa
commit e2691a1c38
2 changed files with 15 additions and 3 deletions

View file

@ -292,6 +292,14 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
hyper_param += [param_group[param]]
meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr']) #lr=1e-2
meta_scheduler=None
if opt_param['Meta']['scheduler']=='multiStep':
meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt,
milestones=[int(epochs/3), int(epochs*2/3), int(epochs*2.7/3)],
gamma=10)
elif opt_param['Meta']['scheduler'] is not None:
raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler'])
model.train()
meta_opt.zero_grad()
@ -356,12 +364,15 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
tf = time.perf_counter()
#Schedulers
if inner_scheduler is not None:
inner_scheduler.step()
#Transfer inner_opt lr to diffopt
for diff_param_group in diffopt.param_groups:
for param_group in inner_opt.param_groups:
diff_param_group['lr'] = param_group['lr']
if meta_scheduler is not None:
meta_scheduler.step()
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
try: