From cc737b79973abcf51fe590dd3780f447be841e47 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Tue, 19 Nov 2019 21:46:14 -0500 Subject: [PATCH] leger chgt mag regularisation --- higher/datasets.py | 4 ++-- higher/dataug.py | 5 +++-- higher/test_dataug.py | 4 ++-- higher/transformations.py | 3 +++ higher/utils.py | 13 +++++++++++-- 5 files changed, 21 insertions(+), 8 deletions(-) diff --git a/higher/datasets.py b/higher/datasets.py index 7d0589f..17be0ff 100644 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -38,8 +38,8 @@ data_test = torchvision.datasets.CIFAR10( "./data", train=False, download=True, transform=transform ) #''' -train_subset_indices=range(int(len(data_train)/2)) -#train_subset_indices=range(BATCH_SIZE*10) +#train_subset_indices=range(int(len(data_train)/2)) +train_subset_indices=range(BATCH_SIZE*10) val_subset_indices=range(int(len(data_train)/2),len(data_train)) dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices)) diff --git a/higher/dataug.py b/higher/dataug.py index b80de85..277d900 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -552,6 +552,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX] }) + #for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag + #Distribution self._samples = [] self._mix_dist = False @@ -561,8 +563,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) #Mag regularisation if not self._fixed_mag: - ignore={'Identity', 'FlipUD', 'FlipLR', 'Solarize', 'Posterize'} - self._reg_mask=[self._TF.index(t) for t in self._TF if t not in ignore] + self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag] self._reg_tgt = torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max def forward(self, x): diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 56c31a6..e71d4c6 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -38,7 +38,7 @@ else: if __name__ == "__main__": n_inner_iter = 10 - epochs = 200 + epochs = 2 dataug_epoch_start=0 #### Classic #### @@ -84,7 +84,7 @@ if __name__ == "__main__": json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') - print('TF influence', TF_influence(log)) + plot_TF_influence(log, param_names=tf_names) print('Execution Time : %.00f '%(time.process_time() - t0)) print('-'*9) #''' diff --git a/higher/transformations.py b/higher/transformations.py index fa4b4b8..f2f5958 100644 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -52,6 +52,9 @@ TF_dict={ #Dataugv5 #'Equalize': (lambda mag: None), } +TF_no_mag={'Identity', 'FlipUD', 'FlipLR'} +TF_ignore_mag= TF_no_mag | {'Solarize', 'Posterize'} + def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039) return (float_image*255.).type(torch.uint8) diff --git a/higher/utils.py b/higher/utils.py index 7df741a..1f2d0e6 100644 --- a/higher/utils.py +++ b/higher/utils.py @@ -254,11 +254,20 @@ def print_torch_mem(add_info=''): torch.cuda.max_memory_cached()/ mega_bytes) print(string) -def TF_influence(log): +def plot_TF_influence(log, fig_name='TF_influence', param_names=None): proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])] mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])] - return np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean + plt.figure() + + mean = np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean + std = np.std(proba, axis=1)*np.std(mag, axis=1) + plt.bar(param_names, mean, yerr=std) + + plt.xticks(rotation=90) + fig_name = fig_name.replace('.',',') + plt.savefig(fig_name, bbox_inches='tight') + plt.close() class loss_monitor(): #Voir https://github.com/pytorch/ignite def __init__(self, patience, end_train=1):