diff --git a/higher/res/LeNet-100 epochs.png b/higher/res/LeNet-100 epochs.png index 88a1a6a..bf1475d 100644 Binary files a/higher/res/LeNet-100 epochs.png and b/higher/res/LeNet-100 epochs.png differ diff --git a/higher/test_dataug.py b/higher/test_dataug.py index a7593e8..e0ca566 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -11,7 +11,36 @@ BATCH_SIZE = 300 #TEST_SIZE = 300 TEST_SIZE = 10000 +tf_names = [ + ## Geometric TF ## + 'Identity', + 'FlipUD', + 'FlipLR', + #'Rotate', + #'TranslateX', + #'TranslateY', + #'ShearX', + #'ShearY', + + ## Color TF (Expect image in the range of [0, 1]) ## + #'Contrast', + #'Color', + 'Brightness', + #'Sharpness', + #'Posterize', + #'Solarize', #=>Image entre [0,1] #Pas opti pour des batch + + #Non fonctionnel + #'Auto_Contrast', #Pas opti pour des batch (Super lent) + #'Equalize', +] + #ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1] +#transform_train = torchvision.transforms.Compose([ +# torchvision.transforms.RandomHorizontalFlip(), +# torchvision.transforms.ToTensor(), + #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #CIFAR10 +#]) transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #CIFAR10 @@ -31,6 +60,9 @@ data_test = torchvision.datasets.MNIST( data_train = torchvision.datasets.CIFAR10( "./data", train=True, download=True, transform=transform ) +#data_val = torchvision.datasets.CIFAR10( +# "./data", train=True, download=True, transform=transform +#) data_test = torchvision.datasets.CIFAR10( "./data", train=False, download=True, transform=transform ) @@ -81,7 +113,7 @@ def train_classic(model, epochs=1): dl_val_it = iter(dl_val) log = [] for epoch in range(epochs): - print_torch_mem("Start epoch") + #print_torch_mem("Start epoch") t0 = time.process_time() for i, (features, labels) in enumerate(dl_train): #print_torch_mem("Start iter") @@ -132,8 +164,8 @@ def train_classic_higher(model, epochs=1): #with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, diffopt): for epoch in range(epochs): - print_torch_mem("Start epoch "+str(epoch)) - print("Fast param ",len(fmodel._fast_params)) + #print_torch_mem("Start epoch "+str(epoch)) + #print("Fast param ",len(fmodel._fast_params)) t0 = time.process_time() for i, (features, labels) in enumerate(dl_train): #print_torch_mem("Start iter") @@ -702,8 +734,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f ########################################## if __name__ == "__main__": - n_inner_iter = 0 - epochs = 2 + n_inner_iter = 10 + epochs = 100 dataug_epoch_start=0 #### Classic #### @@ -714,7 +746,8 @@ if __name__ == "__main__": #model.augment(mode=False) print(str(model), 'on', device_name) - log= train_classic_higher(model=model, epochs=epochs) + log= train_classic(model=model, epochs=epochs) + #log= train_classic_higher(model=model, epochs=epochs) #### plot_res(log, fig_name="res/{}-{} epochs".format(str(model),epochs)) @@ -728,11 +761,13 @@ if __name__ == "__main__": print('-'*9) ''' #### Augmented Model #### - ''' + #''' + tf_dict = {k: TF.TF_dict[k] for k in tf_names} + #tf_dict = TF.TF_dict aug_model = Augmented_model(Data_augV4(TF_dict=TF.TF_dict, mix_dist=0.0), LeNet(3,10)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10) #### plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)) @@ -744,13 +779,14 @@ if __name__ == "__main__": json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') print('-'*9) - ''' + #''' ## TF number tests ## + ''' res_folder="res/TF_nb_tests/" epochs= 200 inner_its = [0, 10] dataug_epoch_starts= [0, -1] - max_TF_nb = len(TF.TF_dict) + TF_nb = [14] #range(1,len(TF.TF_dict)+1) try: os.mkdir(res_folder) @@ -762,14 +798,14 @@ if __name__ == "__main__": print("---Starting inner_it", n_inner_iter,"---") for dataug_epoch_start in dataug_epoch_starts: print("---Starting dataug", dataug_epoch_start,"---") - for i in range(1,max_TF_nb): + for i in TF_nb: keys = list(TF.TF_dict.keys())[0:i] ntf_dict = {k: TF.TF_dict[k] for k in keys} aug_model = Augmented_model(Data_augV4(TF_dict=ntf_dict, mix_dist=0.0), LeNet(3,10)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10) #### plot_res(log, fig_name=res_folder+"{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)) @@ -780,4 +816,6 @@ if __name__ == "__main__": with open(res_folder+"log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f: json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') - print('-'*9) \ No newline at end of file + print('-'*9) + + ''' \ No newline at end of file