mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
Ajout Meta-scheduler a run_dist_dataugV3
This commit is contained in:
parent
2cbe3d09aa
commit
e2691a1c38
2 changed files with 15 additions and 3 deletions
|
@ -8,7 +8,7 @@ from dataug import *
|
||||||
from train_utils import *
|
from train_utils import *
|
||||||
from transformations import TF_loader
|
from transformations import TF_loader
|
||||||
|
|
||||||
postfix=''
|
postfix='-metaScheduler'
|
||||||
TF_loader=TF_loader()
|
TF_loader=TF_loader()
|
||||||
|
|
||||||
device = torch.device('cuda') #Select device to use
|
device = torch.device('cuda') #Select device to use
|
||||||
|
@ -40,9 +40,10 @@ if __name__ == "__main__":
|
||||||
optim_param={
|
optim_param={
|
||||||
'Meta':{
|
'Meta':{
|
||||||
'optim':'Adam',
|
'optim':'Adam',
|
||||||
'lr':1e-2, #1e-2
|
'lr':1e-4, #1e-2
|
||||||
'epoch_start': 2, #0 / 2 (Resnet?)
|
'epoch_start': 2, #0 / 2 (Resnet?)
|
||||||
'reg_factor': 0.001,
|
'reg_factor': 0.001,
|
||||||
|
'scheduler': 'multiStep', #None, 'multiStep'
|
||||||
},
|
},
|
||||||
'Inner':{
|
'Inner':{
|
||||||
'optim': 'SGD',
|
'optim': 'SGD',
|
||||||
|
@ -138,7 +139,7 @@ if __name__ == "__main__":
|
||||||
inner_it=n_inner_iter,
|
inner_it=n_inner_iter,
|
||||||
dataug_epoch_start=dataug_epoch_start,
|
dataug_epoch_start=dataug_epoch_start,
|
||||||
opt_param=optim_param,
|
opt_param=optim_param,
|
||||||
print_freq=10,
|
print_freq=20,
|
||||||
unsup_loss=1,
|
unsup_loss=1,
|
||||||
hp_opt=False,
|
hp_opt=False,
|
||||||
save_sample_freq=None)
|
save_sample_freq=None)
|
||||||
|
|
|
@ -292,6 +292,14 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
||||||
hyper_param += [param_group[param]]
|
hyper_param += [param_group[param]]
|
||||||
meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr']) #lr=1e-2
|
meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||||
|
|
||||||
|
meta_scheduler=None
|
||||||
|
if opt_param['Meta']['scheduler']=='multiStep':
|
||||||
|
meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt,
|
||||||
|
milestones=[int(epochs/3), int(epochs*2/3), int(epochs*2.7/3)],
|
||||||
|
gamma=10)
|
||||||
|
elif opt_param['Meta']['scheduler'] is not None:
|
||||||
|
raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler'])
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
meta_opt.zero_grad()
|
meta_opt.zero_grad()
|
||||||
|
|
||||||
|
@ -356,12 +364,15 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
||||||
|
|
||||||
tf = time.perf_counter()
|
tf = time.perf_counter()
|
||||||
|
|
||||||
|
#Schedulers
|
||||||
if inner_scheduler is not None:
|
if inner_scheduler is not None:
|
||||||
inner_scheduler.step()
|
inner_scheduler.step()
|
||||||
#Transfer inner_opt lr to diffopt
|
#Transfer inner_opt lr to diffopt
|
||||||
for diff_param_group in diffopt.param_groups:
|
for diff_param_group in diffopt.param_groups:
|
||||||
for param_group in inner_opt.param_groups:
|
for param_group in inner_opt.param_groups:
|
||||||
diff_param_group['lr'] = param_group['lr']
|
diff_param_group['lr'] = param_group['lr']
|
||||||
|
if meta_scheduler is not None:
|
||||||
|
meta_scheduler.step()
|
||||||
|
|
||||||
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
|
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue