diff --git a/higher/smart_aug/benchmark.py b/higher/smart_aug/benchmark.py index fcaa90f..d71baa4 100644 --- a/higher/smart_aug/benchmark.py +++ b/higher/smart_aug/benchmark.py @@ -14,7 +14,7 @@ optim_param={ }, 'Inner':{ 'optim': 'SGD', - 'lr':1e-1, #1e-2 + 'lr':1e-1, #1e-2 #1e-1 for ResNet 'momentum':0.9, #0.9 } } @@ -58,6 +58,8 @@ tf_names = [ #'Random', #'RandBlend' ] +tf_dict = {k: TF.TF_dict[k] for k in tf_names} + device = torch.device('cuda') @@ -75,9 +77,8 @@ np.random.seed(0) ########################################## if __name__ == "__main__": - - tf_dict = {k: TF.TF_dict[k] for k in tf_names} - + ### Benchmark ### + ''' for model_type in model_list.keys(): for model_name in model_list[model_type]: model = getattr(model_type, model_name)(pretrained=False) @@ -124,14 +125,15 @@ if __name__ == "__main__": print('Execution Time : %.00f '%(exec_time)) print('-'*9) + + ''' + ### HP Search ### inner_its = [1] dist_mix = [0.0, 0.5, 0.8, 1.0] dataug_epoch_starts= [0] - tf_dict = {k: TF.TF_dict[k] for k in tf_names} - TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)] - N_seq_TF= [4, 3, 2] - mag_setup = [(True,True), (False, False)] #(Fixed, Shared) + N_seq_TF= [2, 3, 4] + mag_setup = [(True,True), (False, False)] #(FxSh, Independant) #prob_setup = [True, False] nb_run= 3 @@ -149,12 +151,10 @@ if __name__ == "__main__": #for p_setup in prob_setup: p_setup=False for run in range(nb_run): - if (n_inner_iter == 0 and (m_setup!=(True,True) and p_setup!=True)) or (p_setup and dist!=0.0): continue #Autres setup inutiles sans meta-opti - #keys = list(TF.TF_dict.keys())[0:i] - #ntf_dict = {k: TF.TF_dict[k] for k in keys} t0 = time.process_time() + model = getattr(model_list.keys()[0], 'resnet18')(pretrained=False) model = Higher_model(model) #run_dist_dataugV3 aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) @@ -183,10 +183,6 @@ if __name__ == "__main__": print('Log :\"',f.name, '\" saved !') except: print("Failed to save logs :",f.name) - try: - plot_resV2(log, fig_name="../res/"+filename, param_names=aug_model.TF_names()) - except: - print("Failed to plot res") print('Execution Time : %.00f '%(exec_time)) print('-'*9)