Minor improvement (RandAug)

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-30 11:21:25 -05:00
parent 6bba069d8a
commit 561b71b30a
5 changed files with 50 additions and 179 deletions

View file

@ -287,13 +287,19 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
diffopt.detach_()
model['model'].detach_()
meta_opt.zero_grad()
elif not high_grad_track:
diffopt.detach_()
model['model'].detach_()
tf = time.process_time()
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
try:
viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
model.train()
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch))
model.eval()
except:
print("Couldn't save samples epoch"+epoch)
pass
@ -315,9 +321,9 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
"acc": accuracy,
"time": tf - t0,
"mix_dist": model['data_aug']['mix_dist'].item(),
"param": param,
}
if not model['data_aug']._fixed_mix: data["mix_dist"]=model['data_aug']['mix_dist'].item()
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'].item(), 'momentum': p_grp['momentum'].item()} for p_grp in diffopt.param_groups]
log.append(data)
#############