From 240ec5e581d3734376a4f0b4b7c6d007cb9faefb Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Fri, 6 Dec 2019 14:19:05 -0500 Subject: [PATCH] Test MobileNet Brutus --- higher/test_brutus.py | 24 +++++++++++++----------- higher/test_dataug.py | 1 - 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/higher/test_brutus.py b/higher/test_brutus.py index 60af723..8a5dd85 100755 --- a/higher/test_brutus.py +++ b/higher/test_brutus.py @@ -35,18 +35,21 @@ if __name__ == "__main__": n_inner_iter = 1 - epochs = 100 + epochs = 150 dataug_epoch_start=0 + #model = LeNet(3,10) + model = MobileNetV2(num_classes=10) + #model = WideResNet(num_classes=10, wrn_size=32) + tf_dict = {k: TF.TF_dict[k] for k in tf_names} t0 = time.process_time() - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device) - #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device) - print(str(aug_model), 'on', device_name) - #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=None, loss_patience=None) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=True, fixed_mag=True, shared_mag=True), model).to(device) + + print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None) exec_time=time.process_time() - t0 #### @@ -60,11 +63,10 @@ if __name__ == "__main__": #### t0 = time.process_time() - #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device) - aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device) - print(str(aug_model), 'on', device_name) - #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=None, loss_patience=None) + aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) + + print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None) exec_time=time.process_time() - t0 #### diff --git a/higher/test_dataug.py b/higher/test_dataug.py index ec3dbce..87dae1e 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -149,7 +149,6 @@ if __name__ == "__main__": #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) - #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None) exec_time=time.process_time() - t0