mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-03 11:40:46 +02:00
Fixs + Benchmark
This commit is contained in:
parent
49472adfab
commit
f450fa8201
5 changed files with 60 additions and 16 deletions
|
@ -8,7 +8,8 @@ from transformations import TF_loader
|
|||
|
||||
import torchvision.models as models
|
||||
|
||||
model_list={models.resnet: ['resnet18', 'resnet50','wide_resnet50_2']} #lr=0.1
|
||||
#model_list={models.resnet: ['resnet18', 'resnet50','wide_resnet50_2']} #lr=0.1
|
||||
model_list={models.resnet: ['resnet18']}
|
||||
|
||||
optim_param={
|
||||
'Meta':{
|
||||
|
@ -28,7 +29,7 @@ optim_param={
|
|||
|
||||
res_folder="../res/benchmark/CIFAR10/"
|
||||
#res_folder="../res/HPsearch/"
|
||||
epochs= 300
|
||||
epochs= 400
|
||||
dataug_epoch_start=0
|
||||
nb_run= 1
|
||||
|
||||
|
@ -54,8 +55,8 @@ if __name__ == "__main__":
|
|||
|
||||
### Benchmark ###
|
||||
#'''
|
||||
n_inner_iter = 1
|
||||
dist_mix = [0.5, 1.0]
|
||||
n_inner_iter = 1#[0, 1]
|
||||
dist_mix = [0.5]
|
||||
N_seq_TF= [3, 4]
|
||||
mag_setup = [(False, False)] #[(True, True), (False, False)] #(FxSh, Independant)
|
||||
|
||||
|
|
14
higher/smart_aug/old/compare_TF.py
Normal file
14
higher/smart_aug/old/compare_TF.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
import RandAugment as rand
|
||||
import PIL
|
||||
import torchvision
|
||||
import transformations as TF
|
||||
tpil=torchvision.transforms.ToPILImage()
|
||||
ttensor=torchvision.transforms.ToTensor()
|
||||
|
||||
img,label =data_train[0]
|
||||
|
||||
rimg=ttensor(PIL.ImageEnhance.Color(tpil(img)).enhance(1.5))#ttensor(PIL.ImageOps.solarize(tpil(img), 50))#ttensor(tpil(img).transform(tpil(img).size, PIL.Image.AFFINE, (1, -0.1, 0, 0, 1, 0)))#rand.augmentations.FlipUD(tpil(img),1))
|
||||
timg=TF.color(img.unsqueeze(0),torch.Tensor([1.5])).squeeze(0)
|
||||
print(torch.allclose(rimg,timg, atol=1e-3))
|
||||
tpil(rimg).save('rimg.jpg')
|
||||
tpil(timg).save('timg.jpg')
|
|
@ -4,18 +4,20 @@ if __name__ == "__main__":
|
|||
|
||||
#'''
|
||||
files=[
|
||||
"../res/log/Aug_mod(Data_augV5(Mix0.8-3TFx2-MagFx)-resnet18)-2 epochs (dataug:0)- 1 in_it.json",
|
||||
"../res/HPsearch/log/Aug_mod(Data_augV5(Mix0.5-14TFx3-Mag)-ResNet)-200 epochs (dataug:0)- 1 in_it-0.json",
|
||||
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-0.json",
|
||||
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-1.json",
|
||||
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-2.json",
|
||||
#"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
]
|
||||
files = ["../res/benchmark/CIFAR10/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,0.17,'wide_resnet50_2', str(run)) for run in range(3)]
|
||||
files = ["../res/benchmark/CIFAR100/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-Mag)-%s)-200 epochs (dataug:0)- 1 in_it-%s.json"%(0.5,3,'wide_resnet50_2', str(run)) for run in range(3)]
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
plot_resV2(data['Log'], fig_name=file.replace("/log","").replace(".json",""), param_names=data['Param_names'])
|
||||
plot_resV2(data['Log'], fig_name=file.replace("/log","").replace(".json",""))#, param_names=data['Param_names'])
|
||||
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
||||
#'''
|
||||
## Loss , Acc, Proba = f(epoch) ##
|
||||
|
@ -95,17 +97,18 @@ if __name__ == "__main__":
|
|||
'''
|
||||
|
||||
'''
|
||||
#HP search
|
||||
inner_its = [1]
|
||||
dist_mix = [0]#[0.5, 0.8, 1.0] #Uniform
|
||||
N_seq_TF= [4, 3, 2]
|
||||
dist_mix = [0.3, 0.5, 0.8, 1.0] #Uniform
|
||||
N_seq_TF= [5]
|
||||
nb_run= 3
|
||||
|
||||
for n_inner_iter in inner_its:
|
||||
for n_tf in N_seq_TF:
|
||||
for dist in dist_mix:
|
||||
|
||||
#files = ["../res/brutus-tests2/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-MagFxSh)-ResNet18)-150 epochs (dataug:0)- 1 in_it-%s.json"%(dist, n_tf, str(run)) for run in range(nb_run)]
|
||||
files = ["../res/brutus-tests2/log/Aug_mod(Data_augV5(Uniform-14TFx%d-MagFxSh)-ResNet18)-150 epochs (dataug:0)- 1 in_it-%s.json"%(n_tf, str(run)) for run in range(nb_run)]
|
||||
#files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-MagFxSh)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(dist, n_tf, str(run)) for run in range(nb_run)]
|
||||
files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Uniform-14TFx%d-MagFxSh)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(n_tf, str(run)) for run in range(nb_run)]
|
||||
accs = []
|
||||
times = []
|
||||
for idx, file in enumerate(files):
|
||||
|
@ -117,4 +120,30 @@ if __name__ == "__main__":
|
|||
print(idx, data['Accuracy'])
|
||||
|
||||
print(files[0], 'acc', np.mean(accs), '+-',np.std(accs), ',t', np.mean(times))
|
||||
'''
|
||||
|
||||
'''
|
||||
#Benchmark
|
||||
model_list=['resnet18', 'resnet50','wide_resnet50_2']
|
||||
nb_run= 3
|
||||
|
||||
for model_name in model_list:
|
||||
|
||||
files = ["../res/benchmark/CIFAR100/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,0.17,model_name, str(run)) for run in range(nb_run)]
|
||||
#files = ["../res/benchmark/CIFAR10/log/%s-200 epochs -%s.json"%(model_name, str(run)) for run in range(nb_run)]
|
||||
|
||||
accs = []
|
||||
times = []
|
||||
mem_alloc = []
|
||||
mem_cach = []
|
||||
for idx, file in enumerate(files):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
accs.append(data['Accuracy'])
|
||||
times.append(data['Time'][0])
|
||||
mem_cach.append(data['Memory'])
|
||||
print(idx, data['Accuracy'])
|
||||
|
||||
print(files[0], 'acc', np.mean(accs), '+-',np.std(accs), ',t', np.mean(times), 'Mem', np.mean(mem_cach))
|
||||
'''
|
|
@ -69,8 +69,8 @@ if __name__ == "__main__":
|
|||
model = model.to(device)
|
||||
|
||||
|
||||
print("{} on {} for {} epochs".format(model_name, device_name, epochs))
|
||||
#print("RandAugment(N{}-M{:.2f})-{} on {} for {} epochs".format(rand_aug['N'],rand_aug['M'],model_name, device_name, epochs))
|
||||
print("{} on {} for {} epochs{}".format(model_name, device_name, epochs, postfix))
|
||||
#print("RandAugment(N{}-M{:.2f})-{} on {} for {} epochs{}".format(rand_aug['N'],rand_aug['M'],model_name, device_name, epochs, postfix))
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10)
|
||||
#log= train_classic_higher(model=model, epochs=epochs)
|
||||
|
||||
|
@ -88,9 +88,9 @@ if __name__ == "__main__":
|
|||
#"Rand_Aug": rand_aug,
|
||||
"Log": log}
|
||||
print(model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs".format(model_name,epochs)
|
||||
filename = "{}-{} epochs".format(model_name,epochs)+postfix
|
||||
#print("RandAugment-",model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
#filename = "RandAugment(N{}-M{:.2f})-{}-{} epochs".format(rand_aug['N'],rand_aug['M'],model_name,epochs)+'-cosine'
|
||||
#filename = "RandAugment(N{}-M{:.2f})-{}-{} epochs".format(rand_aug['N'],rand_aug['M'],model_name,epochs)+postfix
|
||||
with open("../res/log/%s.json" % filename, "w+") as f:
|
||||
try:
|
||||
json.dump(out, f, indent=True)
|
||||
|
@ -128,7 +128,7 @@ if __name__ == "__main__":
|
|||
TF_ignore_mag=tf_ignore_mag), model).to(device)
|
||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||
print("{} on {} for {} epochs - {} inner_it{}".format(str(aug_model), device_name, epochs, n_inner_iter, postfix))
|
||||
log= run_dist_dataugV3(model=aug_model,
|
||||
epochs=epochs,
|
||||
inner_it=n_inner_iter,
|
||||
|
|
|
@ -247,7 +247,6 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
device = next(model.parameters()).device
|
||||
log = []
|
||||
dl_val_it = iter(dl_val)
|
||||
val_loss=None
|
||||
|
||||
high_grad_track = True
|
||||
if inner_it == 0: #No HP optimization
|
||||
|
@ -298,6 +297,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
|
||||
for epoch in range(1, epochs+1):
|
||||
t0 = time.perf_counter()
|
||||
val_loss=None
|
||||
|
||||
#Cross-Validation
|
||||
#dl_train, dl_val = cvs.next_split()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue