leger chgt mag regularisation

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-19 21:46:14 -05:00
parent 64282bda3a
commit cc737b7997
5 changed files with 21 additions and 8 deletions

View file

@ -38,8 +38,8 @@ data_test = torchvision.datasets.CIFAR10(
"./data", train=False, download=True, transform=transform
)
#'''
train_subset_indices=range(int(len(data_train)/2))
#train_subset_indices=range(BATCH_SIZE*10)
#train_subset_indices=range(int(len(data_train)/2))
train_subset_indices=range(BATCH_SIZE*10)
val_subset_indices=range(int(len(data_train)/2),len(data_train))
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))

View file

@ -552,6 +552,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]
})
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
#Distribution
self._samples = []
self._mix_dist = False
@ -561,8 +563,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
#Mag regularisation
if not self._fixed_mag:
ignore={'Identity', 'FlipUD', 'FlipLR', 'Solarize', 'Posterize'}
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in ignore]
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
self._reg_tgt = torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
def forward(self, x):

View file

@ -38,7 +38,7 @@ else:
if __name__ == "__main__":
n_inner_iter = 10
epochs = 200
epochs = 2
dataug_epoch_start=0
#### Classic ####
@ -84,7 +84,7 @@ if __name__ == "__main__":
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
print('TF influence', TF_influence(log))
plot_TF_influence(log, param_names=tf_names)
print('Execution Time : %.00f '%(time.process_time() - t0))
print('-'*9)
#'''

View file

@ -52,6 +52,9 @@ TF_dict={ #Dataugv5
#'Equalize': (lambda mag: None),
}
TF_no_mag={'Identity', 'FlipUD', 'FlipLR'}
TF_ignore_mag= TF_no_mag | {'Solarize', 'Posterize'}
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
return (float_image*255.).type(torch.uint8)

View file

@ -254,11 +254,20 @@ def print_torch_mem(add_info=''):
torch.cuda.max_memory_cached()/ mega_bytes)
print(string)
def TF_influence(log):
def plot_TF_influence(log, fig_name='TF_influence', param_names=None):
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
return np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
plt.figure()
mean = np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
std = np.std(proba, axis=1)*np.std(mag, axis=1)
plt.bar(param_names, mean, yerr=std)
plt.xticks(rotation=90)
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
class loss_monitor(): #Voir https://github.com/pytorch/ignite
def __init__(self, patience, end_train=1):