diff --git a/higher/train_utils.py b/higher/train_utils.py index f4c5045..4866249 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -682,8 +682,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f model.augment(mode=True) if inner_it != 0: high_grad_track = True - #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) - #viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) + viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) #print("Copy ", countcopy) return log