diff --git a/higher/datasets.py b/higher/datasets.py index 6f06b4c..6be5e9e 100755 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -65,21 +65,26 @@ class AugmentedDataset(VisionDataset): self._TF = [ - 'Invert', - 'Cutout', - 'Sharpness', - 'AutoContrast', - 'Posterize', - 'ShearX', + ## Geometric TF ## + 'Rotate', 'TranslateX', 'TranslateY', + 'ShearX', 'ShearY', - 'Rotate', - 'Equalize', + + 'Cutout', + + ## Color TF ## 'Contrast', 'Color', - 'Solarize', - 'Brightness' + 'Brightness', + 'Sharpness', + #'Posterize', + #'Solarize', + + 'Invert', + 'AutoContrast', + 'Equalize', ] self._op_list =[] self.prob=0.5 @@ -119,6 +124,7 @@ class AugmentedDataset(VisionDataset): for idx, image in enumerate(self.sup_data): if (idx/self.dataset_info['sup'])%0.2==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup']) + #if idx==10000:break for _ in range(aug_copy): chosen_policy = policies[np.random.choice(len(policies))] diff --git a/higher/model.py b/higher/model.py index 7a10bce..ba8064f 100755 --- a/higher/model.py +++ b/higher/model.py @@ -94,6 +94,8 @@ class NetworkBlock(nn.Module): def forward(self, x): return self.layer(x) + +#wrn_size: 32 = WRN-28-2 ? 160 = WRN-28-10 class WideResNet(nn.Module): #def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0): diff --git a/higher/res/Aug_mod(Data_augV5(Uniform-14TFx3-Mag)-LeNet)-200 epochs (dataug:0)- 1 in_it.png b/higher/res/Aug_mod(Data_augV5(Uniform-14TFx3-Mag)-LeNet)-200 epochs (dataug:0)- 1 in_it.png deleted file mode 100755 index f15c3ee..0000000 Binary files a/higher/res/Aug_mod(Data_augV5(Uniform-14TFx3-Mag)-LeNet)-200 epochs (dataug:0)- 1 in_it.png and /dev/null differ diff --git a/higher/test_brutus.py b/higher/test_brutus.py index dd1d014..6935046 100755 --- a/higher/test_brutus.py +++ b/higher/test_brutus.py @@ -33,6 +33,49 @@ else: ########################################## if __name__ == "__main__": + + n_inner_iter = 1 + epochs = 200 + dataug_epoch_start=0 + + tf_dict = {k: TF.TF_dict[k] for k in tf_names} + + t0 = time.process_time() + + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device) + #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device) + print(str(aug_model), 'on', device_name) + #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=None, loss_patience=None) + + exec_time=time.process_time() - t0 + #### + times = [x["time"] for x in log] + out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} + filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) + with open("res/log/%s.json" % filename, "w+") as f: + json.dump(out, f, indent=True) + print('Log :\"',f.name, '\" saved !') + + #### + t0 = time.process_time() + + #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device) + aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device) + print(str(aug_model), 'on', device_name) + #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=None, loss_patience=None) + + exec_time=time.process_time() - t0 + #### + times = [x["time"] for x in log] + out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} + filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) + with open("res/log/%s.json" % filename, "w+") as f: + json.dump(out, f, indent=True) + print('Log :\"',f.name, '\" saved !') + + ''' res_folder="res/brutus-tests/" epochs= 150 inner_its = [1] @@ -80,4 +123,5 @@ if __name__ == "__main__": print('Log :\"',f.name, '\" saved !') #plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names) - print('-'*9) \ No newline at end of file + print('-'*9) + ''' \ No newline at end of file diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 0af56a4..40a94da 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -109,12 +109,12 @@ if __name__ == "__main__": t0 = time.process_time() data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2))) - data_train_aug.augement_data(aug_copy=10) + data_train_aug.augement_data(aug_copy=30) print(data_train_aug) dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True) - #xs, ys = next(iter(dl_train)) - #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug))) + xs, ys = next(iter(dl_train)) + viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug))) model = LeNet(3,10).to(device) #model = WideResNet(num_classes=10, wrn_size=16).to(device) @@ -149,9 +149,9 @@ if __name__ == "__main__": tf_dict = {k: TF.TF_dict[k] for k in tf_names} #aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device) - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device) - #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device) - #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device) + #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device) + #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None) diff --git a/higher/train_utils.py b/higher/train_utils.py index fc8e5e9..b11fb7b 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -625,9 +625,9 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f model_copy(src=fmodel, dst=model) optim_copy(dopt=diffopt, opt=inner_opt) - if epoch>50: - meta_opt.step() - model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 + #if epoch>50: + meta_opt.step() + model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 #model['data_aug'].next_TF_set() fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)