mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
minor changes
This commit is contained in:
parent
bf29d4fb6d
commit
cd6e159b77
6 changed files with 59 additions and 95 deletions
|
@ -144,6 +144,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
#print_torch_mem("Start epoch")
|
||||
t0 = time.process_time()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
#viz_sample_data(imgs=features, labels=labels, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#print_torch_mem("Start iter")
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
|
@ -277,7 +278,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
meta_opt.step()
|
||||
|
||||
#Adjust Hyper-parameters
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
model['data_aug'].adjust_param() #Contrainte sum(proba)=1
|
||||
if hp_opt:
|
||||
for param_group in diffopt.param_groups:
|
||||
for param in list(opt_param['Inner'].keys())[1:]:
|
||||
|
@ -289,7 +290,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
meta_opt.zero_grad()
|
||||
|
||||
elif not high_grad_track:
|
||||
diffopt.detach_()
|
||||
#diffopt.detach_()
|
||||
model['model'].detach_()
|
||||
|
||||
tf = time.process_time()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue