mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Modification/Suppression des BadTF
This commit is contained in:
parent
23351ec13c
commit
3425ba2ceb
2 changed files with 14 additions and 18 deletions
|
@ -746,7 +746,10 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
#if epoch>50:
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
#model['data_aug'].next_TF_set()
|
||||
try: #Dataugv6
|
||||
model['data_aug'].next_TF_set()
|
||||
except:
|
||||
pass
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
@ -754,7 +757,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
tf = time.process_time()
|
||||
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
|
||||
|
||||
if(not high_grad_track):
|
||||
countcopy+=1
|
||||
|
@ -810,7 +813,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
dataug_epoch_start = epoch
|
||||
model.augment(mode=True)
|
||||
if inner_it != 0: high_grad_track = True
|
||||
|
||||
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue