Bon setup fin tests brutus

This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-04 10:37:56 -05:00
parent c4e2e30151
commit 33ef7afd04
2 changed files with 10 additions and 10 deletions

View file

@ -91,7 +91,7 @@ if __name__ == "__main__":
print('-'*9) print('-'*9)
''' '''
#### Augmented Model #### #### Augmented Model ####
#''' '''
t0 = time.process_time() t0 = time.process_time()
tf_dict = {k: TF.TF_dict[k] for k in tf_names} tf_dict = {k: TF.TF_dict[k] for k in tf_names}
#tf_dict = TF.TF_dict #tf_dict = TF.TF_dict
@ -117,9 +117,9 @@ if __name__ == "__main__":
print('Execution Time : %.00f '%(time.process_time() - t0)) print('Execution Time : %.00f '%(time.process_time() - t0))
print('-'*9) print('-'*9)
#'''
#### TF tests ####
''' '''
#### TF tests ####
#'''
res_folder="res/brutus-tests/" res_folder="res/brutus-tests/"
epochs= 150 epochs= 150
inner_its = [1] inner_its = [1]
@ -169,4 +169,4 @@ if __name__ == "__main__":
#plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names) #plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names)
print('-'*9) print('-'*9)
''' #'''

View file

@ -616,7 +616,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
model_copy(src=fmodel, dst=model) model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt) optim_copy(dopt=diffopt, opt=inner_opt)
if epoch>50: #if epoch>50:
meta_opt.step() meta_opt.step()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
#model['data_aug'].next_TF_set() #model['data_aug'].next_TF_set()
@ -683,8 +683,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
model.augment(mode=True) model.augment(mode=True)
if inner_it != 0: high_grad_track = True if inner_it != 0: high_grad_track = True
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) #viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
#print("Copy ", countcopy) #print("Copy ", countcopy)
return log return log