diff --git a/higher/test_brutus.py b/higher/test_brutus.py index 60af723..8a5dd85 100755 --- a/higher/test_brutus.py +++ b/higher/test_brutus.py @@ -35,18 +35,21 @@ if __name__ == "__main__": n_inner_iter = 1 - epochs = 100 + epochs = 150 dataug_epoch_start=0 + #model = LeNet(3,10) + model = MobileNetV2(num_classes=10) + #model = WideResNet(num_classes=10, wrn_size=32) + tf_dict = {k: TF.TF_dict[k] for k in tf_names} t0 = time.process_time() - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device) - #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device) - print(str(aug_model), 'on', device_name) - #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=None, loss_patience=None) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=True, fixed_mag=True, shared_mag=True), model).to(device) + + print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None) exec_time=time.process_time() - t0 #### @@ -60,11 +63,10 @@ if __name__ == "__main__": #### t0 = time.process_time() - #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device) - aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device) - print(str(aug_model), 'on', device_name) - #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=None, loss_patience=None) + aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) + + print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None) exec_time=time.process_time() - t0 #### diff --git a/higher/test_dataug.py b/higher/test_dataug.py index ec3dbce..87dae1e 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -149,7 +149,6 @@ if __name__ == "__main__": #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) - #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None) exec_time=time.process_time() - t0