mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-03 11:40:46 +02:00
Ajout option Weight decay / Nesterov sur inner opt
This commit is contained in:
parent
65e67addf6
commit
383f63c7b8
3 changed files with 26 additions and 12 deletions
|
@ -972,7 +972,11 @@ class Augmented_model(nn.Module):
|
|||
|
||||
self._opt_param=opt_param
|
||||
#Inner Opt
|
||||
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
momentum=opt_param['Inner']['momentum'],
|
||||
weight_decay=opt_param['Inner']['decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
#Validation data
|
||||
self._dl_val=dl_val
|
||||
|
|
|
@ -77,12 +77,12 @@ if __name__ == "__main__":
|
|||
|
||||
#Task to perform
|
||||
tasks={
|
||||
#'classic',
|
||||
'aug_model'
|
||||
'classic',
|
||||
#'aug_model'
|
||||
}
|
||||
#Parameters
|
||||
n_inner_iter = 1
|
||||
epochs = 2
|
||||
epochs = 150
|
||||
dataug_epoch_start=0
|
||||
optim_param={
|
||||
'Meta':{
|
||||
|
@ -93,16 +93,18 @@ if __name__ == "__main__":
|
|||
'optim': 'SGD',
|
||||
'lr':1e-2, #1e-2
|
||||
'momentum':0.9, #0.9
|
||||
'decay':0.0001,
|
||||
'nesterov':True,
|
||||
}
|
||||
}
|
||||
|
||||
#Models
|
||||
model = LeNet(3,10)
|
||||
#model = LeNet(3,10)
|
||||
#model = ResNet(num_classes=10)
|
||||
#import torchvision.models as models
|
||||
import torchvision.models as models
|
||||
#model=models.resnet18()
|
||||
model_name = str(model) #'wide_resnet50_2' #'resnet18' #str(model)
|
||||
#model = getattr(models.resnet, model_name)(pretrained=False)
|
||||
model_name = 'resnet50' #'wide_resnet50_2' #'resnet18' #str(model)
|
||||
model = getattr(models.resnet, model_name)(pretrained=False)
|
||||
|
||||
#### Classic ####
|
||||
if 'classic' in tasks:
|
||||
|
@ -111,7 +113,7 @@ if __name__ == "__main__":
|
|||
|
||||
|
||||
print("{} on {} for {} epochs".format(model_name, device_name, epochs))
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1)
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10)
|
||||
#log= train_classic_higher(model=model, epochs=epochs)
|
||||
|
||||
exec_time=time.perf_counter() - t0
|
||||
|
@ -161,7 +163,7 @@ if __name__ == "__main__":
|
|||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=1,
|
||||
print_freq=10,
|
||||
unsup_loss=1,
|
||||
hp_opt=False,
|
||||
save_sample_freq=None)
|
||||
|
|
|
@ -144,7 +144,11 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
"""
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
optim = torch.optim.SGD(model.parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
momentum=opt_param['Inner']['momentum'],
|
||||
weight_decay=opt_param['Inner']['decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
|
@ -232,7 +236,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
|
||||
## Optimizers ##
|
||||
#Inner Opt
|
||||
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
optim = torch.optim.SGD(model.parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
momentum=opt_param['Inner']['momentum'],
|
||||
weight_decay=opt_param['Inner']['decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
diffopt = model['model'].get_diffopt(
|
||||
inner_opt,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue