Ajout option Weight decay / Nesterov sur inner opt

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-10 16:32:59 -05:00
parent 65e67addf6
commit 383f63c7b8
3 changed files with 26 additions and 12 deletions

View file

@ -972,7 +972,11 @@ class Augmented_model(nn.Module):
self._opt_param=opt_param
#Inner Opt
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
#Validation data
self._dl_val=dl_val

View file

@ -77,12 +77,12 @@ if __name__ == "__main__":
#Task to perform
tasks={
#'classic',
'aug_model'
'classic',
#'aug_model'
}
#Parameters
n_inner_iter = 1
epochs = 2
epochs = 150
dataug_epoch_start=0
optim_param={
'Meta':{
@ -93,16 +93,18 @@ if __name__ == "__main__":
'optim': 'SGD',
'lr':1e-2, #1e-2
'momentum':0.9, #0.9
'decay':0.0001,
'nesterov':True,
}
}
#Models
model = LeNet(3,10)
#model = LeNet(3,10)
#model = ResNet(num_classes=10)
#import torchvision.models as models
import torchvision.models as models
#model=models.resnet18()
model_name = str(model) #'wide_resnet50_2' #'resnet18' #str(model)
#model = getattr(models.resnet, model_name)(pretrained=False)
model_name = 'resnet50' #'wide_resnet50_2' #'resnet18' #str(model)
model = getattr(models.resnet, model_name)(pretrained=False)
#### Classic ####
if 'classic' in tasks:
@ -111,7 +113,7 @@ if __name__ == "__main__":
print("{} on {} for {} epochs".format(model_name, device_name, epochs))
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1)
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10)
#log= train_classic_higher(model=model, epochs=epochs)
exec_time=time.perf_counter() - t0
@ -161,7 +163,7 @@ if __name__ == "__main__":
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
print_freq=1,
print_freq=10,
unsup_loss=1,
hp_opt=False,
save_sample_freq=None)

View file

@ -144,7 +144,11 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
"""
device = next(model.parameters()).device
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
optim = torch.optim.SGD(model.parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
model.train()
dl_val_it = iter(dl_val)
@ -232,7 +236,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
## Optimizers ##
#Inner Opt
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
optim = torch.optim.SGD(model.parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
diffopt = model['model'].get_diffopt(
inner_opt,