Ajout option Weight decay / Nesterov sur inner opt

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-10 16:32:59 -05:00
parent 65e67addf6
commit 383f63c7b8
3 changed files with 26 additions and 12 deletions

View file

@ -144,7 +144,11 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
"""
device = next(model.parameters()).device
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
optim = torch.optim.SGD(model.parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
model.train()
dl_val_it = iter(dl_val)
@ -232,7 +236,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
## Optimizers ##
#Inner Opt
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
optim = torch.optim.SGD(model.parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
diffopt = model['model'].get_diffopt(
inner_opt,