minor changes

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-31 10:34:44 -05:00
parent bf29d4fb6d
commit cd6e159b77
6 changed files with 59 additions and 95 deletions

View file

@ -144,6 +144,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
#print_torch_mem("Start epoch")
t0 = time.process_time()
for i, (features, labels) in enumerate(dl_train):
#viz_sample_data(imgs=features, labels=labels, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
#print_torch_mem("Start iter")
features,labels = features.to(device), labels.to(device)
@ -277,7 +278,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
meta_opt.step()
#Adjust Hyper-parameters
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
model['data_aug'].adjust_param() #Contrainte sum(proba)=1
if hp_opt:
for param_group in diffopt.param_groups:
for param in list(opt_param['Inner'].keys())[1:]:
@ -289,7 +290,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
meta_opt.zero_grad()
elif not high_grad_track:
diffopt.detach_()
#diffopt.detach_()
model['model'].detach_()
tf = time.process_time()