diff --git a/higher/test_dataug.py b/higher/test_dataug.py index d7c7716..405db63 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -91,7 +91,7 @@ if __name__ == "__main__": print('-'*9) ''' #### Augmented Model #### - #''' + ''' t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict @@ -117,9 +117,9 @@ if __name__ == "__main__": print('Execution Time : %.00f '%(time.process_time() - t0)) print('-'*9) - #''' - #### TF tests #### ''' + #### TF tests #### + #''' res_folder="res/brutus-tests/" epochs= 150 inner_its = [1] @@ -169,4 +169,4 @@ if __name__ == "__main__": #plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names) print('-'*9) - ''' + #''' diff --git a/higher/train_utils.py b/higher/train_utils.py index d3451bf..72bda2b 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -616,10 +616,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f model_copy(src=fmodel, dst=model) optim_copy(dopt=diffopt, opt=inner_opt) - if epoch>50: - meta_opt.step() - model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 - #model['data_aug'].next_TF_set() + #if epoch>50: + meta_opt.step() + model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 + #model['data_aug'].next_TF_set() fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) @@ -683,8 +683,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f model.augment(mode=True) if inner_it != 0: high_grad_track = True - viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) - viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) + #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + #viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) #print("Copy ", countcopy) return log