Correction test MobileNet Brutus

This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-09 10:46:53 -05:00
parent 48c3925d74
commit 6c0597e7ea
4 changed files with 140 additions and 22 deletions

View file

@ -629,7 +629,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
print("Copy ", countcopy)
return log
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None):
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
device = next(model.parameters()).device
log = []
countcopy=0
@ -796,8 +796,12 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
model.augment(mode=True)
if inner_it != 0: high_grad_track = True
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
try:
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
except:
print("Couldn't save finals samples")
pass
#print("Copy ", countcopy)
return log