mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Bon setup fin tests brutus
This commit is contained in:
parent
c4e2e30151
commit
33ef7afd04
2 changed files with 10 additions and 10 deletions
|
@ -91,7 +91,7 @@ if __name__ == "__main__":
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
'''
|
'''
|
||||||
#### Augmented Model ####
|
#### Augmented Model ####
|
||||||
#'''
|
'''
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#tf_dict = TF.TF_dict
|
#tf_dict = TF.TF_dict
|
||||||
|
@ -117,9 +117,9 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print('Execution Time : %.00f '%(time.process_time() - t0))
|
print('Execution Time : %.00f '%(time.process_time() - t0))
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
#'''
|
|
||||||
#### TF tests ####
|
|
||||||
'''
|
'''
|
||||||
|
#### TF tests ####
|
||||||
|
#'''
|
||||||
res_folder="res/brutus-tests/"
|
res_folder="res/brutus-tests/"
|
||||||
epochs= 150
|
epochs= 150
|
||||||
inner_its = [1]
|
inner_its = [1]
|
||||||
|
@ -169,4 +169,4 @@ if __name__ == "__main__":
|
||||||
#plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names)
|
#plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names)
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
|
|
||||||
'''
|
#'''
|
||||||
|
|
|
@ -616,10 +616,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
model_copy(src=fmodel, dst=model)
|
model_copy(src=fmodel, dst=model)
|
||||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
|
||||||
if epoch>50:
|
#if epoch>50:
|
||||||
meta_opt.step()
|
meta_opt.step()
|
||||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||||
#model['data_aug'].next_TF_set()
|
#model['data_aug'].next_TF_set()
|
||||||
|
|
||||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||||
|
@ -683,8 +683,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
model.augment(mode=True)
|
model.augment(mode=True)
|
||||||
if inner_it != 0: high_grad_track = True
|
if inner_it != 0: high_grad_track = True
|
||||||
|
|
||||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||||
|
|
||||||
#print("Copy ", countcopy)
|
#print("Copy ", countcopy)
|
||||||
return log
|
return log
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue