mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Correction test MobileNet Brutus
This commit is contained in:
parent
48c3925d74
commit
6c0597e7ea
4 changed files with 140 additions and 22 deletions
|
@ -629,7 +629,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
|
|||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None):
|
||||
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
|
||||
device = next(model.parameters()).device
|
||||
log = []
|
||||
countcopy=0
|
||||
|
@ -796,8 +796,12 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
model.augment(mode=True)
|
||||
if inner_it != 0: high_grad_track = True
|
||||
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
except:
|
||||
print("Couldn't save finals samples")
|
||||
pass
|
||||
|
||||
#print("Copy ", countcopy)
|
||||
return log
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue