mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Supression res mauvais test brutus
This commit is contained in:
parent
6b7ed4836e
commit
c4e2e30151
19 changed files with 11 additions and 182711 deletions
|
@ -2,7 +2,7 @@ from utils import *
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
#'''
|
'''
|
||||||
files=[
|
files=[
|
||||||
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||||
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||||
|
@ -16,7 +16,7 @@ if __name__ == "__main__":
|
||||||
data = json.load(json_file)
|
data = json.load(json_file)
|
||||||
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
||||||
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
||||||
#'''
|
'''
|
||||||
## Loss , Acc, Proba = f(epoch) ##
|
## Loss , Acc, Proba = f(epoch) ##
|
||||||
#plot_compare(filenames=files, fig_name="res/compare")
|
#plot_compare(filenames=files, fig_name="res/compare")
|
||||||
|
|
||||||
|
@ -76,11 +76,11 @@ if __name__ == "__main__":
|
||||||
'''
|
'''
|
||||||
|
|
||||||
#Res print
|
#Res print
|
||||||
'''
|
#'''
|
||||||
nb_run=3
|
nb_run=3
|
||||||
accs = []
|
accs = []
|
||||||
times = []
|
times = []
|
||||||
files = ["res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-150epochs(dataug:0)-1in_it-%s.json"%str(run) for run in range(nb_run)]
|
files = ["res/brutus-tests/log/Aug_mod(Data_augV5(Mix1.0-18TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-1in_it-%s.json"%str(run) for run in range(nb_run)]
|
||||||
|
|
||||||
for idx, file in enumerate(files):
|
for idx, file in enumerate(files):
|
||||||
#legend+=str(idx)+'-'+file+'\n'
|
#legend+=str(idx)+'-'+file+'\n'
|
||||||
|
@ -90,4 +90,4 @@ if __name__ == "__main__":
|
||||||
times.append(data['Time'][0])
|
times.append(data['Time'][0])
|
||||||
|
|
||||||
print(files[0], np.mean(accs), np.std(accs), np.mean(times))
|
print(files[0], np.mean(accs), np.std(accs), np.mean(times))
|
||||||
'''
|
#'''
|
|
@ -5,7 +5,7 @@ from torch.distributions import *
|
||||||
|
|
||||||
#import kornia
|
#import kornia
|
||||||
#import random
|
#import random
|
||||||
#import numpy as np
|
import numpy as np
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
import transformations as TF
|
import transformations as TF
|
||||||
|
@ -692,7 +692,6 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
else:
|
else:
|
||||||
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
class Data_augV6(nn.Module): #Optimisation sequentielle
|
class Data_augV6(nn.Module): #Optimisation sequentielle
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=None, fixed_mag=True, shared_mag=True):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=None, fixed_mag=True, shared_mag=True):
|
||||||
super(Data_augV6, self).__init__()
|
super(Data_augV6, self).__init__()
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -96,7 +96,7 @@ if __name__ == "__main__":
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#tf_dict = TF.TF_dict
|
#tf_dict = TF.TF_dict
|
||||||
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
|
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
|
||||||
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device)
|
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
|
|
|
@ -616,9 +616,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
model_copy(src=fmodel, dst=model)
|
model_copy(src=fmodel, dst=model)
|
||||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
|
||||||
meta_opt.step()
|
if epoch>50:
|
||||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
meta_opt.step()
|
||||||
#model['data_aug'].next_TF_set()
|
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||||
|
#model['data_aug'].next_TF_set()
|
||||||
|
|
||||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue