mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Changes since Teledyne
This commit is contained in:
parent
bd5dc63cff
commit
1060f18033
203 changed files with 24395 additions and 0 deletions
73
higher/smart_aug/arg_parser.py
Normal file
73
higher/smart_aug/arg_parser.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
import argparse
|
||||
|
||||
#Argparse
|
||||
parser = argparse.ArgumentParser(description='Run smart augmentation')
|
||||
parser.add_argument('-dv','--device', default='cuda', dest='device',
|
||||
help='Device : cpu / cuda')
|
||||
parser.add_argument('-dt','--dtype', default='FP32', dest='dtype',
|
||||
help='Data type (Default: Float32)')
|
||||
|
||||
parser.add_argument('-m','--model', default='resnet18', dest='model',
|
||||
help='Network')
|
||||
parser.add_argument('-pt','--pretrained', default='', dest='pretrained',
|
||||
help='Use pretrained weight if possible')
|
||||
|
||||
parser.add_argument('-ep','--epochs', type=int, default=10, dest='epochs',
|
||||
help='epoch')
|
||||
# parser.add_argument('-ot', '--optimizer', default='SGD', dest='opt_type',
|
||||
# help='Model optimizer')
|
||||
parser.add_argument('-lr', type=float, default=1e-1, dest='lr',
|
||||
help='Model learning rate')
|
||||
parser.add_argument('-mo', '--momentum', type=float, default=0.9, dest='momentum',
|
||||
help='Momentum')
|
||||
parser.add_argument('-dc', '--decay', type=float, default=0.0005, dest='decay',
|
||||
help='Weight decay')
|
||||
parser.add_argument('-ns','--nesterov', type=bool, default=False, dest='nesterov',
|
||||
help='Nesterov momentum ?')
|
||||
parser.add_argument('-sc', '--scheduler', default='cosine', dest='scheduler',
|
||||
help='Model learning rate scheduler')
|
||||
parser.add_argument('-wu', '--warmup', type=float, default=0, dest='warmup',
|
||||
help='Warmup multiplier')
|
||||
|
||||
|
||||
parser.add_argument('-a','--augment', type=bool, default=False, dest='augment',
|
||||
help='Data augmentation ?')
|
||||
parser.add_argument('-N', type=int, default=1,
|
||||
help='Combination of TF')
|
||||
parser.add_argument('-K', type=int, default=0,
|
||||
help='Number inner iteration')
|
||||
parser.add_argument('-al','--augment_loss', type=int, default=1, dest='augment_loss',
|
||||
help='Number of augmented example for each sample in loss computation.')
|
||||
parser.add_argument('-t', '--temp', type=float, default=0.5, dest='temp',
|
||||
help='Probability distribution temperature')
|
||||
parser.add_argument('-tfc','--tf_config', default='../config/invScale_wide_tf_config.json', dest='tf_config',
|
||||
help='TF config')
|
||||
parser.add_argument('-ls', '--learn_seq', type=bool, default=False, dest='learn_seq',
|
||||
help='Learn order of application of TF (DataugV7-8) ?')
|
||||
parser.add_argument('-fm', '--fixed_mag', type=bool, default=False, dest='fixed_mag',
|
||||
help='Fixed magnitude when learning data augmentation ?')
|
||||
parser.add_argument('-sm', '--shared_mag', type=bool, default=False, dest='shared_mag',
|
||||
help='Shared magnitude when learning data augmentation ?')
|
||||
|
||||
# parser.add_argument('-mot', '--metaoptimizer', default='Adam', dest='meta_opt_type',
|
||||
# help='Meta optimizer (Augmentations)')
|
||||
parser.add_argument('-mlr', type=float, default=1e-2, dest='mlr',
|
||||
help='Meta learning rate (Augmentations)')
|
||||
parser.add_argument('-ms', type=int, default=0, dest='meta_epoch_start',
|
||||
help='Epoch at which start meta learning')
|
||||
parser.add_argument('-mr', type=float, default=0.001, dest='mag_reg',
|
||||
help='Augmentation magnitudes regulation factor')
|
||||
|
||||
parser.add_argument('-rf','--res_folder', default='../res/', dest='res_folder',
|
||||
help='Results folder')
|
||||
parser.add_argument('-pf','--postfix', default='', dest='postfix',
|
||||
help='Res postfix')
|
||||
|
||||
parser.add_argument('-dr','--dataroot', default='~/scratch/data', dest='dataroot',
|
||||
help='Datasets folder')
|
||||
parser.add_argument('-ds','--dataset', default='CIFAR10', dest='dataset',
|
||||
help='Dataset')
|
||||
parser.add_argument('-bs','--batch_size', type=int, default=256, dest='batch_size',
|
||||
help='Batch size') #256 (WRN) / 512
|
||||
parser.add_argument('-w','--workers', type=int, default=6, dest='workers',
|
||||
help='Numer of workers (Nb CPU cores).')
|
247
higher/smart_aug/benchmark.py
Normal file
247
higher/smart_aug/benchmark.py
Normal file
|
@ -0,0 +1,247 @@
|
|||
""" Script to run series of experiments.
|
||||
|
||||
"""
|
||||
from dataug import *
|
||||
#from utils import *
|
||||
from train_utils import *
|
||||
from transformations import TF_loader
|
||||
|
||||
import torchvision.models as models
|
||||
from LeNet import *
|
||||
|
||||
#model_list={models.resnet: ['resnet18', 'resnet50','wide_resnet50_2']} #lr=0.1
|
||||
model_list={models.resnet: ['wide_resnet50_2']}
|
||||
|
||||
optim_param={
|
||||
'Meta':{
|
||||
'optim':'Adam',
|
||||
'lr':5e-3, #1e-2
|
||||
'epoch_start': 2, #0 / 2 (Resnet?)
|
||||
'reg_factor': 0.001,
|
||||
'scheduler': None, #None, 'multiStep'
|
||||
},
|
||||
'Inner':{
|
||||
'optim': 'SGD',
|
||||
'lr':1e-2, #1e-2/1e-1 (ResNet)
|
||||
'momentum':0.9, #0.9
|
||||
'decay':0.0005, #0.0005
|
||||
'nesterov':False, #False (True: Bad behavior w/ Data_aug)
|
||||
'scheduler':'cosine', #None, 'cosine', 'multiStep', 'exponential'
|
||||
}
|
||||
}
|
||||
|
||||
res_folder="../res/benchmark/CIFAR10/"
|
||||
#res_folder="../res/benchmark/MNIST/"
|
||||
#res_folder="../res/HPsearch/"
|
||||
epochs= 200
|
||||
dataug_epoch_start=0
|
||||
nb_run= 3
|
||||
|
||||
tf_config='../config/bad_tf_config.json' #'../config/wide_tf_config.json'#'../config/base_tf_config.json'
|
||||
TF_loader=TF_loader()
|
||||
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(tf_config)
|
||||
|
||||
device = torch.device('cuda:1')
|
||||
|
||||
if device == torch.device('cpu'):
|
||||
device_name = 'CPU'
|
||||
else:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True #Faster if same input size #Not recommended for reproductibility
|
||||
|
||||
#Increase reproductibility
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
### Benchmark ###
|
||||
'''
|
||||
inner_its = [0]
|
||||
dist_mix = [0.5]
|
||||
N_seq_TF= [3]
|
||||
mag_setup = [(False, False)] #[(True, True), (False, False)] #(FxSh, Independant)
|
||||
|
||||
for model_type in model_list.keys():
|
||||
for model_name in model_list[model_type]:
|
||||
for run in range(nb_run):
|
||||
|
||||
for n_inner_iter in inner_its:
|
||||
for n_tf in N_seq_TF:
|
||||
for dist in dist_mix:
|
||||
for m_setup in mag_setup:
|
||||
|
||||
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
|
||||
torch.cuda.reset_max_memory_cached() #reset_peak_stats
|
||||
t0 = time.perf_counter()
|
||||
|
||||
model = getattr(model_type, model_name)(pretrained=False, num_classes=len(dl_train.dataset.classes))
|
||||
#model_name = 'LeNet'
|
||||
#model = LeNet(3,10)
|
||||
|
||||
model = Higher_model(model, model_name) #run_dist_dataugV3
|
||||
if n_inner_iter!=0:
|
||||
aug_model = Augmented_model(
|
||||
Data_augV5(TF_dict=tf_dict,
|
||||
N_TF=n_tf,
|
||||
mix_dist=dist,
|
||||
fixed_prob=False,
|
||||
fixed_mag=m_setup[0],
|
||||
shared_mag=m_setup[1],
|
||||
TF_ignore_mag=tf_ignore_mag),
|
||||
model).to(device)
|
||||
else:
|
||||
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=n_tf), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||
log= run_dist_dataugV3(model=aug_model,
|
||||
epochs=epochs,
|
||||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=epochs/4,
|
||||
unsup_loss=1,
|
||||
hp_opt=False,
|
||||
save_sample_freq=None)
|
||||
|
||||
exec_time=time.perf_counter() - t0
|
||||
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
|
||||
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]),
|
||||
"Time": (np.mean(times),np.std(times), exec_time),
|
||||
'Optimizer': optim_param,
|
||||
"Device": device_name,
|
||||
"Memory": [max_allocated, max_cached],
|
||||
"TF_config": tf_config,
|
||||
"Param_names": aug_model.TF_names(),
|
||||
"Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter, run)
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
try:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
except:
|
||||
print("Failed to save logs :",f.name)
|
||||
try:
|
||||
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names())
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
print(sys.exc_info()[1])
|
||||
|
||||
print('Execution Time : %.00f '%(exec_time))
|
||||
print('-'*9)
|
||||
'''
|
||||
### Benchmark - RandAugment/Vanilla ###
|
||||
#'''
|
||||
for model_type in model_list.keys():
|
||||
for model_name in model_list[model_type]:
|
||||
for run in range(nb_run):
|
||||
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
|
||||
torch.cuda.reset_max_memory_cached() #reset_peak_stats
|
||||
t0 = time.perf_counter()
|
||||
|
||||
model = getattr(model_type, model_name)(pretrained=False, num_classes=len(dl_train.dataset.classes)).to(device)
|
||||
|
||||
print("{} on {} for {} epochs".format(model_name, device_name, epochs))
|
||||
#print("RandAugment(N{}-M{:.2f})-{} on {} for {} epochs".format(rand_aug['N'],rand_aug['M'],model_name, device_name, epochs))
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=epochs/4)
|
||||
|
||||
exec_time=time.perf_counter() - t0
|
||||
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
|
||||
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]),
|
||||
"Time": (np.mean(times),np.std(times), exec_time),
|
||||
'Optimizer': optim_param,
|
||||
"Device": device_name,
|
||||
"Memory": [max_allocated, max_cached],
|
||||
#"Rand_Aug": rand_aug,
|
||||
"Log": log}
|
||||
print(model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs -{}-basicDA".format(model_name,epochs, run)
|
||||
#print("RandAugment-",model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
#filename = "RandAugment(N{}-M{:.2f})-{}-{} epochs -{}".format(rand_aug['N'],rand_aug['M'],model_name,epochs, run)
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
try:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
except:
|
||||
print("Failed to save logs :",f.name)
|
||||
print(sys.exc_info()[1])
|
||||
|
||||
try:
|
||||
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names())
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
print(sys.exc_info()[1])
|
||||
print('Execution Time : %.00f '%(exec_time))
|
||||
print('-'*9)
|
||||
#'''
|
||||
### HP Search ###
|
||||
'''
|
||||
from LeNet import *
|
||||
inner_its = [1]
|
||||
dist_mix = [1.0]#[0.0, 0.5, 0.8, 1.0]
|
||||
N_seq_TF= [5, 6]
|
||||
mag_setup = [(True, True), (False, False)] #(FxSh, Independant)
|
||||
#prob_setup = [True, False]
|
||||
|
||||
try:
|
||||
os.mkdir(res_folder)
|
||||
os.mkdir(res_folder+"log/")
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
for n_inner_iter in inner_its:
|
||||
for n_tf in N_seq_TF:
|
||||
for dist in dist_mix:
|
||||
#for i in TF_nb:
|
||||
for m_setup in mag_setup:
|
||||
#for p_setup in prob_setup:
|
||||
p_setup=False
|
||||
for run in range(nb_run):
|
||||
|
||||
t0 = time.perf_counter()
|
||||
|
||||
model = getattr(models.resnet, 'resnet18')(pretrained=False, num_classes=len(dl_train.dataset.classes))
|
||||
#model = LeNet(3,10)
|
||||
model = Higher_model(model) #run_dist_dataugV3
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
|
||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||
log= run_dist_dataugV3(model=aug_model,
|
||||
epochs=epochs,
|
||||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=epochs/4,
|
||||
unsup_loss=1,
|
||||
hp_opt=False,
|
||||
save_sample_freq=None)
|
||||
|
||||
exec_time=time.perf_counter() - t0
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter, run)
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
try:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
except:
|
||||
print("Failed to save logs :",f.name)
|
||||
|
||||
print('Execution Time : %.00f '%(exec_time))
|
||||
print('-'*9)
|
||||
'''
|
220
higher/smart_aug/datasets.py
Normal file
220
higher/smart_aug/datasets.py
Normal file
|
@ -0,0 +1,220 @@
|
|||
""" Dataset definition.
|
||||
|
||||
MNIST / CIFAR10 / CIFAR100 / SVHN / ImageNet
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data.dataset import ConcatDataset
|
||||
import torchvision
|
||||
from arg_parser import *
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
#Wether to download data.
|
||||
download_data=False
|
||||
#Pin GPU memory
|
||||
pin_memory=False #True :+ GPU memory / + Lent
|
||||
#Data storage folder
|
||||
dataroot=args.dataroot
|
||||
|
||||
# if args.dtype == 'FP32':
|
||||
# def_type=torch.float32
|
||||
# elif args.dtype == 'FP16':
|
||||
# # def_type=torch.float16 #Default : float32
|
||||
# def_type=torch.bfloat16
|
||||
# else:
|
||||
# raise Exception('dtype not supported :', args.dtype)
|
||||
|
||||
#ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1]
|
||||
transform = [
|
||||
#torchvision.transforms.Grayscale(3), #MNIST
|
||||
#torchvision.transforms.Resize((224,224), interpolation=2)#VGG
|
||||
torchvision.transforms.ToTensor(),
|
||||
#torchvision.transforms.Normalize(MEAN, STD), #CIFAR10
|
||||
# torchvision.transforms.Lambda(lambda tensor: tensor.to(def_type)),
|
||||
]
|
||||
|
||||
transform_train = [
|
||||
#transforms.RandomHorizontalFlip(),
|
||||
#transforms.RandomVerticalFlip(),
|
||||
#torchvision.transforms.Grayscale(3), #MNIST
|
||||
#torchvision.transforms.Resize((224,224), interpolation=2)
|
||||
torchvision.transforms.ToTensor(),
|
||||
#torchvision.transforms.Normalize(MEAN, STD), #CIFAR10
|
||||
# torchvision.transforms.Lambda(lambda tensor: tensor.to(def_type)),
|
||||
]
|
||||
|
||||
## RandAugment ##
|
||||
#from RandAugment import RandAugment
|
||||
# Add RandAugment with N, M(hyperparameter)
|
||||
#rand_aug={'N': 2, 'M': 1}
|
||||
#rand_aug={'N': 2, 'M': 9./30} #RN-ImageNet
|
||||
#rand_aug={'N': 3, 'M': 5./30} #WRN-CIFAR10
|
||||
#rand_aug={'N': 2, 'M': 14./30} #WRN-CIFAR100
|
||||
#rand_aug={'N': 3, 'M': 7./30} #WRN-SVHN
|
||||
#transform_train.transforms.insert(0, RandAugment(n=rand_aug['N'], m=rand_aug['M']))
|
||||
|
||||
### Classic Dataset ###
|
||||
BATCH_SIZE = args.batch_size
|
||||
TEST_SIZE = BATCH_SIZE
|
||||
# Load Dataset
|
||||
if args.dataset == 'MNIST':
|
||||
transform_train.insert(0, torchvision.transforms.Grayscale(3))
|
||||
transform.insert(0, torchvision.transforms.Grayscale(3))
|
||||
|
||||
val_set=False
|
||||
data_train = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=torchvision.transforms.Compose(transform_train))
|
||||
data_val = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=torchvision.transforms.Compose(transform))
|
||||
data_test = torchvision.datasets.MNIST(dataroot, train=False, download=True, transform=torchvision.transforms.Compose(transform))
|
||||
elif args.dataset == 'CIFAR10': #(32x32 RGB)
|
||||
val_set=False
|
||||
MEAN=(0.4914, 0.4822, 0.4465)
|
||||
STD=(0.2023, 0.1994, 0.2010)
|
||||
data_train = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform_train))
|
||||
data_val = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform))
|
||||
data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=download_data, transform=torchvision.transforms.Compose(transform))
|
||||
elif args.dataset == 'CIFAR100': #(32x32 RGB)
|
||||
val_set=False
|
||||
MEAN=(0.4914, 0.4822, 0.4465)
|
||||
STD=(0.2023, 0.1994, 0.2010)
|
||||
data_train = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform_train))
|
||||
data_val = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform))
|
||||
data_test = torchvision.datasets.CIFAR100(dataroot, train=False, download=download_data, transform=torchvision.transforms.Compose(transform))
|
||||
elif args.dataset == 'TinyImageNet': #(Train:100k, Val:5k, Test:5k) (64x64 RGB)
|
||||
image_size=64 #128 / 224
|
||||
print('Using image size', image_size)
|
||||
transform_train=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform_train
|
||||
transform=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform
|
||||
|
||||
val_set=True
|
||||
MEAN=(0.485, 0.456, 0.406)
|
||||
STD=(0.229, 0.224, 0.225)
|
||||
data_train = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/train'), transform=torchvision.transforms.Compose(transform_train))
|
||||
data_val = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/val'), transform=torchvision.transforms.Compose(transform))
|
||||
data_test = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/test'), transform=torchvision.transforms.Compose(transform))
|
||||
elif args.dataset == 'ImageNet': #
|
||||
image_size=128 #224
|
||||
print('Using image size', image_size)
|
||||
transform_train=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform_train
|
||||
transform=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform
|
||||
|
||||
val_set=False
|
||||
MEAN=(0.485, 0.456, 0.406)
|
||||
STD=(0.229, 0.224, 0.225)
|
||||
data_train = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/train'), transform=torchvision.transforms.Compose(transform_train))
|
||||
data_val = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/train'), transform=torchvision.transforms.Compose(transform))
|
||||
data_test = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/validation'), transform=torchvision.transforms.Compose(transform))
|
||||
|
||||
else:
|
||||
raise Exception('Unknown dataset')
|
||||
|
||||
# Ready dataloader
|
||||
if not val_set : #Split Training set into Train/Val
|
||||
#Validation set size [0, 1]
|
||||
valid_size=0.1
|
||||
train_subset_indices=range(int(len(data_train)*(1-valid_size)))
|
||||
val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
|
||||
|
||||
from torch.utils.data import SubsetRandomSampler
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=args.workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=args.workers, pin_memory=pin_memory)
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=args.workers, pin_memory=pin_memory)
|
||||
else:
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.workers, pin_memory=pin_memory)
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=args.workers, pin_memory=pin_memory)
|
||||
|
||||
|
||||
#SVHN
|
||||
#trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=download_data, transform=transform_train)
|
||||
#extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=download_data, transform=transform_train)
|
||||
#data_train = ConcatDataset([trainset, extraset])
|
||||
#data_test = torchvision.datasets.SVHN(dataroot, split='test', download=download_data, transform=transform)
|
||||
|
||||
#ImageNet
|
||||
#Necessite SciPy
|
||||
# Probleme ? : https://github.com/ildoonet/pytorch-randaugment/blob/48b8f509c4bbda93bbe733d98b3fd052b6e4c8ae/RandAugment/imagenet.py#L28
|
||||
#data_train = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='train', transform=transform_train)
|
||||
#data_test = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)
|
||||
|
||||
#Cross Validation
|
||||
'''
|
||||
import numpy as np
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
from sklearn.model_selection import StratifiedShuffleSplit
|
||||
class CVSplit(object):
|
||||
"""Class that perform train/valid split on a dataset.
|
||||
|
||||
Inspired from : https://skorch.readthedocs.io/en/latest/user/dataset.html
|
||||
|
||||
Attributes:
|
||||
_stratified (bool): Wether the split should be stratified. Recommended to be True for unbalanced dataset.
|
||||
_val_size (float, int): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split.
|
||||
If int, represents the absolute number of validation samples.
|
||||
_data (Dataset): Dataset to split.
|
||||
_targets (np.array): Targets of the dataset used if _stratified is set to True.
|
||||
_cv (BaseShuffleSplit) : Scikit learn object used to split.
|
||||
|
||||
"""
|
||||
def __init__(self, data, val_size=0.1, stratified=True):
|
||||
""" Intialize CVSplit.
|
||||
|
||||
Args:
|
||||
data (Dataset): Dataset to split.
|
||||
val_size (float, int): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split.
|
||||
If int, represents the absolute number of validation samples. (Default: 0.1)
|
||||
stratified (bool): Wether the split should be stratified. Recommended to be True for unbalanced dataset.
|
||||
"""
|
||||
self._stratified=stratified
|
||||
self._val_size=val_size
|
||||
|
||||
self._data=data
|
||||
if self._stratified:
|
||||
cv_cls = StratifiedShuffleSplit
|
||||
self._targets= np.array(data_train.targets)
|
||||
else:
|
||||
cv_cls = ShuffleSplit
|
||||
|
||||
self._cv= cv_cls(test_size=val_size, random_state=0) #Random state w/ fixed seed
|
||||
|
||||
def next_split(self):
|
||||
""" Get next cross-validation split.
|
||||
|
||||
Returns:
|
||||
Train DataLoader, Validation DataLoader
|
||||
"""
|
||||
args=(np.arange(len(self._data)),)
|
||||
if self._stratified:
|
||||
args = args + (self._targets,)
|
||||
|
||||
idx_train, idx_valid = next(iter(self._cv.split(*args)))
|
||||
|
||||
train_subset = torch.utils.data.Subset(self._data, idx_train)
|
||||
val_subset = torch.utils.data.Subset(self._data, idx_valid)
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
||||
return dl_train, dl_val
|
||||
|
||||
cvs = CVSplit(data_train, val_size=valid_size)
|
||||
dl_train, dl_val = cvs.next_split()
|
||||
'''
|
||||
|
||||
'''
|
||||
from skorch.dataset import CVSplit
|
||||
import numpy as np
|
||||
cvs = CVSplit(cv=valid_size, stratified=True) #Stratified =True for unbalanced dataset #ShuffleSplit
|
||||
|
||||
def next_CVSplit():
|
||||
|
||||
train_subset, val_subset = cvs(data_train, y=np.array(data_train.targets))
|
||||
dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
||||
return dl_train, dl_val
|
||||
|
||||
dl_train, dl_val = next_CVSplit()
|
||||
'''
|
1263
higher/smart_aug/dataug.py
Normal file
1263
higher/smart_aug/dataug.py
Normal file
File diff suppressed because it is too large
Load diff
31
higher/smart_aug/higher_patch.py
Normal file
31
higher/smart_aug/higher_patch.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
""" Patch for Higher package
|
||||
|
||||
Recommended use ::
|
||||
|
||||
import higher
|
||||
import higher_patch
|
||||
|
||||
Might become unnecessary with future update of the Higher package.
|
||||
"""
|
||||
import higher
|
||||
import torch as _torch
|
||||
|
||||
def detach_(self):
|
||||
"""Removes all params from their compute graph in place.
|
||||
|
||||
"""
|
||||
# detach param groups
|
||||
for group in self.param_groups:
|
||||
for k, v in group.items():
|
||||
if isinstance(v,_torch.Tensor):
|
||||
v.detach_().requires_grad_()
|
||||
|
||||
# detach state
|
||||
for state_dict in self.state:
|
||||
for k,v_dict in state_dict.items():
|
||||
if isinstance(k,_torch.Tensor): k.detach_().requires_grad_()
|
||||
for k2,v2 in v_dict.items():
|
||||
if isinstance(v2,_torch.Tensor):
|
||||
v2.detach_().requires_grad_()
|
||||
|
||||
higher.optim.DifferentiableOptimizer.detach_ = detach_
|
56
higher/smart_aug/nets/LeNet.py
Normal file
56
higher/smart_aug/nets/LeNet.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
## Basic CNN ##
|
||||
class LeNet(nn.Module):
|
||||
"""Basic CNN.
|
||||
|
||||
"""
|
||||
def __init__(self, num_inp, num_out):
|
||||
"""Init LeNet.
|
||||
|
||||
"""
|
||||
super(LeNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(num_inp, 20, 5)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(20, 50, 5)
|
||||
self.pool2 = nn.MaxPool2d(2, 2)
|
||||
#self.fc1 = nn.Linear(4*4*50, 500)
|
||||
self.fc1 = nn.Linear(5*5*50, 500)
|
||||
self.fc2 = nn.Linear(500, num_out)
|
||||
|
||||
def forward(self, x):
|
||||
"""Main method of LeNet
|
||||
|
||||
"""
|
||||
x = self.pool(F.relu(self.conv1(x)))
|
||||
x = self.pool2(F.relu(self.conv2(x)))
|
||||
x = x.view(x.size(0), -1)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "LeNet"
|
||||
|
||||
#MNIST
|
||||
class MLPNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(MLPNet, self).__init__()
|
||||
self.fc1 = nn.Linear(28*28, 500)
|
||||
self.fc2 = nn.Linear(500, 256)
|
||||
self.fc3 = nn.Linear(256, 10)
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 28*28)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
def name(self):
|
||||
return "MLP"
|
426
higher/smart_aug/nets/resnet_abn.py
Normal file
426
higher/smart_aug/nets/resnet_abn.py
Normal file
|
@ -0,0 +1,426 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
|
||||
|
||||
# __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
# 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
# 'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
|
||||
# model_urls = {
|
||||
# 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
# 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
# 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
# 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
# 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
# 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
# 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
# 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
# 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
# }
|
||||
|
||||
__all__ = ['ResNet_ABN', 'resnet18_ABN', 'resnet34_ABN', 'resnet50_ABN', 'resnet101_ABN',
|
||||
'resnet152_ABN', 'resnext50_32x4d_ABN', 'resnext101_32x8d_ABN',
|
||||
'wide_resnet50_2_ABN', 'wide_resnet101_2_ABN']
|
||||
|
||||
class aux_batchNorm(nn.Module):
|
||||
def __init__(self, norm_layer, nb_features):
|
||||
super(aux_batchNorm, self).__init__()
|
||||
self.mode='clean'
|
||||
self.bn=nn.ModuleDict({
|
||||
'clean': norm_layer(nb_features),
|
||||
'augmented': norm_layer(nb_features)
|
||||
})
|
||||
def forward(self, x):
|
||||
if self.mode is 'mixed':
|
||||
running_mean=(self.bn['clean'].running_mean+self.bn['augmented'].running_mean)/2
|
||||
running_var=(self.bn['clean'].running_var+self.bn['augmented'].running_var)/2
|
||||
return nn.functional.batch_norm(x, running_mean, running_var, self.bn['clean'].weight, self.bn['clean'].bias)
|
||||
return self.bn[self.mode](x)
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock_ABN(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock_ABN, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
#self.bn1 = norm_layer(planes)
|
||||
self.bn1 = aux_batchNorm(norm_layer, planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
#self.bn2 = norm_layer(planes)
|
||||
self.bn2 = aux_batchNorm(norm_layer, planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck_ABN(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(Bottleneck_ABN, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
#self.bn1 = norm_layer(width)
|
||||
self.bn1 = aux_batchNorm(norm_layer, width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
# self.bn2 = norm_layer(width)
|
||||
self.bn2 = aux_batchNorm(norm_layer, width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
# self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.bn3 = aux_batchNorm(norm_layer, planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_ABN(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None):
|
||||
super(ResNet_ABN, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
#self.bn1 = norm_layer(self.inplanes)
|
||||
self.bn1 = aux_batchNorm(norm_layer, self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
print('WARNING : zero_init_residual not implemented with ABN')
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, Bottleneck):
|
||||
# nn.init.constant_(m.bn3.weight, 0)
|
||||
# elif isinstance(m, BasicBlock):
|
||||
# nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
# Memoire des BN layers pas fonctinnel avec Higher
|
||||
# self.bn_layers=[]
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, aux_batchNorm):
|
||||
# self.bn_layers.append(m)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
#norm_layer(planes * block.expansion),
|
||||
aux_batchNorm(norm_layer, planes * block.expansion)
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def set_mode(self, mode):
|
||||
# for bn in self.bn_layers:
|
||||
for m in self.modules():
|
||||
if isinstance(m, aux_batchNorm):
|
||||
m.mode=mode
|
||||
|
||||
|
||||
|
||||
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
# model = ResNet(block, layers, **kwargs)
|
||||
# if pretrained:
|
||||
# state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
# progress=progress)
|
||||
# model.load_state_dict(state_dict)
|
||||
# return model
|
||||
|
||||
|
||||
def resnet18_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(BasicBlock_ABN, [2, 2, 2, 2], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet34_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(BasicBlock_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet50_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet101_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet152_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 8, 36, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnext50_32x4d_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
# return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
# return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def wide_resnet50_2_ABN(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
# return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def wide_resnet101_2_ABN(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
# return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
|
618
higher/smart_aug/nets/resnet_deconv.py
Normal file
618
higher/smart_aug/nets/resnet_deconv.py
Normal file
|
@ -0,0 +1,618 @@
|
|||
'''ResNet in PyTorch.
|
||||
For Pre-activation ResNet, see 'preact_resnet.py'.
|
||||
Reference:
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
||||
|
||||
https://github.com/yechengxi/deconvolution
|
||||
'''
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from torch.nn.modules import conv
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from functools import partial
|
||||
|
||||
__all__ = ['ResNet18_DC', 'ResNet34_DC', 'ResNet50_DC', 'ResNet101_DC', 'ResNet152_DC', 'WRN_DC26_10']
|
||||
|
||||
### Deconvolution ###
|
||||
|
||||
#iteratively solve for inverse sqrt of a matrix
|
||||
def isqrt_newton_schulz_autograd(A, numIters):
|
||||
dim = A.shape[0]
|
||||
normA=A.norm()
|
||||
Y = A.div(normA)
|
||||
I = torch.eye(dim,dtype=A.dtype,device=A.device)
|
||||
Z = torch.eye(dim,dtype=A.dtype,device=A.device)
|
||||
|
||||
for i in range(numIters):
|
||||
T = 0.5*(3.0*I - Z@Y)
|
||||
Y = Y@T
|
||||
Z = T@Z
|
||||
#A_sqrt = Y*torch.sqrt(normA)
|
||||
A_isqrt = Z / torch.sqrt(normA)
|
||||
return A_isqrt
|
||||
|
||||
def isqrt_newton_schulz_autograd_batch(A, numIters):
|
||||
batchSize,dim,_ = A.shape
|
||||
normA=A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1)
|
||||
Y = A.div(normA)
|
||||
I = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
|
||||
Z = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
|
||||
|
||||
for i in range(numIters):
|
||||
T = 0.5*(3.0*I - Z.bmm(Y))
|
||||
Y = Y.bmm(T)
|
||||
Z = T.bmm(Z)
|
||||
#A_sqrt = Y*torch.sqrt(normA)
|
||||
A_isqrt = Z / torch.sqrt(normA)
|
||||
|
||||
return A_isqrt
|
||||
|
||||
|
||||
|
||||
#deconvolve channels
|
||||
class ChannelDeconv(nn.Module):
|
||||
def __init__(self, block, eps=1e-2,n_iter=5,momentum=0.1,sampling_stride=3):
|
||||
super(ChannelDeconv, self).__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.n_iter=n_iter
|
||||
self.momentum=momentum
|
||||
self.block = block
|
||||
|
||||
self.register_buffer('running_mean1', torch.zeros(block, 1))
|
||||
#self.register_buffer('running_cov', torch.eye(block))
|
||||
self.register_buffer('running_deconv', torch.eye(block))
|
||||
self.register_buffer('running_mean2', torch.zeros(1, 1))
|
||||
self.register_buffer('running_var', torch.ones(1, 1))
|
||||
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
|
||||
self.sampling_stride=sampling_stride
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
if len(x.shape)==2:
|
||||
x=x.view(x.shape[0],x.shape[1],1,1)
|
||||
if len(x.shape)==3:
|
||||
print('Error! Unsupprted tensor shape.')
|
||||
|
||||
N, C, H, W = x.size()
|
||||
B = self.block
|
||||
|
||||
#take the first c channels out for deconv
|
||||
c=int(C/B)*B
|
||||
if c==0:
|
||||
print('Error! block should be set smaller.')
|
||||
|
||||
#step 1. remove mean
|
||||
if c!=C:
|
||||
x1=x[:,:c].permute(1,0,2,3).contiguous().view(B,-1)
|
||||
else:
|
||||
x1=x.permute(1,0,2,3).contiguous().view(B,-1)
|
||||
|
||||
if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride:
|
||||
x1_s = x1[:,::self.sampling_stride**2]
|
||||
else:
|
||||
x1_s=x1
|
||||
|
||||
mean1 = x1_s.mean(-1, keepdim=True)
|
||||
|
||||
if self.num_batches_tracked==0:
|
||||
self.running_mean1.copy_(mean1.detach())
|
||||
if self.training:
|
||||
self.running_mean1.mul_(1-self.momentum)
|
||||
self.running_mean1.add_(mean1.detach()*self.momentum)
|
||||
else:
|
||||
mean1 = self.running_mean1
|
||||
|
||||
x1=x1-mean1
|
||||
|
||||
#step 2. calculate deconv@x1 = cov^(-0.5)@x1
|
||||
if self.training:
|
||||
cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(B, dtype=x.dtype, device=x.device)
|
||||
deconv = isqrt_newton_schulz_autograd(cov, self.n_iter)
|
||||
|
||||
if self.num_batches_tracked==0:
|
||||
#self.running_cov.copy_(cov.detach())
|
||||
self.running_deconv.copy_(deconv.detach())
|
||||
|
||||
if self.training:
|
||||
#self.running_cov.mul_(1-self.momentum)
|
||||
#self.running_cov.add_(cov.detach()*self.momentum)
|
||||
self.running_deconv.mul_(1 - self.momentum)
|
||||
self.running_deconv.add_(deconv.detach() * self.momentum)
|
||||
else:
|
||||
# cov = self.running_cov
|
||||
deconv = self.running_deconv
|
||||
|
||||
x1 =deconv@x1
|
||||
|
||||
#reshape to N,c,J,W
|
||||
x1 = x1.view(c, N, H, W).contiguous().permute(1,0,2,3)
|
||||
|
||||
# normalize the remaining channels
|
||||
if c!=C:
|
||||
x_tmp=x[:, c:].view(N,-1)
|
||||
if self.sampling_stride > 1 and H>=self.sampling_stride and W>=self.sampling_stride:
|
||||
x_s = x_tmp[:, ::self.sampling_stride ** 2]
|
||||
else:
|
||||
x_s = x_tmp
|
||||
|
||||
mean2=x_s.mean()
|
||||
var=x_s.var()
|
||||
|
||||
if self.num_batches_tracked == 0:
|
||||
self.running_mean2.copy_(mean2.detach())
|
||||
self.running_var.copy_(var.detach())
|
||||
|
||||
if self.training:
|
||||
self.running_mean2.mul_(1 - self.momentum)
|
||||
self.running_mean2.add_(mean2.detach() * self.momentum)
|
||||
self.running_var.mul_(1 - self.momentum)
|
||||
self.running_var.add_(var.detach() * self.momentum)
|
||||
else:
|
||||
mean2 = self.running_mean2
|
||||
var = self.running_var
|
||||
|
||||
x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt()
|
||||
x1 = torch.cat([x1, x_tmp], dim=1)
|
||||
|
||||
|
||||
if self.training:
|
||||
self.num_batches_tracked.add_(1)
|
||||
|
||||
if len(x_shape)==2:
|
||||
x1=x1.view(x_shape)
|
||||
return x1
|
||||
|
||||
#An alternative implementation
|
||||
class Delinear(nn.Module):
|
||||
__constants__ = ['bias', 'in_features', 'out_features']
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=512):
|
||||
super(Delinear, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
||||
|
||||
|
||||
if block > in_features:
|
||||
block = in_features
|
||||
else:
|
||||
if in_features%block!=0:
|
||||
block=math.gcd(block,in_features)
|
||||
print('block size set to:', block)
|
||||
self.block = block
|
||||
self.momentum = momentum
|
||||
self.n_iter = n_iter
|
||||
self.eps = eps
|
||||
self.register_buffer('running_mean', torch.zeros(self.block))
|
||||
self.register_buffer('running_deconv', torch.eye(self.block))
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
if self.training:
|
||||
|
||||
# 1. reshape
|
||||
X=input.view(-1, self.block)
|
||||
|
||||
# 2. subtract mean
|
||||
X_mean = X.mean(0)
|
||||
X = X - X_mean.unsqueeze(0)
|
||||
self.running_mean.mul_(1 - self.momentum)
|
||||
self.running_mean.add_(X_mean.detach() * self.momentum)
|
||||
|
||||
# 3. calculate COV, COV^(-0.5), then deconv
|
||||
# Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Id = torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
|
||||
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
|
||||
# track stats for evaluation
|
||||
self.running_deconv.mul_(1 - self.momentum)
|
||||
self.running_deconv.add_(deconv.detach() * self.momentum)
|
||||
|
||||
else:
|
||||
X_mean = self.running_mean
|
||||
deconv = self.running_deconv
|
||||
|
||||
w = self.weight.view(-1, self.block) @ deconv
|
||||
b = self.bias
|
||||
if self.bias is not None:
|
||||
b = b - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
|
||||
w = w.view(self.weight.shape)
|
||||
return F.linear(input, w, b)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None
|
||||
)
|
||||
|
||||
|
||||
|
||||
class FastDeconv(conv._ConvNd):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,freeze=False,freeze_iter=100):
|
||||
self.momentum = momentum
|
||||
self.n_iter = n_iter
|
||||
self.eps = eps
|
||||
self.counter=0
|
||||
self.track_running_stats=True
|
||||
super(FastDeconv, self).__init__(
|
||||
in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
|
||||
False, _pair(0), groups, bias, padding_mode='zeros')
|
||||
|
||||
if block > in_channels:
|
||||
block = in_channels
|
||||
else:
|
||||
if in_channels%block!=0:
|
||||
block=math.gcd(block,in_channels)
|
||||
|
||||
if groups>1:
|
||||
#grouped conv
|
||||
block=in_channels//groups
|
||||
|
||||
self.block=block
|
||||
|
||||
self.num_features = kernel_size**2 *block
|
||||
if groups==1:
|
||||
self.register_buffer('running_mean', torch.zeros(self.num_features))
|
||||
self.register_buffer('running_deconv', torch.eye(self.num_features))
|
||||
else:
|
||||
self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels))
|
||||
self.register_buffer('running_deconv', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))
|
||||
|
||||
self.sampling_stride=sampling_stride*stride
|
||||
self.counter=0
|
||||
self.freeze_iter=freeze_iter
|
||||
self.freeze=freeze
|
||||
|
||||
def forward(self, x):
|
||||
N, C, H, W = x.shape
|
||||
B = self.block
|
||||
frozen=self.freeze and (self.counter>self.freeze_iter)
|
||||
if self.training and self.track_running_stats:
|
||||
self.counter+=1
|
||||
self.counter %= (self.freeze_iter * 10)
|
||||
|
||||
if self.training and (not frozen):
|
||||
|
||||
# 1. im2col: N x cols x pixels -> N*pixles x cols
|
||||
if self.kernel_size[0]>1:
|
||||
X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
|
||||
else:
|
||||
#channel wise
|
||||
X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]
|
||||
|
||||
if self.groups==1:
|
||||
# (C//B*N*pixels,k*k*B)
|
||||
X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
|
||||
else:
|
||||
X=X.view(-1,X.shape[-1])
|
||||
|
||||
# 2. subtract mean
|
||||
X_mean = X.mean(0)
|
||||
X = X - X_mean.unsqueeze(0)
|
||||
|
||||
# 3. calculate COV, COV^(-0.5), then deconv
|
||||
if self.groups==1:
|
||||
#Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Id=torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
|
||||
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
|
||||
else:
|
||||
X = X.view(-1, self.groups, self.num_features).transpose(0, 1)
|
||||
Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
|
||||
Cov = torch.baddbmm(self.eps, Id, 1. / X.shape[1], X.transpose(1, 2), X)
|
||||
|
||||
deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter)
|
||||
|
||||
if self.track_running_stats:
|
||||
self.running_mean.mul_(1 - self.momentum)
|
||||
self.running_mean.add_(X_mean.detach() * self.momentum)
|
||||
# track stats for evaluation
|
||||
self.running_deconv.mul_(1 - self.momentum)
|
||||
self.running_deconv.add_(deconv.detach() * self.momentum)
|
||||
|
||||
else:
|
||||
X_mean = self.running_mean
|
||||
deconv = self.running_deconv
|
||||
|
||||
#4. X * deconv * conv = X * (deconv * conv)
|
||||
if self.groups==1:
|
||||
w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ deconv
|
||||
b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
|
||||
w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
|
||||
else:
|
||||
w = self.weight.view(C//B, -1,self.num_features)@deconv
|
||||
b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)
|
||||
|
||||
w = w.view(self.weight.shape)
|
||||
x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
return x
|
||||
|
||||
### ResNet
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, deconv=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if deconv:
|
||||
self.conv1 = deconv(in_planes, planes, kernel_size=3, stride=stride, padding=1)
|
||||
self.conv2 = deconv(planes, planes, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv = True
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.deconv = False
|
||||
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if not deconv:
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
#self.bn1 = nn.GroupNorm(planes//16,planes)
|
||||
#self.bn2 = nn.GroupNorm(planes//16,planes)
|
||||
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
#nn.GroupNorm(self.expansion * planes//16,self.expansion * planes)
|
||||
)
|
||||
else:
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
deconv(in_planes, self.expansion*planes, kernel_size=1, stride=stride)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.deconv:
|
||||
out = F.relu(self.conv1(x))
|
||||
out = self.conv2(out)
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
else: #self.batch_norm:
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, deconv=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
|
||||
|
||||
if deconv:
|
||||
self.deconv = True
|
||||
self.conv1 = deconv(in_planes, planes, kernel_size=1)
|
||||
self.conv2 = deconv(planes, planes, kernel_size=3, stride=stride, padding=1)
|
||||
self.conv3 = deconv(planes, self.expansion*planes, kernel_size=1)
|
||||
|
||||
else:
|
||||
self.deconv = False
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if not deconv:
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes)
|
||||
)
|
||||
else:
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
deconv(in_planes, self.expansion * planes, kernel_size=1, stride=stride)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
"""
|
||||
No batch normalization for deconv.
|
||||
"""
|
||||
if self.deconv:
|
||||
out = F.relu((self.conv1(x)))
|
||||
out = F.relu((self.conv2(out)))
|
||||
out = self.conv3(out)
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
else:
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10, deconv=None,channel_deconv=None):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
if deconv:
|
||||
self.deconv = True
|
||||
self.conv1 = deconv(3, 64, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
|
||||
if not deconv:
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
|
||||
#this line is really recent, take extreme care if the result is not good.
|
||||
if channel_deconv:
|
||||
self.deconv1=channel_deconv()
|
||||
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, deconv=deconv)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, deconv=deconv)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, deconv=deconv)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, deconv=deconv)
|
||||
self.linear = nn.Linear(512*block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride, deconv):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride, deconv))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self,'bn1'):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
else:
|
||||
out = F.relu(self.conv1(x))
|
||||
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
if hasattr(self, 'deconv1'):
|
||||
out = self.deconv1(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def_deconv = partial(FastDeconv,bias=True, eps=1e-5, n_iter=5,block=64,sampling_stride=3)
|
||||
#channel_deconv=partial(ChannelDeconv, block=512,eps=1e-5, n_iter=5,sampling_stride=3) #Pas forcément conseillé
|
||||
|
||||
def ResNet18_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet34_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet50_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet101_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet152_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
import math
|
||||
class Wide_ResNet_Cifar_DC(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, wfactor, num_classes=10, deconv=None, channel_deconv=None):
|
||||
super(Wide_ResNet_Cifar_DC, self).__init__()
|
||||
self.depth=layers[0]*6+2
|
||||
self.widen_factor=wfactor
|
||||
|
||||
self.inplanes = 16
|
||||
self.conv1 = deconv(3, 16, kernel_size=3, stride=1, padding=1)
|
||||
if channel_deconv:
|
||||
self.deconv1=channel_deconv()
|
||||
# self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
# self.bn1 = nn.BatchNorm2d(16)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(block, 16*wfactor, layers[0], stride=1, deconv=deconv)
|
||||
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2, deconv=deconv)
|
||||
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2, deconv=deconv)
|
||||
self.avgpool = nn.AvgPool2d(8, stride=1)
|
||||
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride, deconv):
|
||||
# downsample = None
|
||||
# if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
# downsample = nn.Sequential(
|
||||
# nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
# nn.BatchNorm2d(planes * block.expansion)
|
||||
# )
|
||||
|
||||
# layers = []
|
||||
# layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
# self.inplanes = planes * block.expansion
|
||||
# for _ in range(1, blocks):
|
||||
# layers.append(block(self.inplanes, planes))
|
||||
|
||||
# return nn.Sequential(*layers)
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.inplanes, planes, stride, deconv))
|
||||
self.inplanes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
# x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
if hasattr(self, 'deconv1'):
|
||||
out = self.deconv1(out)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "Wide_ResNet_cifar_DC%d_%d"%(self.depth,self.widen_factor)
|
||||
|
||||
def WRN_DC26_10(depth=26, width=10, deconv=def_deconv, channel_deconv=None, **kwargs):
|
||||
assert (depth - 2) % 6 == 0
|
||||
n = int((depth - 2) / 6)
|
||||
return Wide_ResNet_Cifar_DC(BasicBlock, [n, n, n], width, deconv=deconv,channel_deconv=channel_deconv, **kwargs)
|
||||
|
||||
def test():
|
||||
net = ResNet18_DC()
|
||||
y = net(torch.randn(1,3,32,32))
|
||||
print(y.size())
|
||||
|
||||
# test()
|
98
higher/smart_aug/nets/wideresnet.py
Normal file
98
higher/smart_aug/nets/wideresnet.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
_bn_momentum = 0.1
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
|
||||
|
||||
|
||||
def conv_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
init.xavier_uniform_(m.weight, gain=np.sqrt(2))
|
||||
init.constant_(m.bias, 0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class WideBasic(nn.Module):
|
||||
def __init__(self, in_planes, planes, dropout_rate, stride=1):
|
||||
super(WideBasic, self).__init__()
|
||||
assert dropout_rate==0.0, 'dropout layer not used'
|
||||
self.bn1 = nn.BatchNorm2d(in_planes, momentum=_bn_momentum)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
|
||||
#self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.bn2 = nn.BatchNorm2d(planes, momentum=_bn_momentum)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# out = self.dropout(self.conv1(F.relu(self.bn1(x))))
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out += self.shortcut(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class WideResNet(nn.Module):
|
||||
def __init__(self, depth, widen_factor, dropout_rate, num_classes):
|
||||
super(WideResNet, self).__init__()
|
||||
self.depth=depth
|
||||
self.widen_factor=widen_factor
|
||||
self.in_planes = 16
|
||||
|
||||
assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
|
||||
n = int((depth - 4) / 6)
|
||||
k = widen_factor
|
||||
|
||||
nStages = [16, 16*k, 32*k, 64*k]
|
||||
|
||||
self.conv1 = conv3x3(3, nStages[0])
|
||||
self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
|
||||
self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
|
||||
self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)
|
||||
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=_bn_momentum)
|
||||
self.linear = nn.Linear(nStages[3], num_classes)
|
||||
|
||||
# self.apply(conv_init)
|
||||
|
||||
def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, dropout_rate, stride))
|
||||
self.in_planes = planes
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = F.relu(self.bn1(out))
|
||||
# out = F.avg_pool2d(out, 8)
|
||||
out = F.adaptive_avg_pool2d(out, (1, 1))
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
|
||||
return out
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "Wide_ResNet%d_%d"%(self.depth,self.widen_factor)
|
119
higher/smart_aug/nets/wideresnet_cifar.py
Normal file
119
higher/smart_aug/nets/wideresnet_cifar.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
wide resnet for cifar in pytorch
|
||||
Reference:
|
||||
[1] S. Zagoruyko and N. Komodakis. Wide residual networks. In BMVC, 2016.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
#from models.resnet_cifar import BasicBlock
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
" 3x3 convolution with padding "
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion=1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class Wide_ResNet_Cifar(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, wfactor, num_classes=10):
|
||||
super(Wide_ResNet_Cifar, self).__init__()
|
||||
self.depth=layers[0]*6+2
|
||||
self.widen_factor=wfactor
|
||||
|
||||
self.inplanes = 16
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(block, 16*wfactor, layers[0])
|
||||
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(8, stride=1)
|
||||
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion)
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "Wide_ResNet_cifar%d_%d"%(self.depth,self.widen_factor)
|
||||
|
||||
|
||||
def wide_resnet_cifar(depth, width, **kwargs):
|
||||
assert (depth - 2) % 6 == 0
|
||||
n = int((depth - 2) / 6)
|
||||
return Wide_ResNet_Cifar(BasicBlock, [n, n, n], width, **kwargs)
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
net = wide_resnet_cifar(20, 10)
|
||||
y = net(torch.randn(1, 3, 32, 32))
|
||||
print(isinstance(net, Wide_ResNet_Cifar))
|
||||
print(y.size())
|
490
higher/smart_aug/old/augmentation_transforms.py
Normal file
490
higher/smart_aug/old/augmentation_transforms.py
Normal file
|
@ -0,0 +1,490 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2019 The Google UDA Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Transforms used in the Augmentation Policies.
|
||||
|
||||
Copied from AutoAugment: https://github.com/tensorflow/models/blob/master/research/autoaugment/
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
# pylint:disable=g-multiple-import
|
||||
from PIL import ImageOps, ImageEnhance, ImageFilter, Image
|
||||
# pylint:enable=g-multiple-import
|
||||
|
||||
#import tensorflow as tf
|
||||
|
||||
#FLAGS = tf.flags.FLAGS
|
||||
|
||||
|
||||
IMAGE_SIZE = 32
|
||||
# What is the dataset mean and std of the images on the training set
|
||||
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
|
||||
|
||||
|
||||
def get_mean_and_std():
|
||||
#if FLAGS.task_name == "cifar10":
|
||||
means = [0.49139968, 0.48215841, 0.44653091]
|
||||
stds = [0.24703223, 0.24348513, 0.26158784]
|
||||
#elif FLAGS.task_name == "svhn":
|
||||
# means = [0.4376821, 0.4437697, 0.47280442]
|
||||
# stds = [0.19803012, 0.20101562, 0.19703614]
|
||||
#else:
|
||||
# assert False
|
||||
return means, stds
|
||||
|
||||
|
||||
def random_flip(x):
|
||||
"""Flip the input x horizontally with 50% probability."""
|
||||
if np.random.rand(1)[0] > 0.5:
|
||||
return np.fliplr(x)
|
||||
return x
|
||||
|
||||
|
||||
def zero_pad_and_crop(img, amount=4):
|
||||
"""Zero pad by `amount` zero pixels on each side then take a random crop.
|
||||
|
||||
Args:
|
||||
img: numpy image that will be zero padded and cropped.
|
||||
amount: amount of zeros to pad `img` with horizontally and verically.
|
||||
|
||||
Returns:
|
||||
The cropped zero padded img. The returned numpy array will be of the same
|
||||
shape as `img`.
|
||||
"""
|
||||
padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2,
|
||||
img.shape[2]))
|
||||
padded_img[amount:img.shape[0] + amount, amount:
|
||||
img.shape[1] + amount, :] = img
|
||||
top = np.random.randint(low=0, high=2 * amount)
|
||||
left = np.random.randint(low=0, high=2 * amount)
|
||||
new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :]
|
||||
return new_img
|
||||
|
||||
|
||||
def create_cutout_mask(img_height, img_width, num_channels, size):
|
||||
"""Creates a zero mask used for cutout of shape `img_height` x `img_width`.
|
||||
|
||||
Args:
|
||||
img_height: Height of image cutout mask will be applied to.
|
||||
img_width: Width of image cutout mask will be applied to.
|
||||
num_channels: Number of channels in the image.
|
||||
size: Size of the zeros mask.
|
||||
|
||||
Returns:
|
||||
A mask of shape `img_height` x `img_width` with all ones except for a
|
||||
square of zeros of shape `size` x `size`. This mask is meant to be
|
||||
elementwise multiplied with the original image. Additionally returns
|
||||
the `upper_coord` and `lower_coord` which specify where the cutout mask
|
||||
will be applied.
|
||||
"""
|
||||
assert img_height == img_width
|
||||
|
||||
# Sample center where cutout mask will be applied
|
||||
height_loc = np.random.randint(low=0, high=img_height)
|
||||
width_loc = np.random.randint(low=0, high=img_width)
|
||||
|
||||
# Determine upper right and lower left corners of patch
|
||||
upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
|
||||
lower_coord = (min(img_height, height_loc + size // 2),
|
||||
min(img_width, width_loc + size // 2))
|
||||
mask_height = lower_coord[0] - upper_coord[0]
|
||||
mask_width = lower_coord[1] - upper_coord[1]
|
||||
assert mask_height > 0
|
||||
assert mask_width > 0
|
||||
|
||||
mask = np.ones((img_height, img_width, num_channels))
|
||||
zeros = np.zeros((mask_height, mask_width, num_channels))
|
||||
mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = (
|
||||
zeros)
|
||||
return mask, upper_coord, lower_coord
|
||||
|
||||
|
||||
def cutout_numpy(img, size=16):
|
||||
"""Apply cutout with mask of shape `size` x `size` to `img`.
|
||||
|
||||
The cutout operation is from the paper https://arxiv.org/abs/1708.04552.
|
||||
This operation applies a `size`x`size` mask of zeros to a random location
|
||||
within `img`.
|
||||
|
||||
Args:
|
||||
img: Numpy image that cutout will be applied to.
|
||||
size: Height/width of the cutout mask that will be
|
||||
|
||||
Returns:
|
||||
A numpy tensor that is the result of applying the cutout mask to `img`.
|
||||
"""
|
||||
img_height, img_width, num_channels = (img.shape[0], img.shape[1],
|
||||
img.shape[2])
|
||||
assert len(img.shape) == 3
|
||||
mask, _, _ = create_cutout_mask(img_height, img_width, num_channels, size)
|
||||
return img * mask
|
||||
|
||||
|
||||
def float_parameter(level, maxval):
|
||||
"""Helper function to scale `val` between 0 and maxval .
|
||||
|
||||
Args:
|
||||
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
||||
maxval: Maximum value that the operation can have. This will be scaled
|
||||
to level/PARAMETER_MAX.
|
||||
|
||||
Returns:
|
||||
A float that results from scaling `maxval` according to `level`.
|
||||
"""
|
||||
return float(level) * maxval / PARAMETER_MAX
|
||||
|
||||
|
||||
def int_parameter(level, maxval):
|
||||
"""Helper function to scale `val` between 0 and maxval .
|
||||
|
||||
Args:
|
||||
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
||||
maxval: Maximum value that the operation can have. This will be scaled
|
||||
to level/PARAMETER_MAX.
|
||||
|
||||
Returns:
|
||||
An int that results from scaling `maxval` according to `level`.
|
||||
"""
|
||||
return int(level * maxval / PARAMETER_MAX)
|
||||
|
||||
|
||||
def pil_wrap(img, use_mean_std):
|
||||
"""Convert the `img` numpy tensor to a PIL Image."""
|
||||
|
||||
if use_mean_std:
|
||||
MEANS, STDS = get_mean_and_std()
|
||||
else:
|
||||
MEANS = [0, 0, 0]
|
||||
STDS = [1, 1, 1]
|
||||
img_ori = (img * STDS + MEANS) * 255
|
||||
|
||||
return Image.fromarray(
|
||||
np.uint8((img * STDS + MEANS) * 255.0)).convert('RGBA')
|
||||
|
||||
|
||||
def pil_unwrap(pil_img, use_mean_std, img_shape):
|
||||
"""Converts the PIL img to a numpy array."""
|
||||
if use_mean_std:
|
||||
MEANS, STDS = get_mean_and_std()
|
||||
else:
|
||||
MEANS = [0, 0, 0]
|
||||
STDS = [1, 1, 1]
|
||||
pic_array = np.array(pil_img.getdata()).reshape((img_shape[0], img_shape[1], 4)) / 255.0
|
||||
i1, i2 = np.where(pic_array[:, :, 3] == 0)
|
||||
pic_array = (pic_array[:, :, :3] - MEANS) / STDS
|
||||
pic_array[i1, i2] = [0, 0, 0]
|
||||
return pic_array
|
||||
|
||||
|
||||
def apply_policy(policy, img, use_mean_std=True):
|
||||
"""Apply the `policy` to the numpy `img`.
|
||||
|
||||
Args:
|
||||
policy: A list of tuples with the form (name, probability, level) where
|
||||
`name` is the name of the augmentation operation to apply, `probability`
|
||||
is the probability of applying the operation and `level` is what strength
|
||||
the operation to apply.
|
||||
img: Numpy image that will have `policy` applied to it.
|
||||
|
||||
Returns:
|
||||
The result of applying `policy` to `img`.
|
||||
"""
|
||||
img_shape = img.shape
|
||||
pil_img = pil_wrap(img, use_mean_std)
|
||||
for xform in policy:
|
||||
assert len(xform) == 3
|
||||
name, probability, level = xform
|
||||
xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(
|
||||
probability, level, img_shape)
|
||||
pil_img = xform_fn(pil_img)
|
||||
return pil_unwrap(pil_img, use_mean_std, img_shape)
|
||||
|
||||
|
||||
class TransformFunction(object):
|
||||
"""Wraps the Transform function for pretty printing options."""
|
||||
|
||||
def __init__(self, func, name):
|
||||
self.f = func
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return '<' + self.name + '>'
|
||||
|
||||
def __call__(self, pil_img):
|
||||
return self.f(pil_img)
|
||||
|
||||
|
||||
class TransformT(object):
|
||||
"""Each instance of this class represents a specific transform."""
|
||||
|
||||
def __init__(self, name, xform_fn):
|
||||
self.name = name
|
||||
self.xform = xform_fn
|
||||
|
||||
def pil_transformer(self, probability, level, img_shape):
|
||||
|
||||
def return_function(im):
|
||||
if random.random() < probability:
|
||||
im = self.xform(im, level, img_shape)
|
||||
return im
|
||||
|
||||
name = self.name + '({:.1f},{})'.format(probability, level)
|
||||
return TransformFunction(return_function, name)
|
||||
|
||||
|
||||
################## Transform Functions ##################
|
||||
identity = TransformT('identity', lambda pil_img, level, _: pil_img)
|
||||
flip_lr = TransformT(
|
||||
'FlipLR',
|
||||
lambda pil_img, level, _: pil_img.transpose(Image.FLIP_LEFT_RIGHT))
|
||||
flip_ud = TransformT(
|
||||
'FlipUD',
|
||||
lambda pil_img, level, _: pil_img.transpose(Image.FLIP_TOP_BOTTOM))
|
||||
# pylint:disable=g-long-lambda
|
||||
auto_contrast = TransformT(
|
||||
'AutoContrast',
|
||||
lambda pil_img, level, _: ImageOps.autocontrast(
|
||||
pil_img.convert('RGB')).convert('RGBA'))
|
||||
equalize = TransformT(
|
||||
'Equalize',
|
||||
lambda pil_img, level, _: ImageOps.equalize(
|
||||
pil_img.convert('RGB')).convert('RGBA'))
|
||||
invert = TransformT(
|
||||
'Invert',
|
||||
lambda pil_img, level, _: ImageOps.invert(
|
||||
pil_img.convert('RGB')).convert('RGBA'))
|
||||
# pylint:enable=g-long-lambda
|
||||
blur = TransformT(
|
||||
'Blur', lambda pil_img, level, _: pil_img.filter(ImageFilter.BLUR))
|
||||
smooth = TransformT(
|
||||
'Smooth',
|
||||
lambda pil_img, level, _: pil_img.filter(ImageFilter.SMOOTH))
|
||||
|
||||
|
||||
def _rotate_impl(pil_img, level, _):
|
||||
"""Rotates `pil_img` from -30 to 30 degrees depending on `level`."""
|
||||
degrees = int_parameter(level, 30)
|
||||
if random.random() > 0.5:
|
||||
degrees = -degrees
|
||||
return pil_img.rotate(degrees)
|
||||
|
||||
|
||||
rotate = TransformT('Rotate', _rotate_impl)
|
||||
|
||||
|
||||
def _posterize_impl(pil_img, level, _):
|
||||
"""Applies PIL Posterize to `pil_img`."""
|
||||
level = int_parameter(level, 4)
|
||||
return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA')
|
||||
|
||||
|
||||
posterize = TransformT('Posterize', _posterize_impl)
|
||||
|
||||
|
||||
def _shear_x_impl(pil_img, level, img_shape):
|
||||
"""Applies PIL ShearX to `pil_img`.
|
||||
|
||||
The ShearX operation shears the image along the horizontal axis with `level`
|
||||
magnitude.
|
||||
|
||||
Args:
|
||||
pil_img: Image in PIL object.
|
||||
level: Strength of the operation specified as an Integer from
|
||||
[0, `PARAMETER_MAX`].
|
||||
|
||||
Returns:
|
||||
A PIL Image that has had ShearX applied to it.
|
||||
"""
|
||||
level = float_parameter(level, 0.3)
|
||||
if random.random() > 0.5:
|
||||
level = -level
|
||||
return pil_img.transform(
|
||||
(img_shape[0], img_shape[1]),
|
||||
Image.AFFINE,
|
||||
(1, level, 0, 0, 1, 0))
|
||||
|
||||
|
||||
shear_x = TransformT('ShearX', _shear_x_impl)
|
||||
|
||||
|
||||
def _shear_y_impl(pil_img, level, img_shape):
|
||||
"""Applies PIL ShearY to `pil_img`.
|
||||
|
||||
The ShearY operation shears the image along the vertical axis with `level`
|
||||
magnitude.
|
||||
|
||||
Args:
|
||||
pil_img: Image in PIL object.
|
||||
level: Strength of the operation specified as an Integer from
|
||||
[0, `PARAMETER_MAX`].
|
||||
|
||||
Returns:
|
||||
A PIL Image that has had ShearX applied to it.
|
||||
"""
|
||||
level = float_parameter(level, 0.3)
|
||||
if random.random() > 0.5:
|
||||
level = -level
|
||||
return pil_img.transform(
|
||||
(img_shape[0], img_shape[1]),
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, level, 1, 0))
|
||||
|
||||
|
||||
shear_y = TransformT('ShearY', _shear_y_impl)
|
||||
|
||||
|
||||
def _translate_x_impl(pil_img, level, img_shape):
|
||||
"""Applies PIL TranslateX to `pil_img`.
|
||||
|
||||
Translate the image in the horizontal direction by `level`
|
||||
number of pixels.
|
||||
|
||||
Args:
|
||||
pil_img: Image in PIL object.
|
||||
level: Strength of the operation specified as an Integer from
|
||||
[0, `PARAMETER_MAX`].
|
||||
|
||||
Returns:
|
||||
A PIL Image that has had TranslateX applied to it.
|
||||
"""
|
||||
level = int_parameter(level, 10)
|
||||
if random.random() > 0.5:
|
||||
level = -level
|
||||
return pil_img.transform(
|
||||
(img_shape[0], img_shape[1]),
|
||||
Image.AFFINE,
|
||||
(1, 0, level, 0, 1, 0))
|
||||
|
||||
|
||||
translate_x = TransformT('TranslateX', _translate_x_impl)
|
||||
|
||||
|
||||
def _translate_y_impl(pil_img, level, img_shape):
|
||||
"""Applies PIL TranslateY to `pil_img`.
|
||||
|
||||
Translate the image in the vertical direction by `level`
|
||||
number of pixels.
|
||||
|
||||
Args:
|
||||
pil_img: Image in PIL object.
|
||||
level: Strength of the operation specified as an Integer from
|
||||
[0, `PARAMETER_MAX`].
|
||||
|
||||
Returns:
|
||||
A PIL Image that has had TranslateY applied to it.
|
||||
"""
|
||||
level = int_parameter(level, 10)
|
||||
if random.random() > 0.5:
|
||||
level = -level
|
||||
return pil_img.transform(
|
||||
(img_shape[0], img_shape[1]),
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, 0, 1, level))
|
||||
|
||||
|
||||
translate_y = TransformT('TranslateY', _translate_y_impl)
|
||||
|
||||
|
||||
def _crop_impl(pil_img, level, img_shape, interpolation=Image.BILINEAR):
|
||||
"""Applies a crop to `pil_img` with the size depending on the `level`."""
|
||||
cropped = pil_img.crop((level, level, img_shape[0] - level, img_shape[1] - level))
|
||||
resized = cropped.resize((img_shape[0], img_shape[1]), interpolation)
|
||||
return resized
|
||||
|
||||
|
||||
crop_bilinear = TransformT('CropBilinear', _crop_impl)
|
||||
|
||||
|
||||
def _solarize_impl(pil_img, level, _):
|
||||
"""Applies PIL Solarize to `pil_img`.
|
||||
|
||||
Translate the image in the vertical direction by `level`
|
||||
number of pixels.
|
||||
|
||||
Args:
|
||||
pil_img: Image in PIL object.
|
||||
level: Strength of the operation specified as an Integer from
|
||||
[0, `PARAMETER_MAX`].
|
||||
|
||||
Returns:
|
||||
A PIL Image that has had Solarize applied to it.
|
||||
"""
|
||||
level = int_parameter(level, 256)
|
||||
return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA')
|
||||
|
||||
|
||||
solarize = TransformT('Solarize', _solarize_impl)
|
||||
|
||||
|
||||
def _cutout_pil_impl(pil_img, level, img_shape):
|
||||
"""Apply cutout to pil_img at the specified level."""
|
||||
size = int_parameter(level, 20)
|
||||
if size <= 0:
|
||||
return pil_img
|
||||
img_height, img_width, num_channels = (img_shape[0], img_shape[1], 3)
|
||||
_, upper_coord, lower_coord = (
|
||||
create_cutout_mask(img_height, img_width, num_channels, size))
|
||||
pixels = pil_img.load() # create the pixel map
|
||||
for i in range(upper_coord[0], lower_coord[0]): # for every col:
|
||||
for j in range(upper_coord[1], lower_coord[1]): # For every row
|
||||
pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly
|
||||
return pil_img
|
||||
|
||||
cutout = TransformT('Cutout', _cutout_pil_impl)
|
||||
|
||||
|
||||
def _enhancer_impl(enhancer):
|
||||
"""Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL."""
|
||||
def impl(pil_img, level, _):
|
||||
v = float_parameter(level, 1.8) + .1 # going to 0 just destroys it
|
||||
return enhancer(pil_img).enhance(v)
|
||||
return impl
|
||||
|
||||
|
||||
color = TransformT('Color', _enhancer_impl(ImageEnhance.Color))
|
||||
contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast))
|
||||
brightness = TransformT('Brightness', _enhancer_impl(
|
||||
ImageEnhance.Brightness))
|
||||
sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness))
|
||||
|
||||
ALL_TRANSFORMS = [
|
||||
flip_lr,
|
||||
flip_ud,
|
||||
auto_contrast,
|
||||
equalize,
|
||||
invert,
|
||||
rotate,
|
||||
posterize,
|
||||
crop_bilinear,
|
||||
solarize,
|
||||
color,
|
||||
contrast,
|
||||
brightness,
|
||||
sharpness,
|
||||
shear_x,
|
||||
shear_y,
|
||||
translate_x,
|
||||
translate_y,
|
||||
cutout,
|
||||
blur,
|
||||
smooth
|
||||
]
|
||||
|
||||
NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS}
|
||||
TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys()
|
14
higher/smart_aug/old/compare_TF.py
Normal file
14
higher/smart_aug/old/compare_TF.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
import RandAugment as rand
|
||||
import PIL
|
||||
import torchvision
|
||||
import transformations as TF
|
||||
tpil=torchvision.transforms.ToPILImage()
|
||||
ttensor=torchvision.transforms.ToTensor()
|
||||
|
||||
img,label =data_train[0]
|
||||
|
||||
rimg=ttensor(PIL.ImageEnhance.Color(tpil(img)).enhance(1.5))#ttensor(PIL.ImageOps.solarize(tpil(img), 50))#ttensor(tpil(img).transform(tpil(img).size, PIL.Image.AFFINE, (1, -0.1, 0, 0, 1, 0)))#rand.augmentations.FlipUD(tpil(img),1))
|
||||
timg=TF.color(img.unsqueeze(0),torch.Tensor([1.5])).squeeze(0)
|
||||
print(torch.allclose(rimg,timg, atol=1e-3))
|
||||
tpil(rimg).save('rimg.jpg')
|
||||
tpil(timg).save('timg.jpg')
|
1424
higher/smart_aug/old/dataug_old.py
Normal file
1424
higher/smart_aug/old/dataug_old.py
Normal file
File diff suppressed because it is too large
Load diff
85
higher/smart_aug/old/higher_repro.py
Normal file
85
higher/smart_aug/old/higher_repro.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import higher
|
||||
import time
|
||||
|
||||
data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=torchvision.transforms.ToTensor())
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=300, shuffle=True, num_workers=0, pin_memory=False)
|
||||
|
||||
|
||||
class Aug_model(nn.Module):
|
||||
def __init__(self, model, hyper_param=True):
|
||||
super(Aug_model, self).__init__()
|
||||
|
||||
#### Origin of the issue ? ####
|
||||
if hyper_param:
|
||||
self._params = nn.ParameterDict({
|
||||
"hyper_param": nn.Parameter(torch.Tensor([0.5])),
|
||||
})
|
||||
###############################
|
||||
|
||||
self._mods = nn.ModuleDict({
|
||||
'model': model,
|
||||
})
|
||||
|
||||
def forward(self, x):
|
||||
return self._mods['model'](x) #* self._params['hyper_param']
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._mods[key]
|
||||
|
||||
class Aug_model2(nn.Module): #Slow increase like no hyper_param
|
||||
def __init__(self, model, hyper_param=True):
|
||||
super(Aug_model2, self).__init__()
|
||||
|
||||
#### Origin of the issue ? ####
|
||||
if hyper_param:
|
||||
self._params = nn.ParameterDict({
|
||||
"hyper_param": nn.Parameter(torch.Tensor([0.5])),
|
||||
})
|
||||
###############################
|
||||
|
||||
self._mods = nn.ModuleDict({
|
||||
'model': model,
|
||||
'fmodel': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
})
|
||||
|
||||
def forward(self, x):
|
||||
return self._mods['fmodel'](x) * self._params['hyper_param']
|
||||
|
||||
def get_diffopt(self, opt, track_higher_grads=True):
|
||||
return higher.optim.get_diff_optim(opt,
|
||||
self._mods['model'].parameters(),
|
||||
fmodel=self._mods['fmodel'],
|
||||
track_higher_grads=track_higher_grads)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._mods[key]
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
device = torch.device('cuda:1')
|
||||
aug_model = Aug_model2(
|
||||
model=torch.hub.load('pytorch/vision:v0.4.2', 'resnet18', pretrained=False),
|
||||
hyper_param=True #False will not extend step time
|
||||
).to(device)
|
||||
|
||||
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
#fmodel = higher.patch.monkeypatch(aug_model, device=None, copy_initial_weights=True)
|
||||
#diffopt = higher.optim.get_diff_optim(inner_opt, aug_model.parameters(),fmodel=fmodel,track_higher_grads=True)
|
||||
diffopt = aug_model.get_diffopt(inner_opt)
|
||||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
#logits = fmodel(xs)
|
||||
logits = aug_model(xs)
|
||||
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean')
|
||||
|
||||
t = time.process_time()
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
#print(len(fmodel._fast_params),"step", time.process_time()-t)
|
||||
print(len(aug_model['fmodel']._fast_params),"step", time.process_time()-t)
|
502
higher/smart_aug/old/model_old.py
Normal file
502
higher/smart_aug/old/model_old.py
Normal file
|
@ -0,0 +1,502 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
## Basic CNN ##
|
||||
class LeNet_F(nn.Module):
|
||||
def __init__(self, num_inp, num_out):
|
||||
super(LeNet_F, self).__init__()
|
||||
self._params = nn.ParameterDict({
|
||||
'w1': nn.Parameter(torch.zeros(20, num_inp, 5, 5)),
|
||||
'b1': nn.Parameter(torch.zeros(20)),
|
||||
'w2': nn.Parameter(torch.zeros(50, 20, 5, 5)),
|
||||
'b2': nn.Parameter(torch.zeros(50)),
|
||||
#'w3': nn.Parameter(torch.zeros(500,4*4*50)), #num_imp=1
|
||||
'w3': nn.Parameter(torch.zeros(500,5*5*50)), #num_imp=3
|
||||
'b3': nn.Parameter(torch.zeros(500)),
|
||||
'w4': nn.Parameter(torch.zeros(num_out, 500)),
|
||||
'b4': nn.Parameter(torch.zeros(num_out))
|
||||
})
|
||||
self.initialize()
|
||||
|
||||
|
||||
def initialize(self):
|
||||
nn.init.kaiming_uniform_(self._params["w1"], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self._params["w2"], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self._params["w3"], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self._params["w4"], a=math.sqrt(5))
|
||||
|
||||
def forward(self, x):
|
||||
#print("Start Shape ", x.shape)
|
||||
out = F.relu(F.conv2d(input=x, weight=self._params["w1"], bias=self._params["b1"]))
|
||||
#print("Shape ", out.shape)
|
||||
out = F.max_pool2d(out, 2)
|
||||
#print("Shape ", out.shape)
|
||||
out = F.relu(F.conv2d(input=out, weight=self._params["w2"], bias=self._params["b2"]))
|
||||
#print("Shape ", out.shape)
|
||||
out = F.max_pool2d(out, 2)
|
||||
#print("Shape ", out.shape)
|
||||
out = out.view(out.size(0), -1)
|
||||
#print("Shape ", out.shape)
|
||||
out = F.relu(F.linear(out, self._params["w3"], self._params["b3"]))
|
||||
#print("Shape ", out.shape)
|
||||
out = F.linear(out, self._params["w4"], self._params["b4"])
|
||||
#print("Shape ", out.shape)
|
||||
#return F.log_softmax(out, dim=1)
|
||||
return out
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
return "LeNet"
|
||||
|
||||
|
||||
## MobileNetv2 ##
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
padding = (kernel_size - 1) // 2
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self,
|
||||
num_classes=1000,
|
||||
width_mult=1.0,
|
||||
inverted_residual_setting=None,
|
||||
round_nearest=8,
|
||||
block=None):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features = [ConvBNReLU(3, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
x = self.features(x)
|
||||
x = x.mean([2, 3])
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def __str__(self):
|
||||
return "MobileNetV2"
|
||||
|
||||
## ResNet ##
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
__constants__ = ['downsample']
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
__constants__ = ['downsample']
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
#ResNet18 : block=BasicBlock, layers=[2, 2, 2, 2]
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=1000, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def __str__(self):
|
||||
return "ResNet18"
|
||||
|
||||
## Wide ResNet ##
|
||||
#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
|
||||
#https://github.com/arcelien/pba/blob/master/pba/wrn.py
|
||||
#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py
|
||||
'''
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.droprate = dropRate
|
||||
self.equalInOut = (in_planes == out_planes)
|
||||
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
||||
padding=0, bias=False) or None
|
||||
def forward(self, x):
|
||||
if not self.equalInOut:
|
||||
x = self.relu1(self.bn1(x))
|
||||
else:
|
||||
out = self.relu1(self.bn1(x))
|
||||
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, training=self.training)
|
||||
out = self.conv2(out)
|
||||
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
|
||||
|
||||
class NetworkBlock(nn.Module):
|
||||
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
|
||||
super(NetworkBlock, self).__init__()
|
||||
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
|
||||
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
|
||||
layers = []
|
||||
for i in range(int(nb_layers)):
|
||||
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
|
||||
return nn.Sequential(*layers)
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
#wrn_size: 32 = WRN-28-2 ? 160 = WRN-28-10
|
||||
class WideResNet(nn.Module):
|
||||
#def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
|
||||
def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0):
|
||||
super(WideResNet, self).__init__()
|
||||
|
||||
self.kernel_size = wrn_size
|
||||
self.depth=depth
|
||||
filter_size = 3
|
||||
nChannels = [min(self.kernel_size, 16), self.kernel_size, self.kernel_size * 2, self.kernel_size * 4]
|
||||
strides = [1, 2, 2] # stride for each resblock
|
||||
|
||||
#nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
|
||||
assert((depth - 4) % 6 == 0)
|
||||
n = (depth - 4) / 6
|
||||
block = BasicBlock
|
||||
# 1st conv before any network block
|
||||
self.conv1 = nn.Conv2d(filter_size, nChannels[0], kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
# 1st block
|
||||
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, strides[0], dropRate)
|
||||
# 2nd block
|
||||
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, strides[1], dropRate)
|
||||
# 3rd block
|
||||
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, strides[2], dropRate)
|
||||
# global average pooling and classifier
|
||||
self.bn1 = nn.BatchNorm2d(nChannels[3])
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc = nn.Linear(nChannels[3], num_classes)
|
||||
self.nChannels = nChannels[3]
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.block1(out)
|
||||
out = self.block2(out)
|
||||
out = self.block3(out)
|
||||
out = self.relu(self.bn1(out))
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = out.view(-1, self.nChannels)
|
||||
return self.fc(out)
|
||||
|
||||
def architecture(self):
|
||||
return super(WideResNet, self).__str__()
|
||||
|
||||
def __str__(self):
|
||||
return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth)
|
||||
'''
|
184
higher/smart_aug/old/test_brutus.py
Normal file
184
higher/smart_aug/old/test_brutus.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
from model import *
|
||||
from dataug import *
|
||||
#from utils import *
|
||||
from train_utils import *
|
||||
|
||||
import torchvision.models as models
|
||||
|
||||
# Use available TF (see transformations.py)
|
||||
tf_names = [
|
||||
## Geometric TF ##
|
||||
'Identity',
|
||||
'FlipUD',
|
||||
'FlipLR',
|
||||
'Rotate',
|
||||
'TranslateX',
|
||||
'TranslateY',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast',
|
||||
'Color',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'Posterize',
|
||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||
|
||||
## Bad Tranformations ##
|
||||
# Bad Geometric TF #
|
||||
#'BShearX',
|
||||
#'BShearY',
|
||||
#'BTranslateX-',
|
||||
#'BTranslateX-',
|
||||
#'BTranslateY',
|
||||
#'BTranslateY-',
|
||||
|
||||
#'BadContrast',
|
||||
#'BadBrightness',
|
||||
|
||||
#'Random',
|
||||
#'RandBlend'
|
||||
]
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
if device == torch.device('cpu'):
|
||||
device_name = 'CPU'
|
||||
else:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True #Faster if same input size #Not recommended for reproductibility
|
||||
|
||||
#Increase reproductibility
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
n_inner_iter = 1
|
||||
epochs = 150
|
||||
dataug_epoch_start=0
|
||||
optim_param={
|
||||
'Meta':{
|
||||
'optim':'Adam',
|
||||
'lr':1e-2, #1e-2
|
||||
},
|
||||
'Inner':{
|
||||
'optim': 'SGD',
|
||||
'lr':1e-1, #1e-2
|
||||
'momentum':0.9, #0.9
|
||||
}
|
||||
}
|
||||
|
||||
model=models.resnet18()
|
||||
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
|
||||
####
|
||||
'''
|
||||
t0 = time.process_time()
|
||||
|
||||
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None)
|
||||
|
||||
exec_time=time.process_time() - t0
|
||||
####
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
|
||||
with open("res/log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
'''
|
||||
|
||||
####
|
||||
'''
|
||||
t0 = time.process_time()
|
||||
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None)
|
||||
|
||||
exec_time=time.process_time() - t0
|
||||
####
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
|
||||
with open("res/log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
'''
|
||||
res_folder="../res/brutus-tests2/"
|
||||
epochs= 150
|
||||
inner_its = [1]
|
||||
dist_mix = [0.0, 0.5, 0.8, 1.0]
|
||||
dataug_epoch_starts= [0]
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
|
||||
N_seq_TF= [4, 3, 2]
|
||||
mag_setup = [(True,True), (False, False)] #(Fixed, Shared)
|
||||
#prob_setup = [True, False]
|
||||
nb_run= 3
|
||||
|
||||
try:
|
||||
os.mkdir(res_folder)
|
||||
os.mkdir(res_folder+"log/")
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
for n_inner_iter in inner_its:
|
||||
for dataug_epoch_start in dataug_epoch_starts:
|
||||
for n_tf in N_seq_TF:
|
||||
for dist in dist_mix:
|
||||
#for i in TF_nb:
|
||||
for m_setup in mag_setup:
|
||||
#for p_setup in prob_setup:
|
||||
p_setup=False
|
||||
for run in range(nb_run):
|
||||
if (n_inner_iter == 0 and (m_setup!=(True,True) and p_setup!=True)) or (p_setup and dist!=0.0): continue #Autres setup inutiles sans meta-opti
|
||||
#keys = list(TF.TF_dict.keys())[0:i]
|
||||
#ntf_dict = {k: TF.TF_dict[k] for k in keys}
|
||||
|
||||
t0 = time.process_time()
|
||||
|
||||
model = ResNet(num_classes=10)
|
||||
model = Higher_model(model) #run_dist_dataugV3
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
|
||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||
log= run_dist_dataugV3(model=aug_model,
|
||||
epochs=epochs,
|
||||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=50,
|
||||
KLdiv=True)
|
||||
|
||||
exec_time=time.process_time() - t0
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter, run)
|
||||
with open("../res/log/%s.json" % filename, "w+") as f:
|
||||
try:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
except:
|
||||
print("Failed to save logs :",f.name)
|
||||
try:
|
||||
plot_resV2(log, fig_name="../res/"+filename, param_names=aug_model.TF_names())
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
|
||||
print('Execution Time : %.00f '%(exec_time))
|
||||
print('-'*9)
|
||||
#'''
|
150
higher/smart_aug/old/test_lr.py
Normal file
150
higher/smart_aug/old/test_lr.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
import numpy as np
|
||||
import json, math, time, os
|
||||
|
||||
from torch.utils.data import SubsetRandomSampler
|
||||
import torch.optim as optim
|
||||
import higher
|
||||
from model import *
|
||||
|
||||
import copy
|
||||
|
||||
BATCH_SIZE = 300
|
||||
TEST_SIZE = 300
|
||||
|
||||
mnist_train = torchvision.datasets.MNIST(
|
||||
"./data", train=True, download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
#torchvision.transforms.RandomAffine(degrees=180, translate=None, scale=None, shear=None, resample=False, fillcolor=0),
|
||||
torchvision.transforms.ToTensor()
|
||||
])
|
||||
)
|
||||
|
||||
mnist_test = torchvision.datasets.MNIST(
|
||||
"./data", train=False, download=True, transform=torchvision.transforms.ToTensor()
|
||||
)
|
||||
|
||||
#train_subset_indices=range(int(len(mnist_train)/2))
|
||||
train_subset_indices=range(BATCH_SIZE)
|
||||
val_subset_indices=range(int(len(mnist_train)/2),len(mnist_train))
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||
dl_val = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
|
||||
dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=TEST_SIZE, shuffle=False)
|
||||
|
||||
|
||||
def test(model):
|
||||
model.eval()
|
||||
for i, (features, labels) in enumerate(dl_test):
|
||||
pred = model.forward(features)
|
||||
return pred.argmax(dim=1).eq(labels).sum().item() / TEST_SIZE * 100
|
||||
|
||||
def train_classic(model, optim, epochs=1):
|
||||
model.train()
|
||||
log = []
|
||||
for epoch in range(epochs):
|
||||
t0 = time.process_time()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
|
||||
optim.zero_grad()
|
||||
pred = model.forward(features)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
#### Log ####
|
||||
tf = time.process_time()
|
||||
data={
|
||||
"time": tf - t0,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
times = [x["time"] for x in log]
|
||||
print("Vanilla : acc", test(model), "in (ms):", np.mean(times), "+/-", np.std(times))
|
||||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
||||
model = LeNet(1,10)
|
||||
opt_param = {
|
||||
"lr": torch.tensor(1e-2).requires_grad_(),
|
||||
"momentum": torch.tensor(0.9).requires_grad_()
|
||||
}
|
||||
n_inner_iter = 1
|
||||
dl_train_it = iter(dl_train)
|
||||
dl_val_it = iter(dl_val)
|
||||
epoch = 0
|
||||
epochs = 10
|
||||
|
||||
####
|
||||
train_classic(model=model, optim=torch.optim.Adam(model.parameters(), lr=0.001), epochs=epochs)
|
||||
model = LeNet(1,10)
|
||||
|
||||
meta_opt = torch.optim.Adam(opt_param.values(), lr=1e-2)
|
||||
inner_opt = torch.optim.SGD(model.parameters(), lr=opt_param['lr'], momentum=opt_param['momentum'])
|
||||
#for xs_val, ys_val in dl_val:
|
||||
while epoch < epochs:
|
||||
#print(data_aug.params["mag"], data_aug.params["mag"].grad)
|
||||
meta_opt.zero_grad()
|
||||
model.train()
|
||||
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
|
||||
|
||||
for param_group in diffopt.param_groups:
|
||||
param_group['lr'] = opt_param['lr']
|
||||
param_group['momentum'] = opt_param['momentum']
|
||||
|
||||
for i in range(n_inner_iter):
|
||||
try:
|
||||
xs, ys = next(dl_train_it)
|
||||
except StopIteration: #Fin epoch train
|
||||
epoch +=1
|
||||
dl_train_it = iter(dl_train)
|
||||
xs, ys = next(dl_train_it)
|
||||
|
||||
print('Epoch', epoch)
|
||||
print('train loss',loss.item(), '/ val loss', val_loss.item())
|
||||
print('acc', test(model))
|
||||
print('opt : lr', opt_param['lr'].item(), 'momentum', opt_param['momentum'].item())
|
||||
print('-'*9)
|
||||
model.train()
|
||||
|
||||
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
|
||||
#print('loss',loss.item())
|
||||
diffopt.step(loss) # note that `step` must take `loss` as an argument!
|
||||
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
|
||||
# these new parameters, as an alternative to getting them from
|
||||
# `fmodel.fast_params` or `fmodel.parameters()` after calling
|
||||
# `diffopt.step`.
|
||||
|
||||
# At this point, or at any point in the iteration, you can take the
|
||||
# gradient of `fmodel.parameters()` (or equivalently
|
||||
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
|
||||
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
|
||||
# `grad_fn` as an attribute, and be part of the gradient tape.
|
||||
|
||||
# At the end of your inner loop you can obtain these e.g. ...
|
||||
#grad_of_grads = torch.autograd.grad(
|
||||
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val_it)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
|
||||
val_logits = fmodel(xs_val)
|
||||
val_loss = F.cross_entropy(val_logits, ys_val)
|
||||
#print('val_loss',val_loss.item())
|
||||
|
||||
val_loss.backward()
|
||||
#meta_grads = torch.autograd.grad(val_loss, opt_lr, allow_unused=True)
|
||||
#print(meta_grads)
|
||||
for param_group in diffopt.param_groups:
|
||||
print(param_group['lr'], '/',param_group['lr'].grad)
|
||||
print(param_group['momentum'], '/',param_group['momentum'].grad)
|
||||
|
||||
#model=copy.deepcopy(fmodel)
|
||||
model.load_state_dict(fmodel.state_dict())
|
||||
|
||||
meta_opt.step()
|
866
higher/smart_aug/old/train_utils_old.py
Normal file
866
higher/smart_aug/old/train_utils_old.py
Normal file
|
@ -0,0 +1,866 @@
|
|||
import torch
|
||||
#import torch.optim
|
||||
import torchvision
|
||||
import higher
|
||||
|
||||
from datasets import *
|
||||
from utils import *
|
||||
|
||||
def train_classic_higher(model, epochs=1):
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
log = []
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False)
|
||||
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, diffopt):
|
||||
|
||||
for epoch in range(epochs):
|
||||
#print_torch_mem("Start epoch "+str(epoch))
|
||||
#print("Fast param ",len(fmodel._fast_params))
|
||||
t0 = time.process_time()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
#print_torch_mem("Start iter")
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
#optim.zero_grad()
|
||||
logits = model.forward(features)
|
||||
pred = F.log_softmax(logits, dim=1)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
#.backward()
|
||||
#optim.step()
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
model_copy(src=fmodel, dst=model, patch_copy=False)
|
||||
optim_copy(dopt=diffopt, opt=optim)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False)
|
||||
|
||||
#### Tests ####
|
||||
tf = time.process_time()
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
val_loss = F.cross_entropy(model(xs_val), ys_val)
|
||||
accuracy, _ =test(model)
|
||||
model.train()
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
return log
|
||||
|
||||
def train_classic_tests(model, epochs=1):
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
countcopy=0
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
log = []
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
doptim = higher.optim.get_diff_optim(optim, model.parameters(), fmodel=fmodel, track_higher_grads=False)
|
||||
for epoch in range(epochs):
|
||||
print_torch_mem("Start epoch")
|
||||
print(len(fmodel._fast_params))
|
||||
t0 = time.process_time()
|
||||
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=True) as (fmodel, doptim):
|
||||
|
||||
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
#doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True)
|
||||
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, doptim):
|
||||
|
||||
|
||||
#optim.zero_grad()
|
||||
pred = fmodel.forward(features)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
doptim.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
#loss.backward()
|
||||
#new_params = doptim.step(loss, params=fmodel.parameters())
|
||||
#fmodel.update_params(new_params)
|
||||
|
||||
|
||||
#print('Fast param',len(fmodel._fast_params))
|
||||
#print('opt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][2]['momentum_buffer'].shape)
|
||||
|
||||
if False or (len(fmodel._fast_params)>1):
|
||||
print("fmodel fast param",len(fmodel._fast_params))
|
||||
'''
|
||||
#val_loss = F.cross_entropy(fmodel(features), labels)
|
||||
|
||||
#print_graph(val_loss)
|
||||
|
||||
#val_loss.backward()
|
||||
#print('bip')
|
||||
|
||||
tmp = fmodel.parameters()
|
||||
|
||||
#print(list(tmp)[1])
|
||||
tmp = [higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in tmp]
|
||||
#print(len(tmp))
|
||||
|
||||
#fmodel._fast_params.clear()
|
||||
del fmodel._fast_params
|
||||
fmodel._fast_params=None
|
||||
|
||||
fmodel.fast_params=tmp # Surcharge la memoire
|
||||
#fmodel.update_params(tmp) #Meilleur perf / Surcharge la memoire avec trach higher grad
|
||||
|
||||
#optim._fmodel=fmodel
|
||||
'''
|
||||
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model, patch_copy=False)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
#doptim.detach_dyn()
|
||||
#tmp = doptim.state
|
||||
#tmp = doptim.state_dict()
|
||||
#for k, v in tmp['state'].items():
|
||||
# print('dict',k, type(v))
|
||||
|
||||
a = optim.param_groups[0]['params'][0]
|
||||
state = optim.state[a]
|
||||
#state['momentum_buffer'] = None
|
||||
#print('opt state', type(optim.state[a]), len(optim.state[a]))
|
||||
#optim.load_state_dict(tmp)
|
||||
|
||||
|
||||
for group_idx, group in enumerate(optim.param_groups):
|
||||
# print('gp idx',group_idx)
|
||||
for p_idx, p in enumerate(group['params']):
|
||||
optim.state[p]=doptim.state[group_idx][p_idx]
|
||||
|
||||
#print('opt state', type(optim.state[a]['momentum_buffer']), optim.state[a]['momentum_buffer'][0:10])
|
||||
#print('dopt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][0]['momentum_buffer'][0:10])
|
||||
'''
|
||||
for a in tmp:
|
||||
#print(type(a), len(a))
|
||||
for nb, b in a.items():
|
||||
#print(nb, type(b), len(b))
|
||||
for n, state in b.items():
|
||||
#print(n, type(states))
|
||||
#print(state.grad_fn)
|
||||
state = torch.tensor(state.data).requires_grad_()
|
||||
#print(state.grad_fn)
|
||||
'''
|
||||
|
||||
|
||||
doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True)
|
||||
#doptim.state = tmp
|
||||
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
#### Tests ####
|
||||
tf = time.process_time()
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
val_loss = F.cross_entropy(model(xs_val), ys_val)
|
||||
accuracy, _ =test(model)
|
||||
model.train()
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
#countcopy+=1
|
||||
#model_copy(src=fmodel, dst=model, patch_copy=False)
|
||||
#optim.load_state_dict(doptim.state_dict()) #Besoin sauver etat otpim ?
|
||||
|
||||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
from PIL import Image
|
||||
import augmentation_transforms
|
||||
import numpy as np
|
||||
class AugmentedDatasetV2(VisionDataset):
|
||||
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None):
|
||||
|
||||
super(AugmentedDatasetV2, self).__init__(root, transform=transform, target_transform=target_transform)
|
||||
|
||||
supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform)
|
||||
|
||||
self.sup_data = supervised_dataset.data if not subset else supervised_dataset.data[subset[0]:subset[1]]
|
||||
self.sup_targets = supervised_dataset.targets if not subset else supervised_dataset.targets[subset[0]:subset[1]]
|
||||
assert len(self.sup_data)==len(self.sup_targets)
|
||||
|
||||
for idx, img in enumerate(self.sup_data):
|
||||
self.sup_data[idx]= Image.fromarray(img) #to PIL Image
|
||||
|
||||
self.unsup_data=[]
|
||||
self.unsup_targets=[]
|
||||
self.origin_idx=[]
|
||||
|
||||
self.dataset_info= {
|
||||
'name': 'CIFAR10',
|
||||
'sup': len(self.sup_data),
|
||||
'unsup': len(self.unsup_data),
|
||||
'length': len(self.sup_data)+len(self.unsup_data),
|
||||
}
|
||||
|
||||
|
||||
self._TF = [
|
||||
## Geometric TF ##
|
||||
'Rotate',
|
||||
'TranslateX',
|
||||
'TranslateY',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
|
||||
'Cutout',
|
||||
|
||||
## Color TF ##
|
||||
'Contrast',
|
||||
'Color',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'Posterize',
|
||||
'Solarize',
|
||||
|
||||
'Invert',
|
||||
'AutoContrast',
|
||||
'Equalize',
|
||||
]
|
||||
self._op_list =[]
|
||||
self.prob=0.5
|
||||
self.mag_range=(1, 10)
|
||||
for tf in self._TF:
|
||||
for mag in range(self.mag_range[0], self.mag_range[1]):
|
||||
self._op_list+=[(tf, self.prob, mag)]
|
||||
self._nb_op = len(self._op_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target class.
|
||||
"""
|
||||
aug_img, origin_img, target = self.unsup_data[index], self.sup_data[self.origin_idx[index]], self.unsup_targets[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
#img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
aug_img = self.transform(aug_img)
|
||||
origin_img = self.transform(origin_img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return aug_img, origin_img, target
|
||||
|
||||
def augement_data(self, aug_copy=1):
|
||||
|
||||
policies = []
|
||||
for op_1 in self._op_list:
|
||||
for op_2 in self._op_list:
|
||||
policies += [[op_1, op_2]]
|
||||
|
||||
for idx, image in enumerate(self.sup_data):
|
||||
if idx%(self.dataset_info['sup']/5)==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup'])
|
||||
#if idx==10000:break
|
||||
|
||||
for _ in range(aug_copy):
|
||||
chosen_policy = policies[np.random.choice(len(policies))]
|
||||
aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image
|
||||
#aug_image = augmentation_transforms.cutout_numpy(aug_image)
|
||||
|
||||
self.unsup_data+=[(aug_image*255.).astype(self.sup_data.dtype)]#Cast float image to uint8
|
||||
self.unsup_targets+=[self.sup_targets[idx]]
|
||||
self.origin_idx+=[idx]
|
||||
|
||||
#self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8
|
||||
self.unsup_data=np.array(self.unsup_data)
|
||||
|
||||
assert len(self.unsup_data)==len(self.unsup_targets)
|
||||
|
||||
self.dataset_info['unsup']=len(self.unsup_data)
|
||||
self.dataset_info['length']=self.dataset_info['sup']+self.dataset_info['unsup']
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_info['unsup']#self.dataset_info['length']
|
||||
|
||||
def __str__(self):
|
||||
return "CIFAR10(Sup:{}-Unsup:{}-{}TF(Mag{}-{}))".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF), self.mag_range[0], self.mag_range[1])
|
||||
|
||||
def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1):
|
||||
"""Training of a model using UDA inspired approach.
|
||||
|
||||
Intended to be used alongside an already augmented dataset (see AugmentedDatasetV2).
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to train.
|
||||
dl_unsup (Dataloader): Data loader of unsupervised/augmented data.
|
||||
opt_param (dict): Dictionnary containing optimizers parameters.
|
||||
epochs (int): Number of epochs to perform. (default: 1)
|
||||
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
|
||||
Returns:
|
||||
(list) Logs of training. Each items is a dict containing results of an epoch.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
opt = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
dl_unsup_it =iter(dl_unsup)
|
||||
log = []
|
||||
for epoch in range(epochs):
|
||||
#print_torch_mem("Start epoch")
|
||||
t0 = time.process_time()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
#print_torch_mem("Start iter")
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
optim.zero_grad()
|
||||
#Supervised
|
||||
logits = model.forward(features)
|
||||
pred = F.log_softmax(logits, dim=1)
|
||||
sup_loss = F.cross_entropy(pred,labels)
|
||||
|
||||
#Unsupervised
|
||||
try:
|
||||
aug_xs, origin_xs, ys = next(dl_unsup_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_unsup_it =iter(dl_unsup)
|
||||
aug_xs, origin_xs, ys = next(dl_unsup_it)
|
||||
aug_xs, origin_xs, ys = aug_xs.to(device), origin_xs.to(device), ys.to(device)
|
||||
|
||||
#print(aug_xs.shape, origin_xs.shape, ys.shape)
|
||||
sup_logits = model.forward(origin_xs)
|
||||
unsup_logits = model.forward(aug_xs)
|
||||
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
log_unsup=F.log_softmax(unsup_logits, dim=1)
|
||||
#KL div w/ logits
|
||||
unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup)
|
||||
unsup_loss=unsup_loss.sum(dim=-1).mean()
|
||||
|
||||
#print(unsup_loss)
|
||||
unsupp_coeff = 1
|
||||
loss = sup_loss + unsup_loss * unsupp_coeff
|
||||
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
#### Tests ####
|
||||
tf = time.process_time()
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
val_loss = F.cross_entropy(model(xs_val), ys_val)
|
||||
accuracy, _ =test(model)
|
||||
model.train()
|
||||
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Sup Loss :', sup_loss.item(), '/ unsup_loss :', unsup_loss.item())
|
||||
print('Accuracy :', accuracy)
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
return log
|
||||
|
||||
|
||||
def run_simple_dataug(inner_it, epochs=1):
|
||||
device = next(model.parameters()).device
|
||||
dl_train_it = iter(dl_train)
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
#aug_model = nn.Sequential(
|
||||
# Data_aug(),
|
||||
# LeNet(1,10),
|
||||
# )
|
||||
aug_model = Augmented_model(Data_aug(), LeNet(1,10)).to(device)
|
||||
print(str(aug_model))
|
||||
meta_opt = torch.optim.Adam(aug_model['data_aug'].parameters(), lr=1e-2)
|
||||
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
log = []
|
||||
t0 = time.process_time()
|
||||
|
||||
epoch = 0
|
||||
while epoch < epochs:
|
||||
meta_opt.zero_grad()
|
||||
aug_model.train()
|
||||
with higher.innerloop_ctx(aug_model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
|
||||
|
||||
for i in range(n_inner_iter):
|
||||
try:
|
||||
xs, ys = next(dl_train_it)
|
||||
except StopIteration: #Fin epoch train
|
||||
tf = time.process_time()
|
||||
epoch +=1
|
||||
dl_train_it = iter(dl_train)
|
||||
xs, ys = next(dl_train_it)
|
||||
|
||||
accuracy, _ =test(model)
|
||||
aug_model.train()
|
||||
|
||||
#### Print ####
|
||||
print('-'*9)
|
||||
print('Epoch %d/%d'%(epoch,epochs))
|
||||
print('train loss',loss.item(), '/ val loss', val_loss.item())
|
||||
print('acc', accuracy)
|
||||
print('mag', aug_model['data_aug']['mag'].item())
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": aug_model['data_aug']['mag'].item(),
|
||||
}
|
||||
log.append(data)
|
||||
t0 = time.process_time()
|
||||
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
|
||||
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
|
||||
#loss.backward(retain_graph=True)
|
||||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('mag', fmodel['data_aug']['mag'].grad)
|
||||
|
||||
diffopt.step(loss) # note that `step` must take `loss` as an argument!
|
||||
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
|
||||
# these new parameters, as an alternative to getting them from
|
||||
# `fmodel.fast_params` or `fmodel.parameters()` after calling
|
||||
# `diffopt.step`.
|
||||
|
||||
# At this point, or at any point in the iteration, you can take the
|
||||
# gradient of `fmodel.parameters()` (or equivalently
|
||||
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
|
||||
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
|
||||
# `grad_fn` as an attribute, and be part of the gradient tape.
|
||||
|
||||
# At the end of your inner loop you can obtain these e.g. ...
|
||||
#grad_of_grads = torch.autograd.grad(
|
||||
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
fmodel.augment(mode=False)
|
||||
val_logits = fmodel(xs_val) #Validation sans transfornations !
|
||||
val_loss = F.cross_entropy(val_logits, ys_val)
|
||||
#print('val_loss',val_loss.item())
|
||||
val_loss.backward()
|
||||
|
||||
#print('mag', fmodel['data_aug']['mag'], '/', fmodel['data_aug']['mag'].grad)
|
||||
|
||||
#model=copy.deepcopy(fmodel)
|
||||
aug_model.load_state_dict(fmodel.state_dict()) #Do not copy gradient !
|
||||
#Copie des gradients
|
||||
for paramName, paramValue, in fmodel.named_parameters():
|
||||
for netCopyName, netCopyValue, in aug_model.named_parameters():
|
||||
if paramName == netCopyName:
|
||||
netCopyValue.grad = paramValue.grad
|
||||
|
||||
#print('mag', aug_model['data_aug']['mag'], '/', aug_model['data_aug']['mag'].grad)
|
||||
meta_opt.step()
|
||||
|
||||
plot_res(log, fig_name="res/{}-{} epochs- {} in_it".format(str(aug_model),epochs,inner_it))
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
print(str(aug_model),": acc", max([x["acc"] for x in log]), "in (ms):", np.mean(times), "+/-", np.std(times))
|
||||
|
||||
def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
|
||||
device = next(model.parameters()).device
|
||||
dl_train_it = iter(dl_train)
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-3)
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
high_grad_track = True
|
||||
if dataug_epoch_start>0:
|
||||
model.augment(mode=False)
|
||||
high_grad_track = False
|
||||
|
||||
model.train()
|
||||
|
||||
log = []
|
||||
t0 = time.process_time()
|
||||
|
||||
countcopy=0
|
||||
val_loss=torch.tensor(0)
|
||||
opt_param=None
|
||||
|
||||
epoch = 0
|
||||
while epoch < epochs:
|
||||
meta_opt.zero_grad()
|
||||
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
|
||||
|
||||
for i in range(n_inner_iter):
|
||||
try:
|
||||
xs, ys = next(dl_train_it)
|
||||
except StopIteration: #Fin epoch train
|
||||
tf = time.process_time()
|
||||
epoch +=1
|
||||
dl_train_it = iter(dl_train)
|
||||
xs, ys = next(dl_train_it)
|
||||
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
||||
accuracy, _ =test(model)
|
||||
model.train()
|
||||
|
||||
#### Print ####
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy :', accuracy)
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
#print('proba grad',aug_model['data_aug']['prob'].grad)
|
||||
#############
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": [p for p in model['data_aug']['prob']],
|
||||
}
|
||||
log.append(data)
|
||||
#############
|
||||
|
||||
if epoch == dataug_epoch_start:
|
||||
print('Starting Data Augmention...')
|
||||
model.augment(mode=True)
|
||||
high_grad_track = True
|
||||
|
||||
t0 = time.process_time()
|
||||
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
'''
|
||||
#Methode exacte
|
||||
final_loss = 0
|
||||
for tf_idx in range(fmodel['data_aug']._nb_tf):
|
||||
fmodel['data_aug'].transf_idx=tf_idx
|
||||
logits = fmodel(xs)
|
||||
loss = F.cross_entropy(logits, ys)
|
||||
#loss.backward(retain_graph=True)
|
||||
#print('idx', tf_idx)
|
||||
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad)
|
||||
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
|
||||
loss = final_loss
|
||||
'''
|
||||
#Methode uniforme
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight().to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
#'''
|
||||
|
||||
#to visualize computational graph
|
||||
#print_graph(loss)
|
||||
|
||||
#loss.backward(retain_graph=True)
|
||||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
fmodel.augment(mode=False) #Validation sans transfornations !
|
||||
val_loss = F.cross_entropy(fmodel(xs_val), ys_val)
|
||||
|
||||
#print_graph(val_loss)
|
||||
|
||||
val_loss.backward()
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param() #Contrainte sum(proba)=1
|
||||
|
||||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
|
||||
device = next(model.parameters()).device
|
||||
log = []
|
||||
countcopy=0
|
||||
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
#if inner_it!=0:
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
high_grad_track = True
|
||||
if inner_it == 0:
|
||||
high_grad_track=False
|
||||
if dataug_epoch_start!=0:
|
||||
model.augment(mode=False)
|
||||
high_grad_track = False
|
||||
|
||||
val_loss_monitor= None
|
||||
if loss_patience != None :
|
||||
if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start
|
||||
else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor (Not on val data : used by Dataug... => Test data)
|
||||
|
||||
model.train()
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
meta_opt.zero_grad()
|
||||
|
||||
for epoch in range(1, epochs+1):
|
||||
#print_torch_mem("Start epoch "+str(epoch))
|
||||
#print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params))
|
||||
t0 = time.process_time()
|
||||
#with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt):
|
||||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
#Methode exacte
|
||||
#final_loss = 0
|
||||
#for tf_idx in range(fmodel['data_aug']._nb_tf):
|
||||
# fmodel['data_aug'].transf_idx=tf_idx
|
||||
# logits = fmodel(xs)
|
||||
# loss = F.cross_entropy(logits, ys)
|
||||
# #loss.backward(retain_graph=True)
|
||||
# final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
#loss = final_loss
|
||||
|
||||
if(not KLdiv):
|
||||
#Methode uniforme
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
|
||||
else:
|
||||
#Methode KL div
|
||||
if fmodel._data_augmentation :
|
||||
fmodel.augment(mode=False)
|
||||
sup_logits = fmodel(xs)
|
||||
fmodel.augment(mode=True)
|
||||
else:
|
||||
sup_logits = fmodel(xs)
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
if fmodel._data_augmentation:
|
||||
aug_logits = fmodel(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
|
||||
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
#if epoch>50: #debut differe ?
|
||||
#KL div w/ logits - Similarite predictions (distributions)
|
||||
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
|
||||
aug_loss = aug_loss.sum(dim=-1)
|
||||
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
|
||||
aug_loss = (w_loss * aug_loss).mean()
|
||||
|
||||
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||
|
||||
unsupp_coeff = 1
|
||||
loss += aug_loss * unsupp_coeff
|
||||
|
||||
#to visualize computational graph
|
||||
#print_graph(loss)
|
||||
|
||||
#loss.backward(retain_graph=True)
|
||||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||
|
||||
#t = time.process_time()
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
#print(len(fmodel._fast_params),"step", time.process_time()-t)
|
||||
|
||||
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
|
||||
#print("meta")
|
||||
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #+ fmodel['data_aug'].reg_loss()
|
||||
#print_graph(val_loss)
|
||||
|
||||
#t = time.process_time()
|
||||
val_loss.backward()
|
||||
#print("meta", time.process_time()-t)
|
||||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
if model['data_aug']['prob'].grad is None or model['data_aug']['mag'] is None:
|
||||
print("Warning no grad (iter",i,") :\n Prob-",model['data_aug']['prob'].grad,"\n Mag-", model['data_aug']['mag'].grad)
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
#if epoch>50:
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
try: #Dataugv6
|
||||
model['data_aug'].next_TF_set()
|
||||
except:
|
||||
pass
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
meta_opt.zero_grad()
|
||||
|
||||
tf = time.process_time()
|
||||
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
|
||||
|
||||
if(not high_grad_track):
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
accuracy, test_loss =test(model)
|
||||
model.train()
|
||||
|
||||
#### Log ####
|
||||
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": param #if isinstance(model['data_aug'], Data_augV5)
|
||||
#else [p.item() for p in model['data_aug']['prob']],
|
||||
}
|
||||
log.append(data)
|
||||
#############
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy :', max([x["acc"] for x in log]))
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
print('TF Mag :', model['data_aug']['mag'].data)
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
#print('Aug loss', aug_loss.item())
|
||||
#############
|
||||
if val_loss_monitor :
|
||||
model.eval()
|
||||
val_loss_monitor.register(test_loss)#val_loss.item())
|
||||
if val_loss_monitor.end_training(): break #Stop training
|
||||
model.train()
|
||||
|
||||
if not model.is_augmenting() and (epoch == dataug_epoch_start or (val_loss_monitor and val_loss_monitor.limit_reached()==1)):
|
||||
print('Starting Data Augmention...')
|
||||
dataug_epoch_start = epoch
|
||||
model.augment(mode=True)
|
||||
if inner_it != 0: high_grad_track = True
|
||||
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
|
||||
except:
|
||||
print("Couldn't save finals samples")
|
||||
pass
|
||||
|
||||
#print("Copy ", countcopy)
|
||||
return log
|
161
higher/smart_aug/old/utils_old.py
Normal file
161
higher/smart_aug/old/utils_old.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
import numpy as np
|
||||
import json, math, time, os
|
||||
import matplotlib.pyplot as plt
|
||||
import copy
|
||||
import gc
|
||||
|
||||
from torchviz import make_dot
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import time
|
||||
|
||||
class timer():
|
||||
def __init__(self):
|
||||
self._start_time=time.time()
|
||||
def exec_time(self):
|
||||
end = time.time()
|
||||
res = end-self._start_time
|
||||
self._start_time=end
|
||||
return res
|
||||
|
||||
def plot_res(log, fig_name='res', param_names=None):
|
||||
|
||||
epochs = [x["epoch"] for x in log]
|
||||
|
||||
fig, ax = plt.subplots(ncols=3, figsize=(15, 3))
|
||||
|
||||
ax[0].set_title('Loss')
|
||||
ax[0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
||||
ax[0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||||
ax[0].legend()
|
||||
|
||||
ax[1].set_title('Acc')
|
||||
ax[1].plot(epochs,[x["acc"] for x in log])
|
||||
|
||||
if log[0]["param"]!= None:
|
||||
if isinstance(log[0]["param"],float):
|
||||
ax[2].set_title('Mag')
|
||||
ax[2].plot(epochs,[x["param"] for x in log], label='Mag')
|
||||
ax[2].legend()
|
||||
else :
|
||||
ax[2].set_title('Prob')
|
||||
#for idx, _ in enumerate(log[0]["param"]):
|
||||
#ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx))
|
||||
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
||||
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
ax[2].stackplot(epochs, proba, labels=param_names)
|
||||
ax[2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name)
|
||||
plt.close()
|
||||
|
||||
def plot_res_compare(filenames, fig_name='res'):
|
||||
|
||||
all_data=[]
|
||||
#legend=""
|
||||
for idx, file in enumerate(filenames):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
all_data.append(data)
|
||||
|
||||
n_tf = [len(x["Param_names"]) for x in all_data]
|
||||
acc = [x["Accuracy"] for x in all_data]
|
||||
time = [x["Time"][0] for x in all_data]
|
||||
|
||||
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
|
||||
|
||||
ax[0].plot(n_tf, acc)
|
||||
ax[1].plot(n_tf, time)
|
||||
|
||||
ax[0].set_title('Acc')
|
||||
ax[1].set_title('Time')
|
||||
#for a in ax: a.legend()
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def plot_TF_res(log, tf_names, fig_name='res'):
|
||||
|
||||
mean = np.mean([x["param"] for x in log], axis=0)
|
||||
std = np.std([x["param"] for x in log], axis=0)
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(30, 8), sharey=True)
|
||||
ax.bar(tf_names, mean, yerr=std)
|
||||
#ax.bar(tf_names, log[-1]["param"])
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def model_copy(src,dst, patch_copy=True, copy_grad=True):
|
||||
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
|
||||
|
||||
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
|
||||
|
||||
if patch_copy:
|
||||
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
|
||||
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
|
||||
|
||||
#Copie des gradients
|
||||
if copy_grad:
|
||||
for paramName, paramValue, in src.named_parameters():
|
||||
for netCopyName, netCopyValue, in dst.named_parameters():
|
||||
if paramName == netCopyName:
|
||||
netCopyValue.grad = paramValue.grad
|
||||
#netCopyValue=copy.deepcopy(paramValue)
|
||||
|
||||
try: #Data_augV4
|
||||
dst['data_aug']._input_info = src['data_aug']._input_info
|
||||
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
|
||||
except:
|
||||
pass
|
||||
|
||||
def optim_copy(dopt, opt):
|
||||
|
||||
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
|
||||
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
|
||||
|
||||
for group_idx, group in enumerate(opt.param_groups):
|
||||
# print('gp idx',group_idx)
|
||||
for p_idx, p in enumerate(group['params']):
|
||||
opt.state[p]=dopt.state[group_idx][p_idx]
|
||||
|
||||
class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
||||
def __init__(self, patience, end_train=1):
|
||||
self.patience = patience
|
||||
self.end_train = end_train
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.reached_limit = 0
|
||||
|
||||
def register(self, loss):
|
||||
if self.best_score is None:
|
||||
self.best_score = loss
|
||||
elif loss > self.best_score:
|
||||
self.counter += 1
|
||||
#if not self.reached_limit:
|
||||
print("loss no improve counter", self.counter, self.reached_limit)
|
||||
else:
|
||||
self.best_score = loss
|
||||
self.counter = 0
|
||||
def limit_reached(self):
|
||||
if self.counter >= self.patience:
|
||||
self.counter = 0
|
||||
self.reached_limit +=1
|
||||
self.best_score = None
|
||||
return self.reached_limit
|
||||
|
||||
def end_training(self):
|
||||
if self.limit_reached() >= self.end_train:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def reset(self):
|
||||
self.__init__(self.patience, self.end_train)
|
169
higher/smart_aug/process_res.py
Normal file
169
higher/smart_aug/process_res.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
from utils import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
'''
|
||||
# files=[
|
||||
# "../res/log/Aug_mod(Data_augV5(Mix0.5-18TFx3-Mag)-efficientnet-b1)-200 epochs (dataug 0)- 1 in_it__AL2.json",
|
||||
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-0.json",
|
||||
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-1.json",
|
||||
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-2.json",
|
||||
# #"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
# ]
|
||||
files = ["../res/benchmark/CIFAR10/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,1,'resnet18', str(run)) for run in range(3)]
|
||||
#files = ["../res/benchmark/CIFAR10/log/Aug_mod(RandAug(18TFx%d-Mag%d)-%s)-200 epochs (dataug:0)- 0 in_it-%s.json"%(2,1,'resnet18', str(run)) for run in range(3)]
|
||||
#files = ["../res/benchmark/CIFAR10/log/Aug_mod(Data_augV5(Mix%.1f-18TFx%d-Mag)-%s)-200 epochs (dataug:0)- 1 in_it-%s.json"%(0.5,3,'resnet18', str(run)) for run in range(3)]
|
||||
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
plot_resV2(data['Log'], fig_name=file.replace("/log","").replace(".json",""), param_names=data['Param_names'], f1=True)
|
||||
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
||||
'''
|
||||
|
||||
#Res print
|
||||
# '''
|
||||
nb_run=3
|
||||
accs = []
|
||||
aug_accs = []
|
||||
f1_max = []
|
||||
f1_min = []
|
||||
times = []
|
||||
mem = []
|
||||
|
||||
files = ["../res/benchmark/log/Aug_mod(Data_augV5(T0.5-19TFx3-Mag)-resnet18)-200 epochs (dataug 0)- 1 in_it__%s.json"%(str(run)) for run in range(1, nb_run+1)]
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
accs.append(data['Accuracy'])
|
||||
aug_accs.append(data['Aug_Accuracy'][1])
|
||||
times.append(data['Time'][0])
|
||||
mem.append(data['Memory'][1])
|
||||
|
||||
acc_idx = [x['acc'] for x in data['Log']].index(data['Accuracy'])
|
||||
f1_max.append(max(data['Log'][acc_idx]['f1'])*100)
|
||||
f1_min.append(min(data['Log'][acc_idx]['f1'])*100)
|
||||
print(idx, accs[-1], aug_accs[-1])
|
||||
|
||||
print(files[0])
|
||||
print("Acc : %.2f ~ %.2f / Aug_Acc %d: %.2f ~ %.2f"%(np.mean(accs), np.std(accs), data['Aug_Accuracy'][0], np.mean(aug_accs), np.std(aug_accs)))
|
||||
print("F1 max : %.2f ~ %.2f / F1 min : %.2f ~ %.2f"%(np.mean(f1_max), np.std(f1_max), np.mean(f1_min), np.std(f1_min)))
|
||||
print("Time (h): %.2f ~ %.2f"%(np.mean(times)/3600, np.std(times)/3600))
|
||||
print("Mem (MB): %d ~ %d"%(np.mean(mem), np.std(mem)))
|
||||
# '''
|
||||
|
||||
## Loss , Acc, Proba = f(epoch) ##
|
||||
#plot_compare(filenames=files, fig_name="res/compare")
|
||||
|
||||
'''
|
||||
## Acc, Time, Epochs = f(n_tf) ##
|
||||
#fig_name="res/TF_nb_tests_compare"
|
||||
fig_name="res/TF_seq_tests_compare"
|
||||
inner_its = [0, 10]
|
||||
dataug_epoch_starts= [0]
|
||||
TF_nb = 14#[len(TF.TF_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
|
||||
N_seq_TF= [1, 2, 3, 4, 6] #[1]
|
||||
|
||||
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
|
||||
for in_it in inner_its:
|
||||
for dataug in dataug_epoch_starts:
|
||||
|
||||
#n_tf = TF_nb
|
||||
#filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF)-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(n_tf, dataug, in_it) for n_tf in TF_nb]
|
||||
#filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(n_tf, 1, dataug, in_it) for n_tf in TF_nb]
|
||||
|
||||
n_tf = N_seq_TF
|
||||
#filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(TF_nb, n_tf, dataug, in_it) for n_tf in N_seq_TF]
|
||||
filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(TF_nb, n_tf, dataug, in_it) for n_tf in N_seq_TF]
|
||||
|
||||
|
||||
all_data=[]
|
||||
#legend=""
|
||||
for idx, file in enumerate(filenames):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
all_data.append(data)
|
||||
|
||||
acc = [x["Accuracy"] for x in all_data]
|
||||
epochs = [len(x["Log"]) for x in all_data]
|
||||
time = [x["Time"][0] for x in all_data]
|
||||
#for i in range(len(time)): time[i] *= epochs[i] #Estimation temps total
|
||||
|
||||
ax[0].plot(n_tf, acc, label="{} in_it/{} dataug".format(in_it,dataug))
|
||||
ax[1].plot(n_tf, time, label="{} in_it/{} dataug".format(in_it,dataug))
|
||||
ax[2].plot(n_tf, epochs, label="{} in_it/{} dataug".format(in_it,dataug))
|
||||
|
||||
|
||||
#for data in all_data:
|
||||
#print(np.mean([x["param"] for x in data["Log"]], axis=0))
|
||||
#print(len(data["Param_names"]), np.argsort(np.argsort(np.mean([x["param"] for x in data["Log"]], axis=0))))
|
||||
|
||||
|
||||
ax[0].set_title('Acc')
|
||||
ax[1].set_title('Time')
|
||||
ax[2].set_title('Epochs')
|
||||
for a in ax: a.legend()
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
'''
|
||||
|
||||
|
||||
|
||||
'''
|
||||
#HP search
|
||||
inner_its = [1]
|
||||
dist_mix = [0.3, 0.5, 0.8, 1.0] #Uniform
|
||||
N_seq_TF= [3]
|
||||
nb_run= 3
|
||||
|
||||
for n_inner_iter in inner_its:
|
||||
for n_tf in N_seq_TF:
|
||||
for dist in dist_mix:
|
||||
|
||||
files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-Mag)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(dist, n_tf, str(run)) for run in range(nb_run)]
|
||||
#files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Uniform-14TFx%d-MagFxSh)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(n_tf, str(run)) for run in range(nb_run)]
|
||||
accs = []
|
||||
times = []
|
||||
for idx, file in enumerate(files):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
accs.append(data['Accuracy'])
|
||||
times.append(data['Time'][0])
|
||||
print(idx, data['Accuracy'])
|
||||
|
||||
print(files[0], 'acc', np.mean(accs), '+-',np.std(accs), ',t', np.mean(times))
|
||||
'''
|
||||
|
||||
'''
|
||||
#Benchmark
|
||||
model_list=['resnet18', 'resnet50','wide_resnet50_2']
|
||||
nb_run= 3
|
||||
|
||||
for model_name in model_list:
|
||||
|
||||
files = ["../res/benchmark/CIFAR100/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,0.17,model_name, str(run)) for run in range(nb_run)]
|
||||
#files = ["../res/benchmark/CIFAR10/log/%s-200 epochs -%s.json"%(model_name, str(run)) for run in range(nb_run)]
|
||||
|
||||
accs = []
|
||||
times = []
|
||||
mem_alloc = []
|
||||
mem_cach = []
|
||||
for idx, file in enumerate(files):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
accs.append(data['Accuracy'])
|
||||
times.append(data['Time'][0])
|
||||
mem_cach.append(data['Memory'])
|
||||
print(idx, data['Accuracy'])
|
||||
|
||||
print(files[0], 'acc', np.mean(accs), '+-',np.std(accs), ',t', np.mean(times), 'Mem', np.mean(mem_cach))
|
||||
'''
|
59
higher/smart_aug/smart_aug_example.py
Normal file
59
higher/smart_aug/smart_aug_example.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
""" Example use of smart augmentation.
|
||||
|
||||
"""
|
||||
|
||||
from LeNet import *
|
||||
from dataug import *
|
||||
from train_utils import *
|
||||
|
||||
tf_config='../config/base_tf_config.json'
|
||||
TF_loader=TF_loader()
|
||||
|
||||
device = torch.device('cuda') #Select device to use
|
||||
|
||||
if device == torch.device('cpu'):
|
||||
device_name = 'CPU'
|
||||
else:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
#Parameters
|
||||
n_inner_iter = 1
|
||||
epochs = 150
|
||||
optim_param={
|
||||
'Meta':{
|
||||
'optim':'Adam',
|
||||
'lr':1e-2, #1e-2
|
||||
},
|
||||
'Inner':{
|
||||
'optim': 'SGD',
|
||||
'lr':1e-2, #1e-2/1e-1 (ResNet)
|
||||
'momentum':0.9, #0.9
|
||||
'decay':0.0005, #0.0005
|
||||
'nesterov':False, #False (True: Bad behavior w/ Data_aug)
|
||||
'scheduler':'cosine', #None, 'cosine', 'multiStep', 'exponential'
|
||||
}
|
||||
}
|
||||
|
||||
#Models
|
||||
model = LeNet(3,10)
|
||||
|
||||
#Smart_aug initialisation
|
||||
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(tf_config)
|
||||
model = Higher_model(model) #run_dist_dataugV3
|
||||
aug_model = Augmented_model(
|
||||
Data_augV5(TF_dict=tf_dict,
|
||||
N_TF=3,
|
||||
mix_dist=0.8,
|
||||
fixed_prob=False,
|
||||
fixed_mag=False,
|
||||
shared_mag=False,
|
||||
TF_ignore_mag=tf_ignore_mag),
|
||||
model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||
|
||||
# Training
|
||||
trained_model = run_simple_smartaug(model=aug_model, epochs=epochs, inner_it=n_inner_iter, opt_param=optim_param)
|
204
higher/smart_aug/test_dataug.py
Normal file
204
higher/smart_aug/test_dataug.py
Normal file
|
@ -0,0 +1,204 @@
|
|||
""" Script to run experiment on smart augmentation.
|
||||
|
||||
"""
|
||||
import sys
|
||||
from dataug import *
|
||||
#from utils import *
|
||||
from train_utils import *
|
||||
from transformations import TF_loader
|
||||
# from arg_parser import *
|
||||
|
||||
TF_loader=TF_loader()
|
||||
|
||||
torch.backends.cudnn.benchmark = True #Faster if same input size #Not recommended for reproductibility
|
||||
|
||||
#Increase reproductibility
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
res_folder=args.res_folder
|
||||
postfix=args.postfix
|
||||
|
||||
if args.dtype == 'FP32':
|
||||
def_type=torch.float32
|
||||
elif args.dtype == 'FP16':
|
||||
# def_type=torch.float16 #Default : float32
|
||||
def_type=torch.bfloat16
|
||||
else:
|
||||
raise Exception('dtype not supported :', args.dtype)
|
||||
torch.set_default_dtype(def_type) #Default : float32
|
||||
|
||||
|
||||
device = torch.device(args.device) #Select device to use
|
||||
if device == torch.device('cpu'):
|
||||
device_name = 'CPU'
|
||||
else:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
#Parameters
|
||||
n_inner_iter = args.K
|
||||
epochs = args.epochs
|
||||
dataug_epoch_start=0
|
||||
Nb_TF_seq= args.N
|
||||
optim_param={
|
||||
'Meta':{
|
||||
'optim':'Adam',
|
||||
'lr':args.mlr,
|
||||
'epoch_start': args.meta_epoch_start, #0 / 2 (Resnet?)
|
||||
'reg_factor': args.mag_reg,
|
||||
'scheduler': None, #None, 'multiStep'
|
||||
},
|
||||
'Inner':{
|
||||
'optim': 'SGD',
|
||||
'lr':args.lr, #1e-2/1e-1 (ResNet)
|
||||
'momentum':args.momentum, #0.9
|
||||
'weight_decay':args.decay, #0.0005
|
||||
'nesterov':args.nesterov, #False (True: Bad behavior w/ Data_aug)
|
||||
'scheduler': args.scheduler, #None, 'cosine', 'multiStep', 'exponential'
|
||||
'warmup':{
|
||||
'multiplier': args.warmup, #2 #+ batch_size => + mutliplier #No warmup = 0
|
||||
'epochs': 5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#Info params
|
||||
F1=True
|
||||
sample_save=None
|
||||
print_f= epochs/4
|
||||
|
||||
#Load network
|
||||
model, model_name= load_model(args.model, num_classes=len(dl_train.dataset.classes), pretrained=args.pretrained)
|
||||
|
||||
#### Classic ####
|
||||
if not args.augment:
|
||||
if device_name != 'CPU':
|
||||
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
|
||||
torch.cuda.reset_max_memory_cached() #reset_peak_stats
|
||||
t0 = time.perf_counter()
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
|
||||
print("{} on {} for {} epochs{}".format(model_name, device_name, epochs, postfix))
|
||||
#print("RandAugment(N{}-M{:.2f})-{} on {} for {} epochs{}".format(rand_aug['N'],rand_aug['M'],model_name, device_name, epochs, postfix))
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=print_f)
|
||||
#log= train_classic_higher(model=model, epochs=epochs)
|
||||
|
||||
exec_time=time.perf_counter() - t0
|
||||
|
||||
if device_name != 'CPU':
|
||||
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
|
||||
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
|
||||
else:
|
||||
max_allocated = 0.0
|
||||
max_cached=0.0
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]),
|
||||
"Time": (np.mean(times),np.std(times), exec_time),
|
||||
'Optimizer': optim_param['Inner'],
|
||||
"Device": device_name,
|
||||
"Memory": [max_allocated, max_cached],
|
||||
#"Rand_Aug": rand_aug,
|
||||
"Log": log}
|
||||
print(model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs".format(model_name,epochs)+postfix
|
||||
#print("RandAugment-",model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
#filename = "RandAugment(N{}-M{:.2f})-{}-{} epochs".format(rand_aug['N'],rand_aug['M'],model_name,epochs)+postfix
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
try:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
except:
|
||||
print("Failed to save logs :",f.name)
|
||||
print(sys.exc_info()[1])
|
||||
|
||||
try:
|
||||
plot_resV2(log, fig_name=res_folder+filename, f1=F1)
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
print(sys.exc_info()[1])
|
||||
|
||||
print('Execution Time (s): %.00f '%(exec_time))
|
||||
print('-'*9)
|
||||
|
||||
#### Augmented Model ####
|
||||
else:
|
||||
# tf_config='../config/invScale_wide_tf_config.json'#'../config/invScale_wide_tf_config.json'#'../config/base_tf_config.json'
|
||||
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(args.tf_config)
|
||||
|
||||
if device_name != 'CPU':
|
||||
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
|
||||
torch.cuda.reset_max_memory_cached() #reset_peak_stats
|
||||
t0 = time.perf_counter()
|
||||
|
||||
model = Higher_model(model, model_name) #run_dist_dataugV3
|
||||
dataug_mod = 'Data_augV8' if args.learn_seq else 'Data_augV5'
|
||||
if n_inner_iter !=0:
|
||||
aug_model = Augmented_model(
|
||||
globals()[dataug_mod](TF_dict=tf_dict,
|
||||
N_TF=Nb_TF_seq,
|
||||
temp=args.temp,
|
||||
fixed_prob=False,
|
||||
fixed_mag=args.fixed_mag,
|
||||
shared_mag=args.shared_mag,
|
||||
TF_ignore_mag=tf_ignore_mag), model).to(device)
|
||||
else:
|
||||
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=Nb_TF_seq), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it{}".format(str(aug_model), device_name, epochs, n_inner_iter, postfix))
|
||||
log, aug_acc = run_dist_dataugV3(model=aug_model,
|
||||
epochs=epochs,
|
||||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
unsup_loss=1,
|
||||
augment_loss=args.augment_loss,
|
||||
hp_opt=False, #False #['lr', 'momentum', 'weight_decay']
|
||||
print_freq=print_f,
|
||||
save_sample_freq=sample_save)
|
||||
|
||||
exec_time=time.perf_counter() - t0
|
||||
if device_name != 'CPU':
|
||||
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
|
||||
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
|
||||
else:
|
||||
max_allocated = 0.0
|
||||
max_cached = 0.0
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]),
|
||||
"Aug_Accuracy": [args.augment_loss, aug_acc],
|
||||
"Time": (np.mean(times),np.std(times), exec_time),
|
||||
'Optimizer': optim_param,
|
||||
"Device": device_name,
|
||||
"Memory": [max_allocated, max_cached],
|
||||
"TF_config": args.tf_config,
|
||||
"Param_names": aug_model.TF_names(),
|
||||
"Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "/ aug_acc", out["Aug_Accuracy"][1] , "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{}_epochs-{}_in_it-AL{}".format(str(aug_model),epochs,n_inner_iter,args.augment_loss)+postfix
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
try:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
except:
|
||||
print("Failed to save logs :",f.name)
|
||||
print(sys.exc_info()[1])
|
||||
try:
|
||||
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names(), f1=F1)
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
print(sys.exc_info()[1])
|
||||
|
||||
print('Execution Time (s): %.00f '%(exec_time))
|
||||
print('-'*9)
|
592
higher/smart_aug/train_utils.py
Normal file
592
higher/smart_aug/train_utils.py
Normal file
|
@ -0,0 +1,592 @@
|
|||
""" Utilities function for training.
|
||||
|
||||
"""
|
||||
import sys
|
||||
import torch
|
||||
#import torch.optim
|
||||
#import torchvision
|
||||
import higher
|
||||
import higher_patch
|
||||
|
||||
from datasets import *
|
||||
from utils import *
|
||||
|
||||
from transformations import Normalizer, translate, zero_stack
|
||||
norm = Normalizer(MEAN, STD)
|
||||
confmat = ConfusionMatrix(num_classes=len(dl_test.dataset.classes))
|
||||
|
||||
max_grad = 1 #Max gradient value #Limite catastrophic drop
|
||||
|
||||
def test(model, augment=0):
|
||||
"""Evaluate a model on test data.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to test.
|
||||
augment (int): Number of augmented example for each sample. (Default : 0)
|
||||
|
||||
Returns:
|
||||
(float, Tensor) Returns the accuracy and F1 score of the model.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
# model['model']['functional'].set_mode('mixed') #ABN
|
||||
|
||||
#for i, (features, labels) in enumerate(dl_test):
|
||||
# features,labels = features.to(device), labels.to(device)
|
||||
|
||||
# pred = model.forward(features)
|
||||
# return pred.argmax(dim=1).eq(labels).sum().item() / dl_test.batch_size * 100
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
#loss = []
|
||||
global confmat
|
||||
confmat.reset()
|
||||
with torch.no_grad():
|
||||
for features, labels in dl_test:
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
if augment>0: #Test Time Augmentation
|
||||
model.augment(True)
|
||||
# V2
|
||||
features=torch.cat([features for _ in range(augment)], dim=0) # (B,C,H,W)=>(B*augment,C,H,W)
|
||||
outputs=model(features)
|
||||
outputs=torch.cat([o.unsqueeze(dim=0) for o in outputs.chunk(chunks=augment, dim=0)],dim=0) # (B*augment,nb_class)=>(augment,B,nb_class)
|
||||
|
||||
w_losses=model['data_aug'].loss_weight(batch_norm=False) #(B*augment) if Dataug
|
||||
if w_losses.shape[0]==1: #RandAugment
|
||||
outputs=torch.sum(outputs, axis=0)/augment #mean
|
||||
else: #Dataug
|
||||
w_losses=torch.cat([w.unsqueeze(dim=0) for w in w_losses.chunk(chunks=augment, dim=0)], dim=0) #(augment, B)
|
||||
w_losses = w_losses / w_losses.sum(axis=0, keepdim=True) #sum(w_losses)=1 pour un même echantillons
|
||||
|
||||
outputs=torch.sum(outputs*w_losses.unsqueeze(dim=2).expand_as(outputs), axis=0)/augment #Weighted mean
|
||||
else:
|
||||
outputs = model(features)
|
||||
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
#loss.append(F.cross_entropy(outputs, labels).item())
|
||||
confmat.update(labels, predicted)
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
|
||||
#print(confmat)
|
||||
#from sklearn.metrics import f1_score
|
||||
#f1 = f1_score(labels.data.to('cpu'), predicted.data.to('cpu'), average="macro")
|
||||
|
||||
return accuracy, confmat.f1_metric(average=None)
|
||||
|
||||
def compute_vaLoss(model, dl_it, dl):
|
||||
"""Evaluate a model on a batch of data.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to evaluate.
|
||||
dl_it (Iterator): Data loader iterator.
|
||||
dl (DataLoader): Data loader.
|
||||
|
||||
Returns:
|
||||
(Tensor) Loss on a single batch of data.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
try:
|
||||
xs, ys = next(dl_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_it = iter(dl)
|
||||
xs, ys = next(dl_it)
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
model.eval() #Validation sans transformations !
|
||||
# model['model']['functional'].set_mode('mixed') #ABN
|
||||
# return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys)
|
||||
return F.cross_entropy(model(xs), ys)
|
||||
|
||||
def mixed_loss(xs, ys, model, unsup_factor=1, augment=1):
|
||||
"""Evaluate a model on a batch of data.
|
||||
|
||||
Compute a combinaison of losses:
|
||||
+ Supervised Cross-Entropy loss from original data.
|
||||
+ Unsupervised Cross-Entropy loss from augmented data.
|
||||
+ KL divergence loss encouraging similarity between original and augmented prediction.
|
||||
|
||||
If unsup_factor is equal to 0 or if there isn't data augmentation, only the supervised loss is computed.
|
||||
|
||||
Inspired by UDA, see: https://github.com/google-research/uda/blob/master/image/main.py
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of data.
|
||||
ys (Tensor): Batch of labels.
|
||||
model (nn.Module): Augmented model (see dataug.py).
|
||||
unsup_factor (float): Factor by which unsupervised CE and KL div loss are multiplied.
|
||||
augment (int): Number of augmented example for each sample. (Default : 1)
|
||||
|
||||
Returns:
|
||||
(Tensor) Mixed loss if there's data augmentation, just supervised CE loss otherwise.
|
||||
"""
|
||||
|
||||
#TODO: add test to prevent augmented model error and redirect to classic loss
|
||||
if unsup_factor!=0 and model.is_augmenting() and augment>0:
|
||||
|
||||
# Supervised loss - Cross-entropy
|
||||
model.augment(mode=False)
|
||||
sup_logits = model(xs)
|
||||
model.augment(mode=True)
|
||||
|
||||
log_sup = F.log_softmax(sup_logits, dim=1)
|
||||
sup_loss = F.nll_loss(log_sup, ys)
|
||||
# sup_loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
if augment>1:
|
||||
# Unsupervised loss - Cross-Entropy
|
||||
xs_a=torch.cat([xs for _ in range(augment)], dim=0) # (B,C,H,W)=>(B*augment,C,H,W)
|
||||
ys_a=torch.cat([ys for _ in range(augment)], dim=0)
|
||||
aug_logits=model(xs_a) # (B*augment,nb_class)
|
||||
|
||||
w_loss=model['data_aug'].loss_weight() #(B*augment) if Dataug
|
||||
|
||||
log_aug = F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss = F.nll_loss(log_aug, ys_a , reduction='none')
|
||||
# aug_loss = F.cross_entropy(log_aug, ys_a , reduction='none')
|
||||
aug_loss = (aug_loss * w_loss).mean()
|
||||
|
||||
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
|
||||
sup_logits_a=torch.cat([sup_logits for _ in range(augment)], dim=0)
|
||||
log_sup_a=torch.cat([log_sup for _ in range(augment)], dim=0)
|
||||
|
||||
kl_loss = (F.softmax(sup_logits_a, dim=1)*(log_sup_a-log_aug)).sum(dim=-1)
|
||||
kl_loss = (w_loss * kl_loss).mean()
|
||||
else:
|
||||
# Unsupervised loss - Cross-Entropy
|
||||
aug_logits = model(xs)
|
||||
w_loss = model['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
log_aug = F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss = F.nll_loss(log_aug, ys , reduction='none')
|
||||
# aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
|
||||
aug_loss = (aug_loss * w_loss).mean()
|
||||
|
||||
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
|
||||
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
|
||||
kl_loss = (w_loss * kl_loss).mean()
|
||||
|
||||
loss = sup_loss + unsup_factor * (aug_loss + kl_loss)
|
||||
|
||||
else: #Supervised loss - Cross-Entropy
|
||||
sup_logits = model(xs)
|
||||
loss = F.cross_entropy(sup_logits, ys)
|
||||
# log_sup = F.log_softmax(sup_logits, dim=1)
|
||||
# loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
return loss
|
||||
|
||||
def train_classic(model, opt_param, epochs=1, print_freq=1):
|
||||
"""Classic training of a model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to train.
|
||||
opt_param (dict): Dictionnary containing optimizers parameters.
|
||||
epochs (int): Number of epochs to perform. (default: 1)
|
||||
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
|
||||
Returns:
|
||||
(list) Logs of training. Each items is a dict containing results of an epoch.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
|
||||
#Optimizer
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
momentum=opt_param['Inner']['momentum'],
|
||||
weight_decay=opt_param['Inner']['weight_decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
#Scheduler
|
||||
inner_scheduler=None
|
||||
if opt_param['Inner']['scheduler']=='cosine':
|
||||
inner_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=epochs, eta_min=0.)
|
||||
elif opt_param['Inner']['scheduler']=='multiStep':
|
||||
#Multistep milestones inspired by AutoAugment
|
||||
inner_scheduler=torch.optim.lr_scheduler.MultiStepLR(optim,
|
||||
milestones=[int(epochs/3), int(epochs*2/3), int(epochs*2.7/3)],
|
||||
gamma=0.1)
|
||||
elif opt_param['Inner']['scheduler']=='exponential':
|
||||
#inner_scheduler=torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.1) #Wrong gamma
|
||||
inner_scheduler=torch.optim.lr_scheduler.LambdaLR(optim, lambda epoch: (1 - epoch / epochs) ** 0.9)
|
||||
elif opt_param['Inner']['scheduler'] is not None:
|
||||
raise ValueError("Lr scheduler unknown : %s"%opt_param['Inner']['scheduler'])
|
||||
|
||||
# from warmup_scheduler import GradualWarmupScheduler
|
||||
# inner_scheduler=GradualWarmupScheduler(optim, multiplier=2, total_epoch=5, after_scheduler=inner_scheduler)
|
||||
|
||||
#Training
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
log = []
|
||||
for epoch in range(epochs):
|
||||
#print_torch_mem("Start epoch")
|
||||
#print(optim.param_groups[0]['lr'])
|
||||
t0 = time.perf_counter()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
#viz_sample_data(imgs=features, labels=labels, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#print_torch_mem("Start iter")
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
optim.zero_grad()
|
||||
logits = model.forward(features)
|
||||
pred = F.log_softmax(logits, dim=1)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
# print_graph(loss, '../samples/torchvision_WRN') #to visualize computational graph
|
||||
# sys.exit()
|
||||
|
||||
if inner_scheduler is not None:
|
||||
inner_scheduler.step()
|
||||
# print(optim.param_groups[0]['lr'])
|
||||
|
||||
#### Tests ####
|
||||
tf = time.perf_counter()
|
||||
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
|
||||
accuracy, f1 =test(model)
|
||||
model.train()
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"f1": f1.tolist(),
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy max:', max([x["acc"] for x in log]))
|
||||
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
|
||||
|
||||
return log
|
||||
|
||||
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, unsup_loss=1, augment_loss=1, hp_opt=False, print_freq=1, save_sample_freq=None):
|
||||
"""Training of an augmented model with higher.
|
||||
|
||||
This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
|
||||
Ex : Augmented_model(Data_augV5(...), Higher_model(model))
|
||||
|
||||
Training loss can either be computed directly from augmented inputs (unsup_loss=0).
|
||||
However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
|
||||
|
||||
Args:
|
||||
model (nn.Module): Augmented model to train.
|
||||
opt_param (dict): Dictionnary containing optimizers parameters.
|
||||
epochs (int): Number of epochs to perform. (default: 1)
|
||||
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
|
||||
dataug_epoch_start (int): Epoch when to start data augmentation. (default: 0)
|
||||
unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
|
||||
augment_loss (int): Number of augmented example for each sample in loss computation. (Default : 1)
|
||||
hp_opt (bool): Wether to learn inner optimizer parameters. (default: False)
|
||||
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, no sample will be saved. (default: None)
|
||||
|
||||
Returns:
|
||||
(list) Logs of training. Each items is a dict containing results of an epoch.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
log = []
|
||||
# kl_log={"prob":[], "mag":[]}
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
high_grad_track = True
|
||||
if inner_it == 0: #No HP optimization
|
||||
high_grad_track=False
|
||||
if dataug_epoch_start!=0: #Augmentation de donnee differee
|
||||
model.augment(mode=False)
|
||||
high_grad_track = False
|
||||
|
||||
## Optimizers ##
|
||||
#Inner Opt
|
||||
inner_opt = torch.optim.SGD(model['model']['original'].parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
momentum=opt_param['Inner']['momentum'],
|
||||
weight_decay=opt_param['Inner']['weight_decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
diffopt = model['model'].get_diffopt(
|
||||
inner_opt,
|
||||
grad_callback=(lambda grads: clip_norm(grads, max_norm=max_grad)),
|
||||
track_higher_grads=high_grad_track)
|
||||
|
||||
#Scheduler
|
||||
inner_scheduler=None
|
||||
if opt_param['Inner']['scheduler']=='cosine':
|
||||
inner_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(inner_opt, T_max=epochs, eta_min=0.)
|
||||
elif opt_param['Inner']['scheduler']=='multiStep':
|
||||
#Multistep milestones inspired by AutoAugment
|
||||
inner_scheduler=torch.optim.lr_scheduler.MultiStepLR(inner_opt,
|
||||
milestones=[int(epochs/3), int(epochs*2/3), int(epochs*2.7/3)],
|
||||
gamma=0.1)
|
||||
elif opt_param['Inner']['scheduler']=='exponential':
|
||||
#inner_scheduler=torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.1) #Wrong gamma
|
||||
inner_scheduler=torch.optim.lr_scheduler.LambdaLR(inner_opt, lambda epoch: (1 - epoch / epochs) ** 0.9)
|
||||
elif not(opt_param['Inner']['scheduler'] is None or opt_param['Inner']['scheduler']==''):
|
||||
raise ValueError("Lr scheduler unknown : %s"%opt_param['Inner']['scheduler'])
|
||||
|
||||
#Warmup
|
||||
if opt_param['Inner']['warmup']['multiplier']>=1:
|
||||
from warmup_scheduler import GradualWarmupScheduler
|
||||
inner_scheduler=GradualWarmupScheduler(inner_opt,
|
||||
multiplier=opt_param['Inner']['warmup']['multiplier'],
|
||||
total_epoch=opt_param['Inner']['warmup']['epochs'],
|
||||
after_scheduler=inner_scheduler)
|
||||
|
||||
#Meta Opt
|
||||
hyper_param = list(model['data_aug'].parameters())
|
||||
if hp_opt : #(deprecated)
|
||||
for param_group in diffopt.param_groups:
|
||||
# print(param_group)
|
||||
for param in hp_opt:
|
||||
param_group[param]=torch.tensor(param_group[param]).to(device).requires_grad_()
|
||||
hyper_param += [param_group[param]]
|
||||
meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr'])
|
||||
|
||||
#Meta-Scheduler (deprecated)
|
||||
meta_scheduler=None
|
||||
if opt_param['Meta']['scheduler']=='multiStep':
|
||||
meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt,
|
||||
milestones=[int(epochs/3), int(epochs*2/3)],# int(epochs*2.7/3)],
|
||||
gamma=3.16)
|
||||
elif opt_param['Meta']['scheduler'] is not None:
|
||||
raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler'])
|
||||
|
||||
model.train()
|
||||
meta_opt.zero_grad()
|
||||
|
||||
for epoch in range(1, epochs+1):
|
||||
t0 = time.perf_counter()
|
||||
val_loss=None
|
||||
|
||||
#Cross-Validation
|
||||
#dl_train, dl_val = cvs.next_split()
|
||||
#dl_val_it = iter(dl_val)
|
||||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
if(unsup_loss==0):
|
||||
#Methode uniforme
|
||||
logits = model(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
|
||||
|
||||
if model._data_augmentation: #Weight loss
|
||||
w_loss = model['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
|
||||
else:
|
||||
#Methode mixed
|
||||
loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss, augment=augment_loss)
|
||||
|
||||
# print_graph(loss, '../samples/pytorch_WRN') #to visualize computational graph
|
||||
# sys.exit()
|
||||
|
||||
# t = time.process_time()
|
||||
diffopt.step(loss)#(opt.zero_grad, loss.backward, opt.step)
|
||||
# print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
|
||||
|
||||
if(high_grad_track and i>0 and i%inner_it==0 and epoch>=opt_param['Meta']['epoch_start']): #Perform Meta step
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss(opt_param['Meta']['reg_factor'])
|
||||
model.train()
|
||||
#print_graph(val_loss) #to visualize computational graph
|
||||
val_loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=max_grad, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
# print("Grad mix",model['data_aug']["temp"].grad)
|
||||
# prv_param=model['data_aug']._params
|
||||
meta_opt.step()
|
||||
# kl_log["prob"].append(F.kl_div(prv_param["prob"],model['data_aug']["prob"], reduction='batchmean').item())
|
||||
# kl_log["mag"].append(F.kl_div(prv_param["mag"],model['data_aug']["mag"], reduction='batchmean').item())
|
||||
|
||||
#Adjust Hyper-parameters
|
||||
model['data_aug'].adjust_param()
|
||||
if hp_opt:
|
||||
for param_group in diffopt.param_groups:
|
||||
for param in hp_opt:
|
||||
param_group[param].data = param_group[param].data.clamp(min=1e-5)
|
||||
|
||||
#Reset gradients
|
||||
diffopt.detach_()
|
||||
model['model'].detach_()
|
||||
meta_opt.zero_grad()
|
||||
|
||||
elif not high_grad_track or epoch<opt_param['Meta']['epoch_start']:
|
||||
diffopt.detach_()
|
||||
model['model'].detach_()
|
||||
meta_opt.zero_grad()
|
||||
|
||||
tf = time.perf_counter()
|
||||
|
||||
#Schedulers
|
||||
if inner_scheduler is not None:
|
||||
inner_scheduler.step()
|
||||
#Transfer inner_opt lr to diffopt
|
||||
for diff_param_group in diffopt.param_groups:
|
||||
for param_group in inner_opt.param_groups:
|
||||
diff_param_group['lr'] = param_group['lr']
|
||||
if meta_scheduler is not None:
|
||||
meta_scheduler.step()
|
||||
|
||||
# if epoch<epochs/3:
|
||||
# model['data_aug']['temp'].data=torch.tensor(0.5, device=device)
|
||||
# elif epoch>epochs/3 and epoch<(epochs*2/3):
|
||||
# model['data_aug']['temp'].data=torch.tensor(0.75, device=device)
|
||||
# elif epoch>(epochs*2/3):
|
||||
# model['data_aug']['temp'].data=torch.tensor(1.0, device=device)
|
||||
# model['data_aug']['temp'].data=torch.tensor(1./3+2/3*(epoch/epochs), device=device)
|
||||
# print('Temp',model['data_aug']['temp'])
|
||||
|
||||
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
model.train()
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
|
||||
model.eval()
|
||||
except:
|
||||
print("Couldn't save samples epoch %d : %s"%(epoch, str(sys.exc_info()[1])))
|
||||
pass
|
||||
|
||||
if(not val_loss): #Compute val loss for logs
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
# Test model
|
||||
accuracy, f1 =test(model)
|
||||
model.train()
|
||||
|
||||
#### Log ####
|
||||
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"f1": f1.tolist(),
|
||||
"time": tf - t0,
|
||||
|
||||
"param": param,
|
||||
}
|
||||
if not model['data_aug']._fixed_temp: data["temp"]=model['data_aug']['temp'].item()
|
||||
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'], 'momentum': p_grp['momentum']} for p_grp in diffopt.param_groups]
|
||||
log.append(data)
|
||||
#############
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy max:', max([x["acc"] for x in log]))
|
||||
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
if not model['data_aug']._fixed_prob: print('TF Proba :', ["{0:0.4f}".format(p) for p in model['data_aug']['prob']])
|
||||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
if not model['data_aug']._fixed_mag:
|
||||
if model['data_aug']._shared_mag:
|
||||
print('TF Mag :', "{0:0.4f}".format(model['data_aug']['mag']))
|
||||
else:
|
||||
print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']])
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
if not model['data_aug']._fixed_temp: print('Temp:', model['data_aug']['temp'].item())
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
# if len(kl_log["prob"])!=0:
|
||||
# print("KL prob : mean %f, std %f, max %f, min %f"%(np.mean(kl_log["prob"]), np.std(kl_log["prob"]), max(kl_log["prob"]), min(kl_log["prob"])))
|
||||
# print("KL mag : mean %f, std %f, max %f, min %f"%(np.mean(kl_log["mag"]), np.std(kl_log["mag"]), max(kl_log["mag"]), min(kl_log["mag"])))
|
||||
# kl_log={"prob":[], "mag":[]}
|
||||
|
||||
if hp_opt :
|
||||
for param_group in diffopt.param_groups:
|
||||
print('Opt param - lr:', param_group['lr'].item(),'- momentum:', param_group['momentum'].item())
|
||||
#############
|
||||
|
||||
#Augmentation de donnee differee
|
||||
if not model.is_augmenting() and (epoch == dataug_epoch_start):
|
||||
print('Starting Data Augmention...')
|
||||
dataug_epoch_start = epoch
|
||||
model.augment(mode=True)
|
||||
if inner_it != 0: #Rebuild diffopt if needed
|
||||
high_grad_track = True
|
||||
diffopt = model['model'].get_diffopt(
|
||||
inner_opt,
|
||||
grad_callback=(lambda grads: clip_norm(grads, max_norm=max_grad)),
|
||||
track_higher_grads=high_grad_track)
|
||||
|
||||
aug_acc, aug_f1 = test(model, augment=augment_loss)
|
||||
|
||||
return log, aug_acc
|
||||
|
||||
#OLD
|
||||
# def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1):
|
||||
# """Simple training of an augmented model with higher.
|
||||
|
||||
# This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
|
||||
# Ex : Augmented_model(Data_augV5(...), Higher_model(model))
|
||||
|
||||
# Training loss can either be computed directly from augmented inputs (unsup_loss=0).
|
||||
# However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
|
||||
|
||||
# Does not support LR scheduler.
|
||||
|
||||
# Args:
|
||||
# model (nn.Module): Augmented model to train.
|
||||
# opt_param (dict): Dictionnary containing optimizers parameters.
|
||||
# epochs (int): Number of epochs to perform. (default: 1)
|
||||
# inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
|
||||
# print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
# unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
|
||||
|
||||
# Returns:
|
||||
# (dict) A dictionary containing a whole state of the trained network.
|
||||
# """
|
||||
# device = next(model.parameters()).device
|
||||
|
||||
# ## Optimizers ##
|
||||
# hyper_param = list(model['data_aug'].parameters())
|
||||
# model.start_bilevel_opt(inner_it=inner_it, hp_list=hyper_param, opt_param=opt_param, dl_val=dl_val)
|
||||
|
||||
# model.train()
|
||||
|
||||
# for epoch in range(1, epochs+1):
|
||||
# t0 = time.process_time()
|
||||
|
||||
# for i, (xs, ys) in enumerate(dl_train):
|
||||
# xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
# #Methode mixed
|
||||
# loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
|
||||
|
||||
# model.step(loss) #(opt.zero_grad, loss.backward, opt.step) + automatic meta-optimisation
|
||||
|
||||
# tf = time.process_time()
|
||||
|
||||
# #### Print ####
|
||||
# if(print_freq and epoch%print_freq==0):
|
||||
# print('-'*9)
|
||||
# print('Epoch : %d/%d'%(epoch,epochs))
|
||||
# print('Time : %.00f'%(tf - t0))
|
||||
# print('Train loss :',loss.item(), '/ val loss', model.val_loss().item())
|
||||
# if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
|
||||
# if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
|
||||
# if not model['data_aug']._fixed_temp: print('Temp:', model['data_aug']['temp'].item())
|
||||
# #############
|
||||
|
||||
# return model['model'].state_dict()
|
686
higher/smart_aug/transformations.py
Normal file
686
higher/smart_aug/transformations.py
Normal file
|
@ -0,0 +1,686 @@
|
|||
""" PyTorch implementation of some PIL image transformations.
|
||||
|
||||
Those implementation are thinked to take advantages of batched computation of PyTorch on GPU.
|
||||
|
||||
Based on Kornia library.
|
||||
See: https://github.com/kornia/kornia
|
||||
|
||||
And PIL.
|
||||
See:
|
||||
https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
|
||||
https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
|
||||
|
||||
Inspired from AutoAugment.
|
||||
See: https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import kornia
|
||||
import random
|
||||
import json
|
||||
|
||||
#TF that don't have use for magnitude parameter.
|
||||
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend', 'identity', 'flip'}
|
||||
#TF which implemetation doesn't allow gradient propagaition.
|
||||
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize', 'posterize','solarize'} #Numpy implementation would be better ?
|
||||
#TF for which magnitude should be ignored (Magnitude fixed).
|
||||
TF_ignore_mag= TF_no_mag | TF_no_grad
|
||||
|
||||
# What is the max 'level' a transform could be predicted
|
||||
PARAMETER_MAX = 1
|
||||
# What is the min 'level' a transform could be predicted
|
||||
PARAMETER_MIN = 0.01
|
||||
|
||||
#Dict containing the value for wich TF are closer to identity
|
||||
#TF_identity={
|
||||
# PARAMETER_MAX:{'Solarize', 'Posterize'},
|
||||
# PARAMETER_MAX/2:{'Contrast','Color','Brightness','Sharpness'},
|
||||
# PARAMETER_MIN:{'Rotate','TranslateX','TranslateY','ShearX','ShearY'},
|
||||
#}
|
||||
|
||||
class Normalizer(object):
|
||||
def __init__(self, mean, std):
|
||||
self.mean=torch.tensor(mean)
|
||||
self.std=torch.tensor(std)
|
||||
def __call__(self, x):
|
||||
# return x.sub_(self.mean.to(x.device)[None, :, None, None]).div_(self.std.to(x.device)[None, :, None, None])
|
||||
return kornia.color.normalize(x, self.mean, self.std)
|
||||
def reverse(self, x):
|
||||
# return x.mul_(self.std.to(x.device)[None, :, None, None]).add_(self.mean.to(x.device)[None, :, None, None])
|
||||
return kornia.color.denormalize(x, self.mean, self.std).to(torch.half).to(torch.float)
|
||||
|
||||
class TF_loader(object):
|
||||
""" Transformations builder.
|
||||
|
||||
See 'config' folder for pre-defined config files.
|
||||
|
||||
Attributes:
|
||||
_filename (str): Path to config file (JSON) used.
|
||||
_TF_dict (dict): Transformations dictionnary built from config file.
|
||||
_TF_ignore_mag (set): Ensemble of transformations names for which magnitude should be ignored.
|
||||
_TF_names (list): List of transformations names/keys.
|
||||
"""
|
||||
def __init__(self):
|
||||
""" Initialize TF_loader.
|
||||
|
||||
"""
|
||||
self._filename=''
|
||||
self._TF_dict={}
|
||||
self._TF_ignore_mag=set()
|
||||
self._TF_names=[]
|
||||
|
||||
def load_TF_dict(self, filename):
|
||||
""" Build a TF dictionnary.
|
||||
|
||||
Args:
|
||||
filename (str): Path to config file (JSON) defining the transformations.
|
||||
Returns:
|
||||
(dict, set) : TF dicttionnary built and ensemble of TF names for which mag should be ignored.
|
||||
"""
|
||||
self._filename=filename
|
||||
self._TF_names=[]
|
||||
self._TF_dict={}
|
||||
self._TF_ignore_mag=set()
|
||||
|
||||
with open(filename) as json_file:
|
||||
TF_params = json.load(json_file)
|
||||
|
||||
for tf in TF_params:
|
||||
self._TF_names.append(tf['name'])
|
||||
if tf['function'] in TF_ignore_mag:
|
||||
self._TF_ignore_mag.add(tf['name'])
|
||||
|
||||
if tf['function'] == 'identity':
|
||||
self._TF_dict[tf['name']]=(lambda x, mag: x)
|
||||
|
||||
elif tf['function'] == 'flip':
|
||||
#Inverser axes ?
|
||||
if tf['param']['axis'] == 'X':
|
||||
self._TF_dict[tf['name']]=(lambda x, mag: flipLR(x))
|
||||
elif tf['param']['axis'] == 'Y':
|
||||
self._TF_dict[tf['name']]=(lambda x, mag: flipUD(x))
|
||||
else:
|
||||
raise Exception("Unknown TF axis : %s in %s"%(tf['function'], self._filename))
|
||||
|
||||
# elif tf['function'] in {'translate', 'shear'}:
|
||||
# rand_fct= 'invScale_rand_floats' if tf['param']['invScale'] else 'rand_floats'
|
||||
# self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'], tf['param']['absolute'], tf['param']['axis'])
|
||||
|
||||
else:
|
||||
axis = tf['param']['axis'] if 'axis' in tf['param'].keys() else None
|
||||
absolute = tf['param']['absolute'] if 'absolute' in tf['param'].keys() else True
|
||||
rand_fct= 'invScale_rand_floats' if tf['param']['invScale'] else 'rand_floats'
|
||||
self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'], absolute, axis)
|
||||
|
||||
return self._TF_dict, self._TF_ignore_mag
|
||||
|
||||
def build_lambda(self, fct_name, rand_fct_name, minval, maxval, absolute=True, axis=None):
|
||||
""" Build a lambda function performing transformations.
|
||||
|
||||
Force different context for creation of each lambda function.
|
||||
|
||||
Args:
|
||||
fct_name (str): Name of the transformations to use (see transformations.py).
|
||||
rand_fct_name (str): Name of the random mapping function to use (see transformations.py).
|
||||
minval (float): minimum magnitude value of the TF.
|
||||
maxval (float): maximum magnitude value of the TF.
|
||||
absolute (bool): Wether the maxval should be relative (absolute=False) to the image size. (default: True)
|
||||
axis (str): Axis ('X' / 'Y') of the TF, if relevant. Should be used for (flip)/translate/shear functions. (default: None)
|
||||
|
||||
Returns:
|
||||
(function) transformations function : Tensor=f(Tensor, magnitude)
|
||||
"""
|
||||
if absolute:
|
||||
max_val_fct=(lambda x: maxval)
|
||||
else: #Relative to img size
|
||||
max_val_fct=(lambda x: x*maxval)
|
||||
|
||||
if axis is None:
|
||||
return (lambda x, mag:
|
||||
globals()[fct_name]( #getattr(TF, fct_name)
|
||||
x,
|
||||
globals()[rand_fct_name](
|
||||
size=x.shape[0],
|
||||
mag=mag,
|
||||
minval=minval,
|
||||
maxval=max_val_fct(max(x.shape[2],x.shape[3])))))
|
||||
elif axis =='X':
|
||||
return (lambda x, mag:
|
||||
globals()[fct_name](
|
||||
x,
|
||||
zero_stack(
|
||||
globals()[rand_fct_name](
|
||||
size=(x.shape[0],),
|
||||
mag=mag,
|
||||
minval=minval,
|
||||
maxval=max_val_fct(x.shape[2])),
|
||||
zero_pos=0)))
|
||||
elif axis == 'Y':
|
||||
return (lambda x, mag:
|
||||
globals()[fct_name](
|
||||
x,
|
||||
zero_stack(
|
||||
globals()[rand_fct_name](
|
||||
size=(x.shape[0],),
|
||||
mag=mag,
|
||||
minval=minval,
|
||||
maxval=max_val_fct(x.shape[3])),
|
||||
zero_pos=1)))
|
||||
else:
|
||||
raise Exception("Unknown TF axis : %s in %s"%(fct_name, self._filename))
|
||||
|
||||
def get_TF_names(self):
|
||||
return self._TF_names
|
||||
def get_TF_dict(self):
|
||||
return self._TF_dict
|
||||
## Image type cast ##
|
||||
def int_image(float_image):
|
||||
"""Convert a float Tensor/Image to an int Tensor/Image.
|
||||
|
||||
Be warry that this transformation isn't bijective, each conversion will result in small loss of information.
|
||||
Granularity: 1/256 = 0.0039.
|
||||
|
||||
This will also result in the loss of the gradient associated to input as gradient cannot be tracked on int Tensor.
|
||||
|
||||
Args:
|
||||
float_image (FloatTensor): Image tensor.
|
||||
|
||||
Returns:
|
||||
(ByteTensor) Converted tensor.
|
||||
"""
|
||||
return (float_image*255.).type(torch.uint8)
|
||||
|
||||
def float_image(int_image):
|
||||
"""Convert a int Tensor/Image to an float Tensor/Image.
|
||||
|
||||
Args:
|
||||
int_image (ByteTensor): Image tensor.
|
||||
|
||||
Returns:
|
||||
(FloatTensor) Converted tensor.
|
||||
"""
|
||||
return int_image.type(torch.float)/255.
|
||||
|
||||
## Parameters utils ##
|
||||
def rand_floats(size, mag, maxval, minval=None):
|
||||
"""Generate a batch of random values.
|
||||
|
||||
Args:
|
||||
size (int): Number of value to generate.
|
||||
mag (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX].
|
||||
maxval (float): Maximum value that can be generated. This will be scaled to mag/PARAMETER_MAX.
|
||||
minval (float): Minimum value that can be generated. (default: -maxval)
|
||||
|
||||
Returns:
|
||||
(Tensor) Generated batch of float values between [minval, maxval].
|
||||
"""
|
||||
real_mag = float_parameter(mag, maxval=maxval)
|
||||
if not minval : minval = -real_mag
|
||||
#return random.uniform(minval, real_max)
|
||||
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
|
||||
|
||||
def invScale_rand_floats(size, mag, maxval, minval):
|
||||
"""Generate a batch of random values.
|
||||
|
||||
Similar to rand_floats() except that the mag is used in an inversed scale.
|
||||
|
||||
Mag:[0,PARAMETER_MAX] => [PARAMETER_MAX, 0]
|
||||
|
||||
Args:
|
||||
size (int): Number of value to generate.
|
||||
mag (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX].
|
||||
maxval (float): Maximum value that can be generated. This will be scaled to mag/PARAMETER_MAX.
|
||||
minval (float): Minimum value that can be generated. (default: -maxval)
|
||||
|
||||
Returns:
|
||||
(Tensor) Generated batch of float values between [minval, maxval].
|
||||
"""
|
||||
real_mag = float_parameter(float(PARAMETER_MAX) - mag, maxval=maxval-minval)+minval
|
||||
return real_mag + (maxval-real_mag) * torch.rand(size, device=mag.device) #[real_mag, max_val]
|
||||
|
||||
def zero_stack(tensor, zero_pos):
|
||||
"""Add a row of zeros to a Tensor.
|
||||
|
||||
This function is intended to be used with single row Tensor, thus returning a 2 dimension Tensor.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): Tensor to be stacked with zeros.
|
||||
zero_pos (int): Wheter the zeros should be added before or after the Tensor. Either 0 or 1.
|
||||
|
||||
Returns:
|
||||
Stacked Tensor.
|
||||
"""
|
||||
if zero_pos==0:
|
||||
return torch.stack((tensor, torch.zeros((tensor.shape[0],), device=tensor.device)), dim=1)
|
||||
if zero_pos==1:
|
||||
return torch.stack((torch.zeros((tensor.shape[0],), device=tensor.device), tensor), dim=1)
|
||||
else:
|
||||
raise Exception("Invalid zero_pos : ", zero_pos)
|
||||
|
||||
def float_parameter(level, maxval):
|
||||
"""Scale level between 0 and maxval.
|
||||
|
||||
Args:
|
||||
level (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX].
|
||||
maxval: Maximum value that the operation can have. This will be scaled to level/PARAMETER_MAX.
|
||||
Returns:
|
||||
A float that results from scaling `maxval` according to `level`.
|
||||
"""
|
||||
|
||||
#return float(level) * maxval / PARAMETER_MAX
|
||||
return (level * maxval / PARAMETER_MAX)#.to(torch.float)
|
||||
|
||||
## Tranformations ##
|
||||
def flipLR(x):
|
||||
"""Flip horizontaly/Left-Right images.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of fliped images.
|
||||
"""
|
||||
device = x.device
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
|
||||
M =torch.tensor( [[[-1., 0., w-1],
|
||||
[ 0., 1., 0.],
|
||||
[ 0., 0., 1.]]], device=device).expand(batch_size,-1,-1)
|
||||
|
||||
# warp the original image by the found transform
|
||||
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||
|
||||
def flipUD(x):
|
||||
"""Flip vertically/Up-Down images.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of fliped images.
|
||||
"""
|
||||
device = x.device
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
|
||||
M =torch.tensor( [[[ 1., 0., 0.],
|
||||
[ 0., -1., h-1],
|
||||
[ 0., 0., 1.]]], device=device).expand(batch_size,-1,-1)
|
||||
|
||||
# warp the original image by the found transform
|
||||
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||
|
||||
def rotate(x, angle):
|
||||
"""Rotate images.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
angle (Tensor): Angles (degrees) of rotation for each images.
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of rotated images.
|
||||
"""
|
||||
return kornia.rotate(x, angle=angle.type(torch.float)) #Kornia ne supporte pas les int
|
||||
|
||||
def translate(x, translation):
|
||||
"""Translate images.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
translation (Tensor): Distance (pixels) of translation for each images.
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of translated images.
|
||||
"""
|
||||
return kornia.translate(x, translation=translation.type(torch.float)) #Kornia ne supporte pas les int
|
||||
|
||||
def shear(x, shear):
|
||||
"""Shear images.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
shear (Tensor): Angle of shear for each images.
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of skewed images.
|
||||
"""
|
||||
return kornia.shear(x, shear=shear)
|
||||
|
||||
def contrast(x, contrast_factor):
|
||||
"""Adjust contast of images.
|
||||
|
||||
Args:
|
||||
x (FloatTensor): Batch of images.
|
||||
contrast_factor (FloatTensor): Contrast adjust factor per element in the batch.
|
||||
0 generates a compleatly black image, 1 does not modify the input image while any other non-negative number modify the brightness by this factor.
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of adjusted images.
|
||||
"""
|
||||
return kornia.adjust_contrast(x, contrast_factor=contrast_factor) #Expect image in the range of [0, 1]
|
||||
|
||||
def color(x, color_factor):
|
||||
"""Adjust color of images.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
color_factor (Tensor): Color factor for each images.
|
||||
0.0 gives a black and white image. A factor of 1.0 gives the original image.
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of adjusted images.
|
||||
"""
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
|
||||
gray_x = kornia.rgb_to_grayscale(x)
|
||||
gray_x = gray_x.repeat_interleave(channels, dim=1)
|
||||
return blend(gray_x, x, color_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
|
||||
|
||||
def brightness(x, brightness_factor):
|
||||
"""Adjust brightness of images.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
brightness_factor (Tensor): Brightness factor for each images.
|
||||
0.0 gives a black image. A factor of 1.0 gives the original image.
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of adjusted images.
|
||||
"""
|
||||
device = x.device
|
||||
|
||||
return blend(torch.zeros(x.size(), device=device), x, brightness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
|
||||
|
||||
def sharpness(x, sharpness_factor):
|
||||
"""Adjust sharpness of images.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
sharpness_factor (Tensor): Sharpness factor for each images.
|
||||
0.0 gives a black image. A factor of 1.0 gives the original image.
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of adjusted images.
|
||||
"""
|
||||
device = x.device
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
|
||||
k = torch.tensor([[[ 1., 1., 1.],
|
||||
[ 1., 5., 1.],
|
||||
[ 1., 1., 1.]]], device=device) #Smooth Filter : https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageFilter.py
|
||||
smooth_x = kornia.filter2D(x, kernel=k, border_type='reflect', normalized=True) #Peut etre necessaire de s'occuper du channel Alhpa differement
|
||||
|
||||
return blend(smooth_x, x, sharpness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
|
||||
|
||||
def posterize(x, bits):
|
||||
"""Reduce the number of bits for each color channel.
|
||||
|
||||
Be warry that the cast to integers block the gradient propagation.
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
bits (Tensor): The number of bits to keep for each channel (1-8).
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of posterized images.
|
||||
"""
|
||||
bits = bits.type(torch.uint8) #Perte du gradient
|
||||
x = int_image(x) #Expect image in the range of [0, 1]
|
||||
|
||||
mask = ~(2 ** (8 - bits) - 1).type(torch.uint8)
|
||||
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
mask = mask.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
|
||||
return float_image(x & mask)
|
||||
|
||||
def cutout(img, length):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): Batch of images. Expect image value between [0, 1].
|
||||
length (Tensor): The length (in pixels) of each square patch.
|
||||
Returns:
|
||||
Tensor: Images with single holes of dimension length x length cut out of it.
|
||||
"""
|
||||
device = img.device
|
||||
(batch_size, channels, h, w) = img.shape
|
||||
|
||||
# mask = np.ones((h, w), np.float32)
|
||||
mask = torch.ones((batch_size, h, w), device=device)
|
||||
length=length.type(torch.uint8)
|
||||
|
||||
# y = np.random.randint(h)
|
||||
# x = np.random.randint(w)
|
||||
y = torch.randint(low=0, high=h, size=(batch_size,)).to(device)
|
||||
x = torch.randint(low=0, high=w, size=(batch_size,)).to(device)
|
||||
|
||||
# y1 = np.clip(y - length // 2, 0, h)
|
||||
# y2 = np.clip(y + length // 2, 0, h)
|
||||
# x1 = np.clip(x - length // 2, 0, w)
|
||||
# x2 = np.clip(x + length // 2, 0, w)
|
||||
y1 = (y - length // 2).clamp(min=0, max=h)
|
||||
y2 = (y + length // 2).clamp(min=0, max=h)
|
||||
x1 = (x - length // 2).clamp(min=0, max=w)
|
||||
x2 = (x + length // 2).clamp(min=0, max=w)
|
||||
|
||||
# mask[y1: y2, x1: x2] = 0.
|
||||
for idx in range(batch_size): #Pas opti pour des batch
|
||||
mask[idx, y1[idx]: y2[idx], x1[idx]: x2[idx]]= 0.
|
||||
|
||||
# mask = torch.from_numpy(mask)
|
||||
# mask = mask.expand_as(img)
|
||||
mask = mask.unsqueeze(dim=1).expand_as(img)
|
||||
img = img * mask
|
||||
|
||||
return img
|
||||
|
||||
import torch.nn.functional as F
|
||||
def solarize(x, thresholds):
|
||||
"""Invert all pixel values above a threshold.
|
||||
|
||||
Be warry that the use of the inequality (x>tresholds) block the gradient propagation.
|
||||
|
||||
TODO : Make differentiable.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
thresholds (Tensor): All pixels above this level are inverted
|
||||
|
||||
Returns:
|
||||
(Tensor): Batch of solarized images.
|
||||
"""
|
||||
batch_size, channels, h, w = x.shape
|
||||
#imgs=[]
|
||||
#for idx, t in enumerate(thresholds): #Operation par image
|
||||
# mask = x[idx] > t #Perte du gradient
|
||||
#In place
|
||||
# inv_x = 1-x[idx][mask]
|
||||
# x[idx][mask]=inv_x
|
||||
#
|
||||
|
||||
#Out of place
|
||||
# im = x[idx]
|
||||
# inv_x = 1-im[mask]
|
||||
|
||||
# imgs.append(im.masked_scatter(mask,inv_x))
|
||||
|
||||
#idxs=torch.tensor(range(x.shape[0]), device=x.device)
|
||||
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
|
||||
#
|
||||
|
||||
thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
x=torch.where(x>thresholds,1-x, x)
|
||||
|
||||
#x=x.min(thresholds)
|
||||
#inv_x = 1-x[mask]
|
||||
#x=x.where(x<thresholds,1-x)
|
||||
#x[mask]=inv_x
|
||||
#x=x.masked_scatter(mask, inv_x)
|
||||
|
||||
#Differentiable (/Thresholds) ?
|
||||
|
||||
#inv_x_bT= F.relu(x) - F.relu(x - thresholds)
|
||||
#inv_x_aT= 1-x #Besoin thresholds
|
||||
|
||||
#print('-'*10)
|
||||
#print(thresholds[0])
|
||||
#print(x[0])
|
||||
#print(inv_x_bT[0])
|
||||
#print(inv_x_aT[0])
|
||||
|
||||
#x=torch.where(x>thresholds,inv_x_aT, inv_x_bT)
|
||||
#print(torch.allclose(x, x+0.001, atol=1e-3))
|
||||
#print(torch.allclose(x, sol_x, atol=1e-2))
|
||||
#print(torch.eq(x,sol_x)[0])
|
||||
|
||||
#print(x[0])
|
||||
#print(sol_x[0])
|
||||
#'''
|
||||
return x
|
||||
|
||||
def blend(x,y,alpha):
|
||||
"""Creates a new images by interpolating between two input images, using a constant alpha.
|
||||
|
||||
x and y should have the same size.
|
||||
alpha should have the same batch size as the images.
|
||||
|
||||
Apply batch wise :
|
||||
out = image1 * (1.0 - alpha) + image2 * alpha
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of images.
|
||||
y (Tensor): Batch of images.
|
||||
alpha (Tensor): The interpolation alpha factor for each images.
|
||||
Returns:
|
||||
(Tensor): Batch of solarized images.
|
||||
"""
|
||||
#return kornia.add_weighted(src1=x, alpha=(1-alpha), src2=y, beta=alpha, gamma=0) #out=src1∗alpha+src2∗beta+gamma #Ne fonctionne pas pour des batch de alpha
|
||||
|
||||
if not isinstance(x, torch.Tensor):
|
||||
raise TypeError("x should be a tensor. Got {}".format(type(x)))
|
||||
|
||||
if not isinstance(y, torch.Tensor):
|
||||
raise TypeError("y should be a tensor. Got {}".format(type(y)))
|
||||
|
||||
assert(x.shape==y.shape and x.shape[0]==alpha.shape[0])
|
||||
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
alpha = alpha.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
res = x*(1-alpha) + y*alpha
|
||||
|
||||
return res
|
||||
|
||||
#Not working
|
||||
def auto_contrast(x):
|
||||
"""NOT TESTED - EXTRA SLOW
|
||||
|
||||
"""
|
||||
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
|
||||
print("Warning : Pas encore check !")
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
x = int_image(x) #Expect image in the range of [0, 1]
|
||||
#print('Start',x[0])
|
||||
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
|
||||
#print(img.shape)
|
||||
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
|
||||
#print(chan.shape)
|
||||
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
||||
|
||||
# find lowest/highest samples after preprocessing
|
||||
for lo in range(256):
|
||||
if hist[lo]:
|
||||
break
|
||||
for hi in range(255, -1, -1):
|
||||
if hist[hi]:
|
||||
break
|
||||
if hi <= lo:
|
||||
# don't bother
|
||||
pass
|
||||
else:
|
||||
scale = 255.0 / (hi - lo)
|
||||
offset = -lo * scale
|
||||
for ix in range(256):
|
||||
n_ix = int(ix * scale + offset)
|
||||
if n_ix < 0: n_ix = 0
|
||||
elif n_ix > 255: n_ix = 255
|
||||
|
||||
chan[chan==ix]=n_ix
|
||||
x[im_idx, chan_idx]=chan
|
||||
|
||||
#print('End',x[0])
|
||||
return float_image(x)
|
||||
|
||||
def equalize(x):
|
||||
""" NOT WORKING
|
||||
|
||||
"""
|
||||
raise Exception(self, "not implemented")
|
||||
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
x = int_image(x) #Expect image in the range of [0, 1]
|
||||
#print('Start',x[0])
|
||||
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
|
||||
#print(img.shape)
|
||||
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
|
||||
#print(chan.shape)
|
||||
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
||||
|
||||
return float_image(x)
|
||||
|
||||
'''
|
||||
# Dictionnary mapping tranformations identifiers to their function.
|
||||
# Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data.
|
||||
TF_dict={ #Dataugv5+
|
||||
## Geometric TF ##
|
||||
'Identity' : (lambda x, mag: x),
|
||||
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
|
||||
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=x.shape[2]*0.33), zero_pos=0))),
|
||||
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=x.shape[3]*0.33), zero_pos=1))),
|
||||
'TranslateXabs': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
||||
'TranslateYabs': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
||||
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||
|
||||
#Color TF (Common mag scale)
|
||||
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'+Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||
|
||||
## Bad Tranformations ##
|
||||
# Bad Geometric TF #
|
||||
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))),
|
||||
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))),
|
||||
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))),
|
||||
'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))),
|
||||
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))),
|
||||
'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))),
|
||||
|
||||
# Bad Color TF #
|
||||
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))),
|
||||
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
||||
|
||||
# Random TF #
|
||||
'Random':(lambda x, mag: torch.rand_like(x)),
|
||||
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))),
|
||||
|
||||
#Not ready for use
|
||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||
#'Equalize': (lambda mag: None),
|
||||
}
|
||||
'''
|
436
higher/smart_aug/utils.py
Normal file
436
higher/smart_aug/utils.py
Normal file
|
@ -0,0 +1,436 @@
|
|||
""" Utilties function.
|
||||
|
||||
"""
|
||||
import numpy as np
|
||||
import json, math, time, os
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') #https://stackoverflow.com/questions/4706451/how-to-save-a-figure-remotely-with-pylab
|
||||
import matplotlib.pyplot as plt
|
||||
import copy
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import time
|
||||
|
||||
from nets.LeNet import *
|
||||
from nets.wideresnet import *
|
||||
from nets.wideresnet_cifar import *
|
||||
import nets.resnet_abn as resnet_abn
|
||||
import nets.resnet_deconv as resnet_DC
|
||||
from efficientnet_pytorch import EfficientNet
|
||||
from efficientnet_pytorch.utils import url_map as EfficientNet_map
|
||||
import torchvision.models as models
|
||||
def load_model(model, num_classes, pretrained=False):
|
||||
if model in models.resnet.__all__ :
|
||||
model_name = model #'resnet18' #'resnet34' #'wide_resnet50_2'
|
||||
if pretrained :
|
||||
print("Using pretrained weights")
|
||||
model = getattr(models.resnet, model_name)(pretrained=True)
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, num_classes)
|
||||
else:
|
||||
model = getattr(models.resnet, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model in models.vgg.__all__ :
|
||||
model_name = model #'vgg11', 'vgg1_bn'
|
||||
if pretrained :
|
||||
print("Using pretrained weights")
|
||||
model = getattr(models.vgg, model_name)(pretrained=True)
|
||||
num_ftrs = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
|
||||
else :
|
||||
model = getattr(models.vgg, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model in models.densenet.__all__ :
|
||||
model_name = model #'densenet121' #'densenet201'
|
||||
if pretrained :
|
||||
print("Using pretrained weights")
|
||||
model = getattr(models.densenet, model_name)(pretrained=True)
|
||||
num_ftrs = model.classifier.in_features
|
||||
model.classifier = nn.Linear(num_ftrs, num_classes)
|
||||
else:
|
||||
model = getattr(models.densenet, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model == 'LeNet':
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
model = LeNet(3,num_classes)
|
||||
model_name=str(model)
|
||||
elif model == 'WideResNet':
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
# model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_classes)
|
||||
# model = WideResNet(16, 4, dropout_rate=0.0, num_classes=num_classes)
|
||||
model = wide_resnet_cifar(26, 10, num_classes=num_classes)
|
||||
# model = wide_resnet_cifar(20, 10, num_classes=num_classes)
|
||||
model_name=str(model)
|
||||
elif model in EfficientNet_map.keys():
|
||||
model_name=model # efficientnet-b0 , efficientnet-b1, efficientnet-b4
|
||||
if pretrained: #ImageNet ou Advprop (Meilleurs perf normalement mais normalisation differentes)
|
||||
print("Using pretrained weights")
|
||||
model = EfficientNet.from_pretrained(model_name, advprop=False)
|
||||
else:
|
||||
model = EfficientNet.from_name(model_name)
|
||||
elif model in resnet_abn.__all__ :
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
model_name=model
|
||||
model = getattr(resnet_abn, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model in resnet_DC.__all__:
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
model_name = model
|
||||
model = getattr(resnet_DC, model_name)(num_classes=num_classes)
|
||||
else:
|
||||
raise Exception('Unknown model')
|
||||
|
||||
return model, model_name
|
||||
|
||||
class ConfusionMatrix(object):
|
||||
""" Confusion matrix.
|
||||
|
||||
Helps computing the confusion matrix and F1 scores.
|
||||
|
||||
Example use ::
|
||||
confmat = ConfusionMatrix(...)
|
||||
|
||||
confmat.reset()
|
||||
for data in dataset:
|
||||
...
|
||||
confmat.update(...)
|
||||
|
||||
confmat.f1_metric(...)
|
||||
|
||||
Attributes:
|
||||
num_classes (int): Number of classes.
|
||||
mat (Tensor): Confusion matrix. Filled by update method.
|
||||
"""
|
||||
def __init__(self, num_classes):
|
||||
""" Initialize ConfusionMatrix.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes.
|
||||
"""
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, target, pred):
|
||||
""" Update the confusion matrix.
|
||||
|
||||
Args:
|
||||
target (Tensor): Target labels.
|
||||
pred (Tensor): Prediction.
|
||||
"""
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
|
||||
with torch.no_grad():
|
||||
k = (target >= 0) & (target < n)
|
||||
inds = n * target[k].to(torch.int64) + pred[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
""" Reset the Confusion matrix.
|
||||
|
||||
"""
|
||||
if self.mat is not None:
|
||||
self.mat.zero_()
|
||||
|
||||
def f1_metric(self, average=None):
|
||||
""" Compute the F1 score.
|
||||
|
||||
Inspired from :
|
||||
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
|
||||
https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
|
||||
|
||||
Args:
|
||||
average (str): Type of averaging performed on the data. (Default: None)
|
||||
``None``:
|
||||
The scores for each class are returned.
|
||||
``'micro'``:
|
||||
Calculate metrics globally by counting the total true positives,
|
||||
false negatives and false positives.
|
||||
``'macro'``:
|
||||
Calculate metrics for each label, and find their unweighted
|
||||
mean. This does not take label imbalance into account.
|
||||
Return:
|
||||
Tensor containing the F1 score. It's shape is either 1, if there was averaging, or (num_classes).
|
||||
"""
|
||||
|
||||
h = self.mat.float()
|
||||
TP = torch.diag(h)
|
||||
TN = []
|
||||
FP = []
|
||||
FN = []
|
||||
for c in range(self.num_classes):
|
||||
idx = torch.ones(self.num_classes).bool()
|
||||
idx[c] = 0
|
||||
# all non-class samples classified as non-class
|
||||
TN.append(self.mat[idx.nonzero()[:, None], idx.nonzero()].sum()) #conf_matrix[idx[:, None], idx].sum() - conf_matrix[idx, c].sum()
|
||||
# all non-class samples classified as class
|
||||
FP.append(self.mat[idx, c].sum())
|
||||
# all class samples not classified as class
|
||||
FN.append(self.mat[c, idx].sum())
|
||||
|
||||
#print('Class {}\nTP {}, TN {}, FP {}, FN {}'.format(c, TP[c], TN[c], FP[c], FN[c]))
|
||||
|
||||
tp = (TP/h.sum(1))#.sum()
|
||||
tn = (torch.tensor(TN, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
||||
fp = (torch.tensor(FP, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
||||
fn = (torch.tensor(FN, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
||||
|
||||
if average=="micro":
|
||||
tp, tn, fp, fn = tp.sum(), tn.sum(), fp.sum(), fn.sum()
|
||||
|
||||
epsilon = 1e-7
|
||||
precision = tp / (tp + fp + epsilon)
|
||||
recall = tp / (tp + fn + epsilon)
|
||||
|
||||
f1 = 2* (precision*recall) / (precision + recall + epsilon)
|
||||
|
||||
if average=="macro":
|
||||
f1=f1.mean()
|
||||
return f1
|
||||
|
||||
#from torchviz import make_dot
|
||||
def print_graph(PyTorch_obj=torch.randn(1, 3, 32, 32), fig_name='graph'):
|
||||
"""Save the computational graph.
|
||||
|
||||
Args:
|
||||
PyTorch_obj (Tensor): End of the graph. Commonly, the loss tensor to get the whole graph.
|
||||
fig_name (string): Relative path where to save the graph. (default: graph)
|
||||
"""
|
||||
graph=make_dot(PyTorch_obj)
|
||||
graph.format = 'png' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||
graph.render(fig_name)
|
||||
|
||||
def plot_resV2(log, fig_name='res', param_names=None, f1=True):
|
||||
"""Save a visual graph of the logs.
|
||||
|
||||
Args:
|
||||
log (dict): Logs of the training generated by most of train_utils.
|
||||
fig_name (string): Relative path where to save the graph. (default: res)
|
||||
param_names (list): Labels for the parameters. (default: None)
|
||||
f1 (bool): Wether to plot F1 scores. (default: True)
|
||||
"""
|
||||
epochs = [x["epoch"] for x in log]
|
||||
|
||||
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(30, 15))
|
||||
|
||||
ax[0, 0].set_title('Loss')
|
||||
ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
||||
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||||
ax[0, 0].legend()
|
||||
|
||||
ax[1, 0].set_title('Test')
|
||||
ax[1, 0].plot(epochs,[x["acc"] for x in log], label='Acc')
|
||||
|
||||
if f1 and "f1" in log[0].keys():
|
||||
#ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1')
|
||||
#'''
|
||||
#print(log[0]["f1"])
|
||||
if isinstance(log[0]["f1"], list):
|
||||
if len(log[0]["f1"])>10:
|
||||
print("Plotting results : Too many class for F1, plotting only min/max")
|
||||
ax[1, 0].plot(epochs,[max(x["f1"])*100 for x in log], label='F1-Max', ls='--')
|
||||
ax[1, 0].plot(epochs,[min(x["f1"])*100 for x in log], label='F1-Min', ls='--')
|
||||
else:
|
||||
for c in range(len(log[0]["f1"])):
|
||||
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
|
||||
else:
|
||||
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1', ls='--')
|
||||
#'''
|
||||
|
||||
ax[1, 0].legend()
|
||||
|
||||
if log[0]["param"]!= None:
|
||||
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
||||
#proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
|
||||
ax[0, 1].set_title('Prob =f(epoch)')
|
||||
ax[0, 1].stackplot(epochs, proba, labels=param_names)
|
||||
#ax[0, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
|
||||
ax[1, 1].set_title('Prob =f(TF)')
|
||||
mean = np.mean(proba, axis=1)
|
||||
std = np.std(proba, axis=1)
|
||||
ax[1, 1].bar(param_names, mean, yerr=std)
|
||||
plt.sca(ax[1, 1]), plt.xticks(rotation=90)
|
||||
|
||||
ax[0, 2].set_title('Mag =f(epoch)')
|
||||
ax[0, 2].stackplot(epochs, mag, labels=param_names)
|
||||
#ax[0, 2].plot(epochs, np.array(mag).T, label=param_names)
|
||||
ax[0, 2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
|
||||
ax[1, 2].set_title('Mag =f(TF)')
|
||||
mean = np.mean(mag, axis=1)
|
||||
std = np.std(mag, axis=1)
|
||||
ax[1, 2].bar(param_names, mean, yerr=std)
|
||||
plt.sca(ax[1, 2]), plt.xticks(rotation=90)
|
||||
|
||||
|
||||
fig_name = fig_name.replace('.',',').replace(',,/','../')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def plot_compare(filenames, fig_name='res'):
|
||||
"""Save a visual graph comparing trainings stats.
|
||||
|
||||
Args:
|
||||
filenames (list[Strings]): Relative paths to the logs (JSON files).
|
||||
fig_name (string): Relative path where to save the graph. (default: res)
|
||||
"""
|
||||
all_data=[]
|
||||
legend=""
|
||||
for idx, file in enumerate(filenames):
|
||||
legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
all_data.append(data)
|
||||
|
||||
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
|
||||
|
||||
for data_idx, log in enumerate(all_data):
|
||||
log=log['Log']
|
||||
epochs = [x["epoch"] for x in log]
|
||||
|
||||
ax[0].plot(epochs,[x["train_loss"] for x in log], label=str(data_idx)+'-Train')
|
||||
ax[0].plot(epochs,[x["val_loss"] for x in log], label=str(data_idx)+'-Val')
|
||||
|
||||
ax[1].plot(epochs,[x["acc"] for x in log], label=str(data_idx))
|
||||
#ax[1].text(x=0.5,y=0,s=str(data_idx)+'-'+filenames[data_idx], transform=ax[1].transAxes)
|
||||
|
||||
if log[0]["param"]!= None:
|
||||
if isinstance(log[0]["param"],float):
|
||||
ax[2].plot(epochs,[x["param"] for x in log], label=str(data_idx)+'-Mag')
|
||||
|
||||
else :
|
||||
for idx, _ in enumerate(log[0]["param"]):
|
||||
ax[2].plot(epochs,[x["param"][idx] for x in log], label=str(data_idx)+'-P'+str(idx))
|
||||
|
||||
fig.suptitle(legend)
|
||||
ax[0].set_title('Loss')
|
||||
ax[1].set_title('Acc')
|
||||
ax[2].set_title('Param')
|
||||
for a in ax: a.legend()
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
|
||||
"""Save data samples.
|
||||
|
||||
Args:
|
||||
imgs (Tensor): Batch of image to sample from. Intended to contain at least 25 images.
|
||||
labels (Tensor): Labels of the images.
|
||||
fig_name (string): Relative path where to save the graph. (default: data_sample)
|
||||
weight_labels (Tensor): Weights associated to each labels. (default: None)
|
||||
"""
|
||||
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
|
||||
|
||||
plt.figure(figsize=(10,10))
|
||||
for i in range(25):
|
||||
plt.subplot(5,5,i+1) #Trop de figure cree ?
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
plt.grid(False)
|
||||
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
|
||||
label = str(labels[i].item())
|
||||
if torch.is_tensor(weight_labels): label+= (" - p %.2f" % weight_labels[i].item())
|
||||
plt.xlabel(label)
|
||||
|
||||
plt.savefig(fig_name)
|
||||
print("Sample saved :", fig_name)
|
||||
plt.close('all')
|
||||
|
||||
def print_torch_mem(add_info=''):
|
||||
"""Print informations on PyTorch memory usage.
|
||||
|
||||
Args:
|
||||
add_info (string): Prefix added before the print. (default: None)
|
||||
"""
|
||||
nb=0
|
||||
max_size=0
|
||||
for obj in gc.get_objects():
|
||||
#print(type(obj))
|
||||
try:
|
||||
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # and len(obj.size())>1:
|
||||
#print(i, type(obj), obj.size())
|
||||
size = np.sum(obj.size())
|
||||
if(size>max_size): max_size=size
|
||||
nb+=1
|
||||
except:
|
||||
pass
|
||||
print(add_info, "-Pytroch tensor nb:",nb," / Max dim:", max_size)
|
||||
|
||||
#print(add_info, "-Garbage size :",len(gc.garbage))
|
||||
|
||||
"""Simple GPU memory report."""
|
||||
|
||||
mega_bytes = 1024.0 * 1024.0
|
||||
string = add_info + ' memory (MB)'
|
||||
string += ' | allocated: {}'.format(
|
||||
torch.cuda.memory_allocated() / mega_bytes)
|
||||
string += ' | max allocated: {}'.format(
|
||||
torch.cuda.max_memory_allocated() / mega_bytes)
|
||||
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
|
||||
string += ' | max cached: {}'.format(
|
||||
torch.cuda.max_memory_cached()/ mega_bytes)
|
||||
print(string)
|
||||
|
||||
'''
|
||||
def plot_TF_influence(log, fig_name='TF_influence', param_names=None):
|
||||
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
|
||||
plt.figure()
|
||||
|
||||
mean = np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
|
||||
std = np.std(proba, axis=1)*np.std(mag, axis=1)
|
||||
plt.bar(param_names, mean, yerr=std)
|
||||
|
||||
plt.xticks(rotation=90)
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
'''
|
||||
|
||||
from torch._six import inf
|
||||
def clip_norm(tensors, max_norm, norm_type=2):
|
||||
"""Clips norm of passed tensors.
|
||||
The norm is computed over all tensors together, as if they were
|
||||
concatenated into a single vector. Clipped tensors are returned.
|
||||
|
||||
See: https://github.com/facebookresearch/higher/issues/18
|
||||
|
||||
Args:
|
||||
tensors (Iterable[Tensor]): an iterable of Tensors or a
|
||||
single Tensor to be normalized.
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
Returns:
|
||||
Clipped (List[Tensor]) tensors.
|
||||
"""
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
tensors = [tensors]
|
||||
tensors = list(tensors)
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(t.abs().max() for t in tensors)
|
||||
else:
|
||||
total_norm = 0
|
||||
for t in tensors:
|
||||
if t is None:
|
||||
continue
|
||||
param_norm = t.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
total_norm = total_norm ** (1. / norm_type)
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef >= 1:
|
||||
return tensors
|
||||
#return [t.mul(clip_coef) for t in tensors]
|
||||
return [t if t is None else t.mul(clip_coef) for t in tensors]
|
Loading…
Add table
Add a link
Reference in a new issue