Ajout option Weight decay / Nesterov sur inner opt

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-10 16:32:59 -05:00
parent 65e67addf6
commit 383f63c7b8
3 changed files with 26 additions and 12 deletions

View file

@ -972,7 +972,11 @@ class Augmented_model(nn.Module):
self._opt_param=opt_param self._opt_param=opt_param
#Inner Opt #Inner Opt
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
#Validation data #Validation data
self._dl_val=dl_val self._dl_val=dl_val

View file

@ -77,12 +77,12 @@ if __name__ == "__main__":
#Task to perform #Task to perform
tasks={ tasks={
#'classic', 'classic',
'aug_model' #'aug_model'
} }
#Parameters #Parameters
n_inner_iter = 1 n_inner_iter = 1
epochs = 2 epochs = 150
dataug_epoch_start=0 dataug_epoch_start=0
optim_param={ optim_param={
'Meta':{ 'Meta':{
@ -93,16 +93,18 @@ if __name__ == "__main__":
'optim': 'SGD', 'optim': 'SGD',
'lr':1e-2, #1e-2 'lr':1e-2, #1e-2
'momentum':0.9, #0.9 'momentum':0.9, #0.9
'decay':0.0001,
'nesterov':True,
} }
} }
#Models #Models
model = LeNet(3,10) #model = LeNet(3,10)
#model = ResNet(num_classes=10) #model = ResNet(num_classes=10)
#import torchvision.models as models import torchvision.models as models
#model=models.resnet18() #model=models.resnet18()
model_name = str(model) #'wide_resnet50_2' #'resnet18' #str(model) model_name = 'resnet50' #'wide_resnet50_2' #'resnet18' #str(model)
#model = getattr(models.resnet, model_name)(pretrained=False) model = getattr(models.resnet, model_name)(pretrained=False)
#### Classic #### #### Classic ####
if 'classic' in tasks: if 'classic' in tasks:
@ -111,7 +113,7 @@ if __name__ == "__main__":
print("{} on {} for {} epochs".format(model_name, device_name, epochs)) print("{} on {} for {} epochs".format(model_name, device_name, epochs))
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1) log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10)
#log= train_classic_higher(model=model, epochs=epochs) #log= train_classic_higher(model=model, epochs=epochs)
exec_time=time.perf_counter() - t0 exec_time=time.perf_counter() - t0
@ -161,7 +163,7 @@ if __name__ == "__main__":
inner_it=n_inner_iter, inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start, dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param, opt_param=optim_param,
print_freq=1, print_freq=10,
unsup_loss=1, unsup_loss=1,
hp_opt=False, hp_opt=False,
save_sample_freq=None) save_sample_freq=None)

View file

@ -144,7 +144,11 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
""" """
device = next(model.parameters()).device device = next(model.parameters()).device
#opt = torch.optim.Adam(model.parameters(), lr=1e-3) #opt = torch.optim.Adam(model.parameters(), lr=1e-3)
optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 optim = torch.optim.SGD(model.parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
model.train() model.train()
dl_val_it = iter(dl_val) dl_val_it = iter(dl_val)
@ -232,7 +236,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
## Optimizers ## ## Optimizers ##
#Inner Opt #Inner Opt
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 optim = torch.optim.SGD(model.parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
diffopt = model['model'].get_diffopt( diffopt = model['model'].get_diffopt(
inner_opt, inner_opt,