From 9513483893bf1b20315bb68ccb22a688e35eab96 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Wed, 19 Feb 2020 11:59:04 -0500 Subject: [PATCH] meta-learning differee (train_utils.py) --- higher/smart_aug/train_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index b83863e..853c6c4 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -327,7 +327,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start #print(len(model['model']['functional']._fast_params),"step", time.process_time()-t) - if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step + if(high_grad_track and i>0 and i%inner_it==0 and epoch>=opt_param['Meta']['epoch_start']): #Perform Meta step #print("meta") val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss() #print_graph(val_loss) #to visualize computational graph @@ -349,9 +349,10 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start model['model'].detach_() meta_opt.zero_grad() - elif not high_grad_track: - #diffopt.detach_() + elif not high_grad_track or epoch