Mise a jour de toute les modifs... (Higher: Ajout deux TF, modification val loss, ajout prob dans sample image, ...)

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-10 13:21:34 -05:00
parent e75fb96716
commit c8ce6c8024
6 changed files with 299 additions and 64 deletions

View file

@ -6,21 +6,21 @@ from train_utils import *
tf_names = [
## Geometric TF ##
'Identity',
'FlipUD',
'FlipLR',
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
#'FlipUD',
#'FlipLR',
#'Rotate',
#'TranslateX',
#'TranslateY',
#'ShearX',
#'ShearY',
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
#'Contrast',
#'Color',
#'Brightness',
#'Sharpness',
#'Posterize',
#'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
#Color TF (Common mag scale)
#'+Contrast',
@ -49,6 +49,8 @@ tf_names = [
#'BadContrast',
#'BadBrightness',
'Random',
#'RandBlend'
#Non fonctionnel
#'Auto_Contrast', #Pas opti pour des batch (Super lent)
#'Equalize',
@ -65,12 +67,12 @@ else:
if __name__ == "__main__":
tasks={
'classic',
#'classic',
#'aug_dataset',
#'aug_model'
'aug_model'
}
n_inner_iter = 1
epochs = 100
epochs = 1
dataug_epoch_start=0
optim_param={
'Meta':{
@ -84,9 +86,9 @@ if __name__ == "__main__":
}
}
#model = LeNet(3,10)
model = LeNet(3,10)
#model = MobileNetV2(num_classes=10)
model = ResNet(num_classes=10)
#model = ResNet(num_classes=10)
#model = WideResNet(num_classes=10, wrn_size=32)
#### Classic ####
@ -95,8 +97,8 @@ if __name__ == "__main__":
model = model.to(device)
print("{} on {} for {} epochs".format(str(model), device_name, epochs))
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1)
#log= train_classic_higher(model=model, epochs=epochs)
#log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10)
log= train_classic_higher(model=model, epochs=epochs)
exec_time=time.process_time() - t0
####
@ -138,7 +140,7 @@ if __name__ == "__main__":
data_train_aug.augement_data(aug_copy=1)
print(data_train_aug)
unsup_ratio = 5
dl_unsup = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE*unsup_ratio, shuffle=True)
dl_unsup = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE*unsup_ratio, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
unsup_xs, sup_xs, ys = next(iter(dl_unsup))
viz_sample_data(imgs=sup_xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug)))
@ -172,7 +174,7 @@ if __name__ == "__main__":
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), model).to(device)
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True), model).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
@ -181,7 +183,7 @@ if __name__ == "__main__":
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
print_freq=10,
print_freq=1,
KLdiv=True,
loss_patience=None)