mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Ajout option Weight decay / Nesterov sur inner opt
This commit is contained in:
parent
65e67addf6
commit
383f63c7b8
3 changed files with 26 additions and 12 deletions
|
@ -144,7 +144,11 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
"""
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
optim = torch.optim.SGD(model.parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
momentum=opt_param['Inner']['momentum'],
|
||||
weight_decay=opt_param['Inner']['decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
|
@ -232,7 +236,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
|
||||
## Optimizers ##
|
||||
#Inner Opt
|
||||
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
optim = torch.optim.SGD(model.parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
momentum=opt_param['Inner']['momentum'],
|
||||
weight_decay=opt_param['Inner']['decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
diffopt = model['model'].get_diffopt(
|
||||
inner_opt,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue