From 385bc9977ce8a832fc742c5c24ae62a0db24b9b9 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Mon, 3 Feb 2020 15:08:22 -0500 Subject: [PATCH] Cross Validation splits + New mesure process time (train utils) --- higher/smart_aug/datasets.py | 27 +++++++++++++++++++++----- higher/smart_aug/test_dataug.py | 34 ++++++++++++++++----------------- higher/smart_aug/train_utils.py | 20 +++++++++++-------- 3 files changed, 51 insertions(+), 30 deletions(-) diff --git a/higher/smart_aug/datasets.py b/higher/smart_aug/datasets.py index 2f0b32d..431749e 100755 --- a/higher/smart_aug/datasets.py +++ b/higher/smart_aug/datasets.py @@ -19,6 +19,8 @@ download_data=False num_workers=2 #4 #Pin GPU memory pin_memory=False #True :+ GPU memory / + Lent +#Data storage folder +dataroot="../data" #ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1] #transform_train = torchvision.transforms.Compose([ @@ -41,7 +43,6 @@ transform_train = torchvision.transforms.Compose([ #transform_train.transforms.insert(0, RandAugment(n=2, m=30)) ### Classic Dataset ### -dataroot="../data" #MNIST #data_train = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=transform_train) @@ -70,11 +71,27 @@ data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=downloa #data_test = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) -train_subset_indices=range(int(len(data_train)/2)) -val_subset_indices=range(int(len(data_train)/2),len(data_train)) +#Validation set size [0, 1] +#valid_size=0.1 +#train_subset_indices=range(int(len(data_train)*(1-valid_size))) +#val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train)) #train_subset_indices=range(BATCH_SIZE*10) #val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20) -dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory) -dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory) +#dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory) +#dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory) dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) + +#Cross Validation +from skorch.dataset import CVSplit +cvs = CVSplit(cv=5) + +def next_CVSplit(): + + train_subset, val_subset = cvs(data_train) + dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) + dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) + + return dl_train, dl_val + +dl_train, dl_val = next_CVSplit() \ No newline at end of file diff --git a/higher/smart_aug/test_dataug.py b/higher/smart_aug/test_dataug.py index 0610723..ac81561 100755 --- a/higher/smart_aug/test_dataug.py +++ b/higher/smart_aug/test_dataug.py @@ -13,19 +13,19 @@ tf_names = [ 'Identity', 'FlipUD', 'FlipLR', - #'Rotate', - #'TranslateX', - #'TranslateY', - #'ShearX', - #'ShearY', + 'Rotate', + 'TranslateX', + 'TranslateY', + 'ShearX', + 'ShearY', ## Color TF (Expect image in the range of [0, 1]) ## - #'Contrast', - #'Color', - #'Brightness', - #'Sharpness', - #'Posterize', - #'Solarize', #=>Image entre [0,1] #Pas opti pour des batch + 'Contrast', + 'Color', + 'Brightness', + 'Sharpness', + 'Posterize', + 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch #Color TF (Common mag scale) #'+Contrast', @@ -74,12 +74,12 @@ if __name__ == "__main__": #Task to perform tasks={ - 'classic', - #'aug_model' + #'classic', + 'aug_model' } #Parameters n_inner_iter = 1 - epochs = 2 + epochs = 150 dataug_epoch_start=0 optim_param={ 'Meta':{ @@ -147,7 +147,7 @@ if __name__ == "__main__": tf_dict = {k: TF.TF_dict[k] for k in tf_names} model = Higher_model(model, model_name) #run_dist_dataugV3 - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) @@ -156,7 +156,7 @@ if __name__ == "__main__": inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, opt_param=optim_param, - print_freq=1, + print_freq=20, unsup_loss=1, hp_opt=False, save_sample_freq=None) @@ -174,7 +174,7 @@ if __name__ == "__main__": "Param_names": aug_model.TF_names(), "Log": log} print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) - filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) + filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)+"(CV)" with open("../res/log/%s.json" % filename, "w+") as f: try: json.dump(out, f, indent=True) diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index 28aa04a..4b95c9e 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -150,7 +150,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1): log = [] for epoch in range(epochs): #print_torch_mem("Start epoch") - t0 = time.process_time() + t0 = time.perf_counter() for i, (features, labels) in enumerate(dl_train): #viz_sample_data(imgs=features, labels=labels, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch)) #print_torch_mem("Start iter") @@ -164,7 +164,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1): optim.step() #### Tests #### - tf = time.process_time() + tf = time.perf_counter() val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) accuracy, f1 =test(model) @@ -176,8 +176,8 @@ def train_classic(model, opt_param, epochs=1, print_freq=1): print('Epoch : %d/%d'%(epoch,epochs)) print('Time : %.00f'%(tf - t0)) print('Train loss :',loss.item(), '/ val loss', val_loss.item()) - print('Accuracy :', accuracy) - print('F1 :', f1.data) + print('Accuracy max:', accuracy) + print('F1 :', f1) #### Log #### data={ @@ -219,7 +219,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start """ device = next(model.parameters()).device log = [] - dl_val_it = iter(dl_val) + #dl_val_it = iter(dl_val) val_loss=None high_grad_track = True @@ -251,8 +251,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start meta_opt.zero_grad() for epoch in range(1, epochs+1): - t0 = time.process_time() + t0 = time.perf_counter() + dl_train, dl_val = next_CVSplit() + dl_val_it = iter(dl_val) + for i, (xs, ys) in enumerate(dl_train): xs, ys = xs.to(device), ys.to(device) @@ -303,7 +306,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start #diffopt.detach_() model['model'].detach_() - tf = time.process_time() + tf = time.perf_counter() if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving try: @@ -345,7 +348,8 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start print('Epoch : %d/%d'%(epoch,epochs)) print('Time : %.00f'%(tf - t0)) print('Train loss :',loss.item(), '/ val loss', val_loss.item()) - print('Accuracy :', max([x["acc"] for x in log])) + print('Accuracy max:', max([x["acc"] for x in log])) + print('F1 :', f1) print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data) #print('proba grad',model['data_aug']['prob'].grad)