mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
Suite de test brutus
This commit is contained in:
parent
cc737b7997
commit
0e7ec8b5b0
6 changed files with 65 additions and 47 deletions
|
@ -31,7 +31,7 @@ def LeNet(images, num_classes):
|
|||
n_n_fc2 = 500; # number of neurons of first fully connected layer (default = 576)
|
||||
|
||||
# 1.layer: convolution + max pooling
|
||||
W_conv1_tf = weight_variable([s_f_conv1, s_f_conv1, 1, n_f_conv1], name = 'W_conv1_tf') # (5,5,1,32)
|
||||
W_conv1_tf = weight_variable([s_f_conv1, s_f_conv1, images.shape[3], n_f_conv1], name = 'W_conv1_tf') # (5,5,1,32)
|
||||
b_conv1_tf = bias_variable([n_f_conv1], name = 'b_conv1_tf') # (32)
|
||||
h_conv1_tf = tf.nn.relu(conv2d(images,
|
||||
W_conv1_tf) + b_conv1_tf,
|
||||
|
@ -76,4 +76,4 @@ def LeNet(images, num_classes):
|
|||
# tf.argmax(y_data_tf, 1),
|
||||
# name = 'y_pred_correct_tf')
|
||||
logits = z_pred_tf
|
||||
return logits #y_pred_proba_tf
|
||||
return logits #y_pred_proba_tf
|
||||
|
|
|
@ -2,9 +2,10 @@ from utils import *
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
'''
|
||||
#'''
|
||||
files=[
|
||||
"res/log/Aug_mod(Data_augV5(Mix0.5-11TFx1-MagSh)-LeNet)-2 epochs (dataug:0)- 1 in_it.json",
|
||||
"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
]
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
|
@ -12,10 +13,12 @@ if __name__ == "__main__":
|
|||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
||||
'''
|
||||
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
||||
#'''
|
||||
## Loss , Acc, Proba = f(epoch) ##
|
||||
#plot_compare(filenames=files, fig_name="res/compare")
|
||||
|
||||
'''
|
||||
## Acc, Time, Epochs = f(n_tf) ##
|
||||
#fig_name="res/TF_nb_tests_compare"
|
||||
fig_name="res/TF_seq_tests_compare"
|
||||
|
@ -67,4 +70,5 @@ if __name__ == "__main__":
|
|||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
plt.close()
|
||||
'''
|
|
@ -38,8 +38,8 @@ data_test = torchvision.datasets.CIFAR10(
|
|||
"./data", train=False, download=True, transform=transform
|
||||
)
|
||||
#'''
|
||||
#train_subset_indices=range(int(len(data_train)/2))
|
||||
train_subset_indices=range(BATCH_SIZE*10)
|
||||
train_subset_indices=range(int(len(data_train)/2))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||
|
|
|
@ -563,8 +563,11 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
|
||||
self._reg_tgt = torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
if self._shared_mag :
|
||||
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
|
||||
else:
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
def forward(self, x):
|
||||
if self._data_augmentation:
|
||||
|
@ -628,7 +631,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX) #Bloque a PARAMETER_MAX
|
||||
|
||||
def loss_weight(self):
|
||||
# 1 seule TF
|
||||
|
@ -638,7 +641,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
#w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
#Plusieurs TF sequentielles
|
||||
#Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !)
|
||||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||
for sample in self._samples:
|
||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||
|
@ -650,8 +653,12 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
return reg_factor * F.mse_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt.to(self._params['mag'].device), reduction='mean')
|
||||
if self._fixed_mag:
|
||||
return torch.tensor(0)
|
||||
else:
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||
return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean')
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
|
|
|
@ -38,7 +38,7 @@ else:
|
|||
if __name__ == "__main__":
|
||||
|
||||
n_inner_iter = 10
|
||||
epochs = 2
|
||||
epochs = 200
|
||||
dataug_epoch_start=0
|
||||
|
||||
#### Classic ####
|
||||
|
@ -64,39 +64,43 @@ if __name__ == "__main__":
|
|||
print('-'*9)
|
||||
'''
|
||||
#### Augmented Model ####
|
||||
#'''
|
||||
'''
|
||||
t0 = time.process_time()
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
#tf_dict = TF.TF_dict
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=False, shared_mag=True), LeNet(3,10)).to(device)
|
||||
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10)
|
||||
|
||||
####
|
||||
plot_resV2(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), param_names=tf_names)
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
with open("res/log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f:
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
|
||||
with open("res/log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
||||
plot_TF_influence(log, param_names=tf_names)
|
||||
plot_resV2(log, fig_name="res/"+filename, param_names=tf_names)
|
||||
|
||||
print('Execution Time : %.00f '%(time.process_time() - t0))
|
||||
print('-'*9)
|
||||
#'''
|
||||
#### TF number tests ####
|
||||
'''
|
||||
res_folder="res/TF_nb_tests/"
|
||||
epochs= 100
|
||||
#### TF tests ####
|
||||
#'''
|
||||
res_folder="res/brutus-tests/"
|
||||
epochs= 150
|
||||
inner_its = [0, 1, 10]
|
||||
dist_mix = [0.0, 0.5]
|
||||
dataug_epoch_starts= [0]
|
||||
TF_nb = [len(TF.TF_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
|
||||
N_seq_TF= [2, 3, 4, 6]
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
|
||||
N_seq_TF= [1,2,3,4]#[2, 3, 4, 6]
|
||||
mag_setup = [(True,True), (False,True), (False, False)]
|
||||
nb_run= 3
|
||||
|
||||
try:
|
||||
os.mkdir(res_folder)
|
||||
|
@ -105,28 +109,31 @@ if __name__ == "__main__":
|
|||
pass
|
||||
|
||||
for n_inner_iter in inner_its:
|
||||
print("---Starting inner_it", n_inner_iter,"---")
|
||||
for dataug_epoch_start in dataug_epoch_starts:
|
||||
print("---Starting dataug", dataug_epoch_start,"---")
|
||||
for n_tf in N_seq_TF:
|
||||
for i in TF_nb:
|
||||
keys = list(TF.TF_dict.keys())[0:i]
|
||||
ntf_dict = {k: TF.TF_dict[k] for k in keys}
|
||||
for dist in dist_mix:
|
||||
#for i in TF_nb:
|
||||
for m_setup in mag_setup:
|
||||
for run in range(nb_run):
|
||||
#keys = list(TF.TF_dict.keys())[0:i]
|
||||
#ntf_dict = {k: TF.TF_dict[k] for k in keys}
|
||||
|
||||
aug_model = Augmented_model(Data_augV4(TF_dict=ntf_dict, N_TF=n_tf, mix_dist=0.0), LeNet(3,10)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_mag=m_setup[0], shared_mag=m_setup[1]), LeNet(3,10)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=20, loss_patience=None)
|
||||
|
||||
####
|
||||
plot_res(log, fig_name=res_folder+"{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), param_names=keys)
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in (s?):", out["Time"][0], "+/-", out["Time"][1])
|
||||
with open(res_folder+"log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
print('-'*9)
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in (s?):", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{}epochs(dataug:{})-{}in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter,run)
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
||||
'''
|
||||
#plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names)
|
||||
print('-'*9)
|
||||
|
||||
#'''
|
|
@ -651,7 +651,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
print('TF Mag :', model['data_aug']['mag'].data)
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
#############
|
||||
#### Log ####
|
||||
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue