mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Ajout RandAugment
This commit is contained in:
parent
3c2022de32
commit
4a7e73088d
4 changed files with 249 additions and 37 deletions
|
@ -540,6 +540,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
#if inner_it!=0:
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2)
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
|
@ -680,5 +681,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
model.augment(mode=True)
|
||||
if inner_it != 0: high_grad_track = True
|
||||
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
||||
#print("Copy ", countcopy)
|
||||
return log
|
Loading…
Add table
Add a link
Reference in a new issue