mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
Ajout option Weight decay / Nesterov sur inner opt
This commit is contained in:
parent
65e67addf6
commit
383f63c7b8
3 changed files with 26 additions and 12 deletions
|
@ -77,12 +77,12 @@ if __name__ == "__main__":
|
|||
|
||||
#Task to perform
|
||||
tasks={
|
||||
#'classic',
|
||||
'aug_model'
|
||||
'classic',
|
||||
#'aug_model'
|
||||
}
|
||||
#Parameters
|
||||
n_inner_iter = 1
|
||||
epochs = 2
|
||||
epochs = 150
|
||||
dataug_epoch_start=0
|
||||
optim_param={
|
||||
'Meta':{
|
||||
|
@ -93,16 +93,18 @@ if __name__ == "__main__":
|
|||
'optim': 'SGD',
|
||||
'lr':1e-2, #1e-2
|
||||
'momentum':0.9, #0.9
|
||||
'decay':0.0001,
|
||||
'nesterov':True,
|
||||
}
|
||||
}
|
||||
|
||||
#Models
|
||||
model = LeNet(3,10)
|
||||
#model = LeNet(3,10)
|
||||
#model = ResNet(num_classes=10)
|
||||
#import torchvision.models as models
|
||||
import torchvision.models as models
|
||||
#model=models.resnet18()
|
||||
model_name = str(model) #'wide_resnet50_2' #'resnet18' #str(model)
|
||||
#model = getattr(models.resnet, model_name)(pretrained=False)
|
||||
model_name = 'resnet50' #'wide_resnet50_2' #'resnet18' #str(model)
|
||||
model = getattr(models.resnet, model_name)(pretrained=False)
|
||||
|
||||
#### Classic ####
|
||||
if 'classic' in tasks:
|
||||
|
@ -111,7 +113,7 @@ if __name__ == "__main__":
|
|||
|
||||
|
||||
print("{} on {} for {} epochs".format(model_name, device_name, epochs))
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1)
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10)
|
||||
#log= train_classic_higher(model=model, epochs=epochs)
|
||||
|
||||
exec_time=time.perf_counter() - t0
|
||||
|
@ -161,7 +163,7 @@ if __name__ == "__main__":
|
|||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=1,
|
||||
print_freq=10,
|
||||
unsup_loss=1,
|
||||
hp_opt=False,
|
||||
save_sample_freq=None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue