mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Final result brutus
This commit is contained in:
parent
6f9cb2cd68
commit
f657a9f989
16 changed files with 182704 additions and 4 deletions
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -618,7 +618,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
model['data_aug'].next_TF_set()
|
||||
#model['data_aug'].next_TF_set()
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
@ -682,8 +682,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
model.augment(mode=True)
|
||||
if inner_it != 0: high_grad_track = True
|
||||
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
||||
#print("Copy ", countcopy)
|
||||
return log
|
||||
return log
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue