mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
leger chgt mag regularisation
This commit is contained in:
parent
64282bda3a
commit
cc737b7997
5 changed files with 21 additions and 8 deletions
|
@ -38,8 +38,8 @@ data_test = torchvision.datasets.CIFAR10(
|
|||
"./data", train=False, download=True, transform=transform
|
||||
)
|
||||
#'''
|
||||
train_subset_indices=range(int(len(data_train)/2))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
#train_subset_indices=range(int(len(data_train)/2))
|
||||
train_subset_indices=range(BATCH_SIZE*10)
|
||||
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||
|
|
|
@ -552,6 +552,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
})
|
||||
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Distribution
|
||||
self._samples = []
|
||||
self._mix_dist = False
|
||||
|
@ -561,8 +563,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
ignore={'Identity', 'FlipUD', 'FlipLR', 'Solarize', 'Posterize'}
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in ignore]
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
|
||||
self._reg_tgt = torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -38,7 +38,7 @@ else:
|
|||
if __name__ == "__main__":
|
||||
|
||||
n_inner_iter = 10
|
||||
epochs = 200
|
||||
epochs = 2
|
||||
dataug_epoch_start=0
|
||||
|
||||
#### Classic ####
|
||||
|
@ -84,7 +84,7 @@ if __name__ == "__main__":
|
|||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
||||
print('TF influence', TF_influence(log))
|
||||
plot_TF_influence(log, param_names=tf_names)
|
||||
print('Execution Time : %.00f '%(time.process_time() - t0))
|
||||
print('-'*9)
|
||||
#'''
|
||||
|
|
|
@ -52,6 +52,9 @@ TF_dict={ #Dataugv5
|
|||
#'Equalize': (lambda mag: None),
|
||||
}
|
||||
|
||||
TF_no_mag={'Identity', 'FlipUD', 'FlipLR'}
|
||||
TF_ignore_mag= TF_no_mag | {'Solarize', 'Posterize'}
|
||||
|
||||
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
||||
return (float_image*255.).type(torch.uint8)
|
||||
|
||||
|
|
|
@ -254,11 +254,20 @@ def print_torch_mem(add_info=''):
|
|||
torch.cuda.max_memory_cached()/ mega_bytes)
|
||||
print(string)
|
||||
|
||||
def TF_influence(log):
|
||||
def plot_TF_influence(log, fig_name='TF_influence', param_names=None):
|
||||
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
|
||||
return np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
|
||||
plt.figure()
|
||||
|
||||
mean = np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
|
||||
std = np.std(proba, axis=1)*np.std(mag, axis=1)
|
||||
plt.bar(param_names, mean, yerr=std)
|
||||
|
||||
plt.xticks(rotation=90)
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
||||
def __init__(self, patience, end_train=1):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue