mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-06-29 00:15:25 +02:00
Minor improvement (RandAug)
This commit is contained in:
parent
6bba069d8a
commit
561b71b30a
5 changed files with 50 additions and 179 deletions
|
@ -287,13 +287,19 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
diffopt.detach_()
|
||||
model['model'].detach_()
|
||||
meta_opt.zero_grad()
|
||||
|
||||
elif not high_grad_track:
|
||||
diffopt.detach_()
|
||||
model['model'].detach_()
|
||||
|
||||
tf = time.process_time()
|
||||
|
||||
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
model.train()
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch))
|
||||
model.eval()
|
||||
except:
|
||||
print("Couldn't save samples epoch"+epoch)
|
||||
pass
|
||||
|
@ -315,9 +321,9 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"mix_dist": model['data_aug']['mix_dist'].item(),
|
||||
"param": param,
|
||||
}
|
||||
if not model['data_aug']._fixed_mix: data["mix_dist"]=model['data_aug']['mix_dist'].item()
|
||||
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'].item(), 'momentum': p_grp['momentum'].item()} for p_grp in diffopt.param_groups]
|
||||
log.append(data)
|
||||
#############
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue