mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
Ajout option Weight decay / Nesterov sur inner opt
This commit is contained in:
parent
65e67addf6
commit
383f63c7b8
3 changed files with 26 additions and 12 deletions
|
@ -972,7 +972,11 @@ class Augmented_model(nn.Module):
|
||||||
|
|
||||||
self._opt_param=opt_param
|
self._opt_param=opt_param
|
||||||
#Inner Opt
|
#Inner Opt
|
||||||
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(),
|
||||||
|
lr=opt_param['Inner']['lr'],
|
||||||
|
momentum=opt_param['Inner']['momentum'],
|
||||||
|
weight_decay=opt_param['Inner']['decay'],
|
||||||
|
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||||
|
|
||||||
#Validation data
|
#Validation data
|
||||||
self._dl_val=dl_val
|
self._dl_val=dl_val
|
||||||
|
|
|
@ -77,12 +77,12 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
#Task to perform
|
#Task to perform
|
||||||
tasks={
|
tasks={
|
||||||
#'classic',
|
'classic',
|
||||||
'aug_model'
|
#'aug_model'
|
||||||
}
|
}
|
||||||
#Parameters
|
#Parameters
|
||||||
n_inner_iter = 1
|
n_inner_iter = 1
|
||||||
epochs = 2
|
epochs = 150
|
||||||
dataug_epoch_start=0
|
dataug_epoch_start=0
|
||||||
optim_param={
|
optim_param={
|
||||||
'Meta':{
|
'Meta':{
|
||||||
|
@ -93,16 +93,18 @@ if __name__ == "__main__":
|
||||||
'optim': 'SGD',
|
'optim': 'SGD',
|
||||||
'lr':1e-2, #1e-2
|
'lr':1e-2, #1e-2
|
||||||
'momentum':0.9, #0.9
|
'momentum':0.9, #0.9
|
||||||
|
'decay':0.0001,
|
||||||
|
'nesterov':True,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#Models
|
#Models
|
||||||
model = LeNet(3,10)
|
#model = LeNet(3,10)
|
||||||
#model = ResNet(num_classes=10)
|
#model = ResNet(num_classes=10)
|
||||||
#import torchvision.models as models
|
import torchvision.models as models
|
||||||
#model=models.resnet18()
|
#model=models.resnet18()
|
||||||
model_name = str(model) #'wide_resnet50_2' #'resnet18' #str(model)
|
model_name = 'resnet50' #'wide_resnet50_2' #'resnet18' #str(model)
|
||||||
#model = getattr(models.resnet, model_name)(pretrained=False)
|
model = getattr(models.resnet, model_name)(pretrained=False)
|
||||||
|
|
||||||
#### Classic ####
|
#### Classic ####
|
||||||
if 'classic' in tasks:
|
if 'classic' in tasks:
|
||||||
|
@ -111,7 +113,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
||||||
print("{} on {} for {} epochs".format(model_name, device_name, epochs))
|
print("{} on {} for {} epochs".format(model_name, device_name, epochs))
|
||||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1)
|
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10)
|
||||||
#log= train_classic_higher(model=model, epochs=epochs)
|
#log= train_classic_higher(model=model, epochs=epochs)
|
||||||
|
|
||||||
exec_time=time.perf_counter() - t0
|
exec_time=time.perf_counter() - t0
|
||||||
|
@ -161,7 +163,7 @@ if __name__ == "__main__":
|
||||||
inner_it=n_inner_iter,
|
inner_it=n_inner_iter,
|
||||||
dataug_epoch_start=dataug_epoch_start,
|
dataug_epoch_start=dataug_epoch_start,
|
||||||
opt_param=optim_param,
|
opt_param=optim_param,
|
||||||
print_freq=1,
|
print_freq=10,
|
||||||
unsup_loss=1,
|
unsup_loss=1,
|
||||||
hp_opt=False,
|
hp_opt=False,
|
||||||
save_sample_freq=None)
|
save_sample_freq=None)
|
||||||
|
|
|
@ -144,7 +144,11 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
optim = torch.optim.SGD(model.parameters(),
|
||||||
|
lr=opt_param['Inner']['lr'],
|
||||||
|
momentum=opt_param['Inner']['momentum'],
|
||||||
|
weight_decay=opt_param['Inner']['decay'],
|
||||||
|
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
dl_val_it = iter(dl_val)
|
dl_val_it = iter(dl_val)
|
||||||
|
@ -232,7 +236,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
||||||
|
|
||||||
## Optimizers ##
|
## Optimizers ##
|
||||||
#Inner Opt
|
#Inner Opt
|
||||||
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
optim = torch.optim.SGD(model.parameters(),
|
||||||
|
lr=opt_param['Inner']['lr'],
|
||||||
|
momentum=opt_param['Inner']['momentum'],
|
||||||
|
weight_decay=opt_param['Inner']['decay'],
|
||||||
|
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||||
|
|
||||||
diffopt = model['model'].get_diffopt(
|
diffopt = model['model'].get_diffopt(
|
||||||
inner_opt,
|
inner_opt,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue