mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
LR scheduler + Resolution pb ResNet50/WRN
This commit is contained in:
parent
383f63c7b8
commit
79de0191a8
4 changed files with 78 additions and 21 deletions
|
@ -143,6 +143,8 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
(list) Logs of training. Each items is a dict containing results of an epoch.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
|
||||
#Optimizer
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
|
@ -150,11 +152,28 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
weight_decay=opt_param['Inner']['decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
#Scheduler
|
||||
inner_scheduler=None
|
||||
if opt_param['Inner']['scheduler']=='cosine':
|
||||
inner_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=epochs, eta_min=0.)
|
||||
elif opt_param['Inner']['scheduler']=='multiStep':
|
||||
#Multistep milestones inspired by AutoAugment
|
||||
inner_scheduler=torch.optim.lr_scheduler.MultiStepLR(optim,
|
||||
milestones=[int(epochs/3), int(epochs*2/3), int(epochs*2.7/3)],
|
||||
gamma=0.1)
|
||||
elif opt_param['Inner']['scheduler']=='exponential':
|
||||
#inner_scheduler=torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.1) #Wrong gamma
|
||||
inner_scheduler=torch.optim.lr_scheduler.LambdaLR(optim, lambda epoch: (1 - epoch / epochs) ** 0.9)
|
||||
elif opt_param['Inner']['scheduler'] is not None:
|
||||
raise ValueError("Lr scheduler unknown : %s"%opt_param['Inner']['scheduler'])
|
||||
|
||||
#Training
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
log = []
|
||||
for epoch in range(epochs):
|
||||
#print_torch_mem("Start epoch")
|
||||
#print(optim.param_groups[0]['lr'])
|
||||
t0 = time.perf_counter()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
#viz_sample_data(imgs=features, labels=labels, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
|
@ -168,6 +187,10 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
|
||||
if inner_scheduler is not None:
|
||||
inner_scheduler.step()
|
||||
|
||||
#### Tests ####
|
||||
tf = time.perf_counter()
|
||||
|
||||
|
@ -175,15 +198,6 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
accuracy, f1 =test(model)
|
||||
model.train()
|
||||
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy max:', accuracy)
|
||||
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
|
@ -196,6 +210,14 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy max:', max([x["acc"] for x in log]))
|
||||
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
|
||||
|
||||
return log
|
||||
|
||||
|
@ -236,8 +258,8 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
|
||||
## Optimizers ##
|
||||
#Inner Opt
|
||||
optim = torch.optim.SGD(model.parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
inner_opt = torch.optim.SGD(model['model']['original'].parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
momentum=opt_param['Inner']['momentum'],
|
||||
weight_decay=opt_param['Inner']['decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
@ -247,6 +269,21 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
||||
track_higher_grads=high_grad_track)
|
||||
|
||||
#Scheduler
|
||||
inner_scheduler=None
|
||||
if opt_param['Inner']['scheduler']=='cosine':
|
||||
inner_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=epochs, eta_min=0.)
|
||||
elif opt_param['Inner']['scheduler']=='multiStep':
|
||||
#Multistep milestones inspired by AutoAugment
|
||||
inner_scheduler=torch.optim.lr_scheduler.MultiStepLR(optim,
|
||||
milestones=[int(epochs/3), int(epochs*2/3), int(epochs*2.7/3)],
|
||||
gamma=0.1)
|
||||
elif opt_param['Inner']['scheduler']=='exponential':
|
||||
#inner_scheduler=torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.1) #Wrong gamma
|
||||
inner_scheduler=torch.optim.lr_scheduler.LambdaLR(optim, lambda epoch: (1 - epoch / epochs) ** 0.9)
|
||||
elif opt_param['Inner']['scheduler'] is not None:
|
||||
raise ValueError("Lr scheduler unknown : %s"%opt_param['Inner']['scheduler'])
|
||||
|
||||
#Meta Opt
|
||||
hyper_param = list(model['data_aug'].parameters())
|
||||
if hp_opt :
|
||||
|
@ -286,7 +323,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
#print_graph(loss) #to visualize computational graph
|
||||
|
||||
#t = time.process_time()
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
diffopt.step(loss)#(opt.zero_grad, loss.backward, opt.step)
|
||||
#print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
|
||||
|
||||
|
||||
|
@ -318,6 +355,13 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
|
||||
tf = time.perf_counter()
|
||||
|
||||
if inner_scheduler is not None:
|
||||
inner_scheduler.step()
|
||||
#Transfer inner_opt lr to diffopt
|
||||
for diff_param_group in diffopt.param_groups:
|
||||
for param_group in inner_opt.param_groups:
|
||||
diff_param_group['lr'] = param_group['lr']
|
||||
|
||||
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
|
@ -396,6 +440,8 @@ def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, un
|
|||
Training loss can either be computed directly from augmented inputs (unsup_loss=0).
|
||||
However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
|
||||
|
||||
Does not support LR scheduler.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Augmented model to train.
|
||||
opt_param (dict): Dictionnary containing optimizers parameters.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue