diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index b83863e..853c6c4 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -327,7 +327,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start #print(len(model['model']['functional']._fast_params),"step", time.process_time()-t) - if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step + if(high_grad_track and i>0 and i%inner_it==0 and epoch>=opt_param['Meta']['epoch_start']): #Perform Meta step #print("meta") val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss() #print_graph(val_loss) #to visualize computational graph @@ -349,9 +349,10 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start model['model'].detach_() meta_opt.zero_grad() - elif not high_grad_track: - #diffopt.detach_() + elif not high_grad_track or epoch