Changes since Teledyne

This commit is contained in:
Antoine Harlé 2020-06-23 07:55:40 -07:00
parent bd5dc63cff
commit 1060f18033
203 changed files with 24395 additions and 0 deletions

View file

@ -0,0 +1,73 @@
import argparse
#Argparse
parser = argparse.ArgumentParser(description='Run smart augmentation')
parser.add_argument('-dv','--device', default='cuda', dest='device',
help='Device : cpu / cuda')
parser.add_argument('-dt','--dtype', default='FP32', dest='dtype',
help='Data type (Default: Float32)')
parser.add_argument('-m','--model', default='resnet18', dest='model',
help='Network')
parser.add_argument('-pt','--pretrained', default='', dest='pretrained',
help='Use pretrained weight if possible')
parser.add_argument('-ep','--epochs', type=int, default=10, dest='epochs',
help='epoch')
# parser.add_argument('-ot', '--optimizer', default='SGD', dest='opt_type',
# help='Model optimizer')
parser.add_argument('-lr', type=float, default=1e-1, dest='lr',
help='Model learning rate')
parser.add_argument('-mo', '--momentum', type=float, default=0.9, dest='momentum',
help='Momentum')
parser.add_argument('-dc', '--decay', type=float, default=0.0005, dest='decay',
help='Weight decay')
parser.add_argument('-ns','--nesterov', type=bool, default=False, dest='nesterov',
help='Nesterov momentum ?')
parser.add_argument('-sc', '--scheduler', default='cosine', dest='scheduler',
help='Model learning rate scheduler')
parser.add_argument('-wu', '--warmup', type=float, default=0, dest='warmup',
help='Warmup multiplier')
parser.add_argument('-a','--augment', type=bool, default=False, dest='augment',
help='Data augmentation ?')
parser.add_argument('-N', type=int, default=1,
help='Combination of TF')
parser.add_argument('-K', type=int, default=0,
help='Number inner iteration')
parser.add_argument('-al','--augment_loss', type=int, default=1, dest='augment_loss',
help='Number of augmented example for each sample in loss computation.')
parser.add_argument('-t', '--temp', type=float, default=0.5, dest='temp',
help='Probability distribution temperature')
parser.add_argument('-tfc','--tf_config', default='../config/invScale_wide_tf_config.json', dest='tf_config',
help='TF config')
parser.add_argument('-ls', '--learn_seq', type=bool, default=False, dest='learn_seq',
help='Learn order of application of TF (DataugV7-8) ?')
parser.add_argument('-fm', '--fixed_mag', type=bool, default=False, dest='fixed_mag',
help='Fixed magnitude when learning data augmentation ?')
parser.add_argument('-sm', '--shared_mag', type=bool, default=False, dest='shared_mag',
help='Shared magnitude when learning data augmentation ?')
# parser.add_argument('-mot', '--metaoptimizer', default='Adam', dest='meta_opt_type',
# help='Meta optimizer (Augmentations)')
parser.add_argument('-mlr', type=float, default=1e-2, dest='mlr',
help='Meta learning rate (Augmentations)')
parser.add_argument('-ms', type=int, default=0, dest='meta_epoch_start',
help='Epoch at which start meta learning')
parser.add_argument('-mr', type=float, default=0.001, dest='mag_reg',
help='Augmentation magnitudes regulation factor')
parser.add_argument('-rf','--res_folder', default='../res/', dest='res_folder',
help='Results folder')
parser.add_argument('-pf','--postfix', default='', dest='postfix',
help='Res postfix')
parser.add_argument('-dr','--dataroot', default='~/scratch/data', dest='dataroot',
help='Datasets folder')
parser.add_argument('-ds','--dataset', default='CIFAR10', dest='dataset',
help='Dataset')
parser.add_argument('-bs','--batch_size', type=int, default=256, dest='batch_size',
help='Batch size') #256 (WRN) / 512
parser.add_argument('-w','--workers', type=int, default=6, dest='workers',
help='Numer of workers (Nb CPU cores).')

View file

@ -0,0 +1,247 @@
""" Script to run series of experiments.
"""
from dataug import *
#from utils import *
from train_utils import *
from transformations import TF_loader
import torchvision.models as models
from LeNet import *
#model_list={models.resnet: ['resnet18', 'resnet50','wide_resnet50_2']} #lr=0.1
model_list={models.resnet: ['wide_resnet50_2']}
optim_param={
'Meta':{
'optim':'Adam',
'lr':5e-3, #1e-2
'epoch_start': 2, #0 / 2 (Resnet?)
'reg_factor': 0.001,
'scheduler': None, #None, 'multiStep'
},
'Inner':{
'optim': 'SGD',
'lr':1e-2, #1e-2/1e-1 (ResNet)
'momentum':0.9, #0.9
'decay':0.0005, #0.0005
'nesterov':False, #False (True: Bad behavior w/ Data_aug)
'scheduler':'cosine', #None, 'cosine', 'multiStep', 'exponential'
}
}
res_folder="../res/benchmark/CIFAR10/"
#res_folder="../res/benchmark/MNIST/"
#res_folder="../res/HPsearch/"
epochs= 200
dataug_epoch_start=0
nb_run= 3
tf_config='../config/bad_tf_config.json' #'../config/wide_tf_config.json'#'../config/base_tf_config.json'
TF_loader=TF_loader()
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(tf_config)
device = torch.device('cuda:1')
if device == torch.device('cpu'):
device_name = 'CPU'
else:
device_name = torch.cuda.get_device_name(device)
torch.backends.cudnn.benchmark = True #Faster if same input size #Not recommended for reproductibility
#Increase reproductibility
torch.manual_seed(0)
np.random.seed(0)
##########################################
if __name__ == "__main__":
### Benchmark ###
'''
inner_its = [0]
dist_mix = [0.5]
N_seq_TF= [3]
mag_setup = [(False, False)] #[(True, True), (False, False)] #(FxSh, Independant)
for model_type in model_list.keys():
for model_name in model_list[model_type]:
for run in range(nb_run):
for n_inner_iter in inner_its:
for n_tf in N_seq_TF:
for dist in dist_mix:
for m_setup in mag_setup:
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
torch.cuda.reset_max_memory_cached() #reset_peak_stats
t0 = time.perf_counter()
model = getattr(model_type, model_name)(pretrained=False, num_classes=len(dl_train.dataset.classes))
#model_name = 'LeNet'
#model = LeNet(3,10)
model = Higher_model(model, model_name) #run_dist_dataugV3
if n_inner_iter!=0:
aug_model = Augmented_model(
Data_augV5(TF_dict=tf_dict,
N_TF=n_tf,
mix_dist=dist,
fixed_prob=False,
fixed_mag=m_setup[0],
shared_mag=m_setup[1],
TF_ignore_mag=tf_ignore_mag),
model).to(device)
else:
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=n_tf), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
log= run_dist_dataugV3(model=aug_model,
epochs=epochs,
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
print_freq=epochs/4,
unsup_loss=1,
hp_opt=False,
save_sample_freq=None)
exec_time=time.perf_counter() - t0
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]),
"Time": (np.mean(times),np.std(times), exec_time),
'Optimizer': optim_param,
"Device": device_name,
"Memory": [max_allocated, max_cached],
"TF_config": tf_config,
"Param_names": aug_model.TF_names(),
"Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs (dataug:{})- {} in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter, run)
with open(res_folder+"log/%s.json" % filename, "w+") as f:
try:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
except:
print("Failed to save logs :",f.name)
try:
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names())
except:
print("Failed to plot res")
print(sys.exc_info()[1])
print('Execution Time : %.00f '%(exec_time))
print('-'*9)
'''
### Benchmark - RandAugment/Vanilla ###
#'''
for model_type in model_list.keys():
for model_name in model_list[model_type]:
for run in range(nb_run):
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
torch.cuda.reset_max_memory_cached() #reset_peak_stats
t0 = time.perf_counter()
model = getattr(model_type, model_name)(pretrained=False, num_classes=len(dl_train.dataset.classes)).to(device)
print("{} on {} for {} epochs".format(model_name, device_name, epochs))
#print("RandAugment(N{}-M{:.2f})-{} on {} for {} epochs".format(rand_aug['N'],rand_aug['M'],model_name, device_name, epochs))
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=epochs/4)
exec_time=time.perf_counter() - t0
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]),
"Time": (np.mean(times),np.std(times), exec_time),
'Optimizer': optim_param,
"Device": device_name,
"Memory": [max_allocated, max_cached],
#"Rand_Aug": rand_aug,
"Log": log}
print(model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs -{}-basicDA".format(model_name,epochs, run)
#print("RandAugment-",model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
#filename = "RandAugment(N{}-M{:.2f})-{}-{} epochs -{}".format(rand_aug['N'],rand_aug['M'],model_name,epochs, run)
with open(res_folder+"log/%s.json" % filename, "w+") as f:
try:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
except:
print("Failed to save logs :",f.name)
print(sys.exc_info()[1])
try:
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names())
except:
print("Failed to plot res")
print(sys.exc_info()[1])
print('Execution Time : %.00f '%(exec_time))
print('-'*9)
#'''
### HP Search ###
'''
from LeNet import *
inner_its = [1]
dist_mix = [1.0]#[0.0, 0.5, 0.8, 1.0]
N_seq_TF= [5, 6]
mag_setup = [(True, True), (False, False)] #(FxSh, Independant)
#prob_setup = [True, False]
try:
os.mkdir(res_folder)
os.mkdir(res_folder+"log/")
except FileExistsError:
pass
for n_inner_iter in inner_its:
for n_tf in N_seq_TF:
for dist in dist_mix:
#for i in TF_nb:
for m_setup in mag_setup:
#for p_setup in prob_setup:
p_setup=False
for run in range(nb_run):
t0 = time.perf_counter()
model = getattr(models.resnet, 'resnet18')(pretrained=False, num_classes=len(dl_train.dataset.classes))
#model = LeNet(3,10)
model = Higher_model(model) #run_dist_dataugV3
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
log= run_dist_dataugV3(model=aug_model,
epochs=epochs,
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
print_freq=epochs/4,
unsup_loss=1,
hp_opt=False,
save_sample_freq=None)
exec_time=time.perf_counter() - t0
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs (dataug:{})- {} in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter, run)
with open(res_folder+"log/%s.json" % filename, "w+") as f:
try:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
except:
print("Failed to save logs :",f.name)
print('Execution Time : %.00f '%(exec_time))
print('-'*9)
'''

View file

@ -0,0 +1,220 @@
""" Dataset definition.
MNIST / CIFAR10 / CIFAR100 / SVHN / ImageNet
"""
import os
import torch
from torch.utils.data.dataset import ConcatDataset
import torchvision
from arg_parser import *
args = parser.parse_args()
#Wether to download data.
download_data=False
#Pin GPU memory
pin_memory=False #True :+ GPU memory / + Lent
#Data storage folder
dataroot=args.dataroot
# if args.dtype == 'FP32':
# def_type=torch.float32
# elif args.dtype == 'FP16':
# # def_type=torch.float16 #Default : float32
# def_type=torch.bfloat16
# else:
# raise Exception('dtype not supported :', args.dtype)
#ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1]
transform = [
#torchvision.transforms.Grayscale(3), #MNIST
#torchvision.transforms.Resize((224,224), interpolation=2)#VGG
torchvision.transforms.ToTensor(),
#torchvision.transforms.Normalize(MEAN, STD), #CIFAR10
# torchvision.transforms.Lambda(lambda tensor: tensor.to(def_type)),
]
transform_train = [
#transforms.RandomHorizontalFlip(),
#transforms.RandomVerticalFlip(),
#torchvision.transforms.Grayscale(3), #MNIST
#torchvision.transforms.Resize((224,224), interpolation=2)
torchvision.transforms.ToTensor(),
#torchvision.transforms.Normalize(MEAN, STD), #CIFAR10
# torchvision.transforms.Lambda(lambda tensor: tensor.to(def_type)),
]
## RandAugment ##
#from RandAugment import RandAugment
# Add RandAugment with N, M(hyperparameter)
#rand_aug={'N': 2, 'M': 1}
#rand_aug={'N': 2, 'M': 9./30} #RN-ImageNet
#rand_aug={'N': 3, 'M': 5./30} #WRN-CIFAR10
#rand_aug={'N': 2, 'M': 14./30} #WRN-CIFAR100
#rand_aug={'N': 3, 'M': 7./30} #WRN-SVHN
#transform_train.transforms.insert(0, RandAugment(n=rand_aug['N'], m=rand_aug['M']))
### Classic Dataset ###
BATCH_SIZE = args.batch_size
TEST_SIZE = BATCH_SIZE
# Load Dataset
if args.dataset == 'MNIST':
transform_train.insert(0, torchvision.transforms.Grayscale(3))
transform.insert(0, torchvision.transforms.Grayscale(3))
val_set=False
data_train = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=torchvision.transforms.Compose(transform_train))
data_val = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=torchvision.transforms.Compose(transform))
data_test = torchvision.datasets.MNIST(dataroot, train=False, download=True, transform=torchvision.transforms.Compose(transform))
elif args.dataset == 'CIFAR10': #(32x32 RGB)
val_set=False
MEAN=(0.4914, 0.4822, 0.4465)
STD=(0.2023, 0.1994, 0.2010)
data_train = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform_train))
data_val = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform))
data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=download_data, transform=torchvision.transforms.Compose(transform))
elif args.dataset == 'CIFAR100': #(32x32 RGB)
val_set=False
MEAN=(0.4914, 0.4822, 0.4465)
STD=(0.2023, 0.1994, 0.2010)
data_train = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform_train))
data_val = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform))
data_test = torchvision.datasets.CIFAR100(dataroot, train=False, download=download_data, transform=torchvision.transforms.Compose(transform))
elif args.dataset == 'TinyImageNet': #(Train:100k, Val:5k, Test:5k) (64x64 RGB)
image_size=64 #128 / 224
print('Using image size', image_size)
transform_train=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform_train
transform=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform
val_set=True
MEAN=(0.485, 0.456, 0.406)
STD=(0.229, 0.224, 0.225)
data_train = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/train'), transform=torchvision.transforms.Compose(transform_train))
data_val = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/val'), transform=torchvision.transforms.Compose(transform))
data_test = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/test'), transform=torchvision.transforms.Compose(transform))
elif args.dataset == 'ImageNet': #
image_size=128 #224
print('Using image size', image_size)
transform_train=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform_train
transform=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform
val_set=False
MEAN=(0.485, 0.456, 0.406)
STD=(0.229, 0.224, 0.225)
data_train = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/train'), transform=torchvision.transforms.Compose(transform_train))
data_val = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/train'), transform=torchvision.transforms.Compose(transform))
data_test = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/validation'), transform=torchvision.transforms.Compose(transform))
else:
raise Exception('Unknown dataset')
# Ready dataloader
if not val_set : #Split Training set into Train/Val
#Validation set size [0, 1]
valid_size=0.1
train_subset_indices=range(int(len(data_train)*(1-valid_size)))
val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train))
#train_subset_indices=range(BATCH_SIZE*10)
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
from torch.utils.data import SubsetRandomSampler
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=args.workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=args.workers, pin_memory=pin_memory)
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=args.workers, pin_memory=pin_memory)
else:
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.workers, pin_memory=pin_memory)
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=args.workers, pin_memory=pin_memory)
#SVHN
#trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=download_data, transform=transform_train)
#extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=download_data, transform=transform_train)
#data_train = ConcatDataset([trainset, extraset])
#data_test = torchvision.datasets.SVHN(dataroot, split='test', download=download_data, transform=transform)
#ImageNet
#Necessite SciPy
# Probleme ? : https://github.com/ildoonet/pytorch-randaugment/blob/48b8f509c4bbda93bbe733d98b3fd052b6e4c8ae/RandAugment/imagenet.py#L28
#data_train = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='train', transform=transform_train)
#data_test = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)
#Cross Validation
'''
import numpy as np
from sklearn.model_selection import ShuffleSplit
from sklearn.model_selection import StratifiedShuffleSplit
class CVSplit(object):
"""Class that perform train/valid split on a dataset.
Inspired from : https://skorch.readthedocs.io/en/latest/user/dataset.html
Attributes:
_stratified (bool): Wether the split should be stratified. Recommended to be True for unbalanced dataset.
_val_size (float, int): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split.
If int, represents the absolute number of validation samples.
_data (Dataset): Dataset to split.
_targets (np.array): Targets of the dataset used if _stratified is set to True.
_cv (BaseShuffleSplit) : Scikit learn object used to split.
"""
def __init__(self, data, val_size=0.1, stratified=True):
""" Intialize CVSplit.
Args:
data (Dataset): Dataset to split.
val_size (float, int): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split.
If int, represents the absolute number of validation samples. (Default: 0.1)
stratified (bool): Wether the split should be stratified. Recommended to be True for unbalanced dataset.
"""
self._stratified=stratified
self._val_size=val_size
self._data=data
if self._stratified:
cv_cls = StratifiedShuffleSplit
self._targets= np.array(data_train.targets)
else:
cv_cls = ShuffleSplit
self._cv= cv_cls(test_size=val_size, random_state=0) #Random state w/ fixed seed
def next_split(self):
""" Get next cross-validation split.
Returns:
Train DataLoader, Validation DataLoader
"""
args=(np.arange(len(self._data)),)
if self._stratified:
args = args + (self._targets,)
idx_train, idx_valid = next(iter(self._cv.split(*args)))
train_subset = torch.utils.data.Subset(self._data, idx_train)
val_subset = torch.utils.data.Subset(self._data, idx_valid)
dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
return dl_train, dl_val
cvs = CVSplit(data_train, val_size=valid_size)
dl_train, dl_val = cvs.next_split()
'''
'''
from skorch.dataset import CVSplit
import numpy as np
cvs = CVSplit(cv=valid_size, stratified=True) #Stratified =True for unbalanced dataset #ShuffleSplit
def next_CVSplit():
train_subset, val_subset = cvs(data_train, y=np.array(data_train.targets))
dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
return dl_train, dl_val
dl_train, dl_val = next_CVSplit()
'''

1263
higher/smart_aug/dataug.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,31 @@
""" Patch for Higher package
Recommended use ::
import higher
import higher_patch
Might become unnecessary with future update of the Higher package.
"""
import higher
import torch as _torch
def detach_(self):
"""Removes all params from their compute graph in place.
"""
# detach param groups
for group in self.param_groups:
for k, v in group.items():
if isinstance(v,_torch.Tensor):
v.detach_().requires_grad_()
# detach state
for state_dict in self.state:
for k,v_dict in state_dict.items():
if isinstance(k,_torch.Tensor): k.detach_().requires_grad_()
for k2,v2 in v_dict.items():
if isinstance(v2,_torch.Tensor):
v2.detach_().requires_grad_()
higher.optim.DifferentiableOptimizer.detach_ = detach_

View file

@ -0,0 +1,56 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
## Basic CNN ##
class LeNet(nn.Module):
"""Basic CNN.
"""
def __init__(self, num_inp, num_out):
"""Init LeNet.
"""
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(num_inp, 20, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(20, 50, 5)
self.pool2 = nn.MaxPool2d(2, 2)
#self.fc1 = nn.Linear(4*4*50, 500)
self.fc1 = nn.Linear(5*5*50, 500)
self.fc2 = nn.Linear(500, num_out)
def forward(self, x):
"""Main method of LeNet
"""
x = self.pool(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def __str__(self):
""" Get name of model
"""
return "LeNet"
#MNIST
class MLPNet(nn.Module):
def __init__(self):
super(MLPNet, self).__init__()
self.fc1 = nn.Linear(28*28, 500)
self.fc2 = nn.Linear(500, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def name(self):
return "MLP"

View file

@ -0,0 +1,426 @@
import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
# __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
# 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
# 'wide_resnet50_2', 'wide_resnet101_2']
# model_urls = {
# 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
# 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
# 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
# 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
# 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
# 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
# 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
# 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
# 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
# }
__all__ = ['ResNet_ABN', 'resnet18_ABN', 'resnet34_ABN', 'resnet50_ABN', 'resnet101_ABN',
'resnet152_ABN', 'resnext50_32x4d_ABN', 'resnext101_32x8d_ABN',
'wide_resnet50_2_ABN', 'wide_resnet101_2_ABN']
class aux_batchNorm(nn.Module):
def __init__(self, norm_layer, nb_features):
super(aux_batchNorm, self).__init__()
self.mode='clean'
self.bn=nn.ModuleDict({
'clean': norm_layer(nb_features),
'augmented': norm_layer(nb_features)
})
def forward(self, x):
if self.mode is 'mixed':
running_mean=(self.bn['clean'].running_mean+self.bn['augmented'].running_mean)/2
running_var=(self.bn['clean'].running_var+self.bn['augmented'].running_var)/2
return nn.functional.batch_norm(x, running_mean, running_var, self.bn['clean'].weight, self.bn['clean'].bias)
return self.bn[self.mode](x)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock_ABN(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock_ABN, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
#self.bn1 = norm_layer(planes)
self.bn1 = aux_batchNorm(norm_layer, planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
#self.bn2 = norm_layer(planes)
self.bn2 = aux_batchNorm(norm_layer, planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck_ABN(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck_ABN, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
#self.bn1 = norm_layer(width)
self.bn1 = aux_batchNorm(norm_layer, width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
# self.bn2 = norm_layer(width)
self.bn2 = aux_batchNorm(norm_layer, width)
self.conv3 = conv1x1(width, planes * self.expansion)
# self.bn3 = norm_layer(planes * self.expansion)
self.bn3 = aux_batchNorm(norm_layer, planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet_ABN(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet_ABN, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
#self.bn1 = norm_layer(self.inplanes)
self.bn1 = aux_batchNorm(norm_layer, self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
print('WARNING : zero_init_residual not implemented with ABN')
# for m in self.modules():
# if isinstance(m, Bottleneck):
# nn.init.constant_(m.bn3.weight, 0)
# elif isinstance(m, BasicBlock):
# nn.init.constant_(m.bn2.weight, 0)
# Memoire des BN layers pas fonctinnel avec Higher
# self.bn_layers=[]
# for m in self.modules():
# if isinstance(m, aux_batchNorm):
# self.bn_layers.append(m)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
#norm_layer(planes * block.expansion),
aux_batchNorm(norm_layer, planes * block.expansion)
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def set_mode(self, mode):
# for bn in self.bn_layers:
for m in self.modules():
if isinstance(m, aux_batchNorm):
m.mode=mode
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
# model = ResNet(block, layers, **kwargs)
# if pretrained:
# state_dict = load_state_dict_from_url(model_urls[arch],
# progress=progress)
# model.load_state_dict(state_dict)
# return model
def resnet18_ABN(pretrained=False, progress=True, **kwargs):
"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
# **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(BasicBlock_ABN, [2, 2, 2, 2], **kwargs)
def resnet34_ABN(pretrained=False, progress=True, **kwargs):
"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
# **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(BasicBlock_ABN, [3, 4, 6, 3], **kwargs)
def resnet50_ABN(pretrained=False, progress=True, **kwargs):
"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
# **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
def resnet101_ABN(pretrained=False, progress=True, **kwargs):
"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
# **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
def resnet152_ABN(pretrained=False, progress=True, **kwargs):
"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
# **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 8, 36, 3], **kwargs)
def resnext50_32x4d_ABN(pretrained=False, progress=True, **kwargs):
"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
# return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
# pretrained, progress, **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
def resnext101_32x8d_ABN(pretrained=False, progress=True, **kwargs):
"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
# return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
# pretrained, progress, **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
def wide_resnet50_2_ABN(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
# return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
# pretrained, progress, **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
def wide_resnet101_2_ABN(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
# return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
# pretrained, progress, **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)

View file

@ -0,0 +1,618 @@
'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
https://github.com/yechengxi/deconvolution
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import conv
from torch.nn.modules.utils import _pair
from functools import partial
__all__ = ['ResNet18_DC', 'ResNet34_DC', 'ResNet50_DC', 'ResNet101_DC', 'ResNet152_DC', 'WRN_DC26_10']
### Deconvolution ###
#iteratively solve for inverse sqrt of a matrix
def isqrt_newton_schulz_autograd(A, numIters):
dim = A.shape[0]
normA=A.norm()
Y = A.div(normA)
I = torch.eye(dim,dtype=A.dtype,device=A.device)
Z = torch.eye(dim,dtype=A.dtype,device=A.device)
for i in range(numIters):
T = 0.5*(3.0*I - Z@Y)
Y = Y@T
Z = T@Z
#A_sqrt = Y*torch.sqrt(normA)
A_isqrt = Z / torch.sqrt(normA)
return A_isqrt
def isqrt_newton_schulz_autograd_batch(A, numIters):
batchSize,dim,_ = A.shape
normA=A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1)
Y = A.div(normA)
I = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
Z = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
for i in range(numIters):
T = 0.5*(3.0*I - Z.bmm(Y))
Y = Y.bmm(T)
Z = T.bmm(Z)
#A_sqrt = Y*torch.sqrt(normA)
A_isqrt = Z / torch.sqrt(normA)
return A_isqrt
#deconvolve channels
class ChannelDeconv(nn.Module):
def __init__(self, block, eps=1e-2,n_iter=5,momentum=0.1,sampling_stride=3):
super(ChannelDeconv, self).__init__()
self.eps = eps
self.n_iter=n_iter
self.momentum=momentum
self.block = block
self.register_buffer('running_mean1', torch.zeros(block, 1))
#self.register_buffer('running_cov', torch.eye(block))
self.register_buffer('running_deconv', torch.eye(block))
self.register_buffer('running_mean2', torch.zeros(1, 1))
self.register_buffer('running_var', torch.ones(1, 1))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.sampling_stride=sampling_stride
def forward(self, x):
x_shape = x.shape
if len(x.shape)==2:
x=x.view(x.shape[0],x.shape[1],1,1)
if len(x.shape)==3:
print('Error! Unsupprted tensor shape.')
N, C, H, W = x.size()
B = self.block
#take the first c channels out for deconv
c=int(C/B)*B
if c==0:
print('Error! block should be set smaller.')
#step 1. remove mean
if c!=C:
x1=x[:,:c].permute(1,0,2,3).contiguous().view(B,-1)
else:
x1=x.permute(1,0,2,3).contiguous().view(B,-1)
if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride:
x1_s = x1[:,::self.sampling_stride**2]
else:
x1_s=x1
mean1 = x1_s.mean(-1, keepdim=True)
if self.num_batches_tracked==0:
self.running_mean1.copy_(mean1.detach())
if self.training:
self.running_mean1.mul_(1-self.momentum)
self.running_mean1.add_(mean1.detach()*self.momentum)
else:
mean1 = self.running_mean1
x1=x1-mean1
#step 2. calculate deconv@x1 = cov^(-0.5)@x1
if self.training:
cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(B, dtype=x.dtype, device=x.device)
deconv = isqrt_newton_schulz_autograd(cov, self.n_iter)
if self.num_batches_tracked==0:
#self.running_cov.copy_(cov.detach())
self.running_deconv.copy_(deconv.detach())
if self.training:
#self.running_cov.mul_(1-self.momentum)
#self.running_cov.add_(cov.detach()*self.momentum)
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
# cov = self.running_cov
deconv = self.running_deconv
x1 =deconv@x1
#reshape to N,c,J,W
x1 = x1.view(c, N, H, W).contiguous().permute(1,0,2,3)
# normalize the remaining channels
if c!=C:
x_tmp=x[:, c:].view(N,-1)
if self.sampling_stride > 1 and H>=self.sampling_stride and W>=self.sampling_stride:
x_s = x_tmp[:, ::self.sampling_stride ** 2]
else:
x_s = x_tmp
mean2=x_s.mean()
var=x_s.var()
if self.num_batches_tracked == 0:
self.running_mean2.copy_(mean2.detach())
self.running_var.copy_(var.detach())
if self.training:
self.running_mean2.mul_(1 - self.momentum)
self.running_mean2.add_(mean2.detach() * self.momentum)
self.running_var.mul_(1 - self.momentum)
self.running_var.add_(var.detach() * self.momentum)
else:
mean2 = self.running_mean2
var = self.running_var
x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt()
x1 = torch.cat([x1, x_tmp], dim=1)
if self.training:
self.num_batches_tracked.add_(1)
if len(x_shape)==2:
x1=x1.view(x_shape)
return x1
#An alternative implementation
class Delinear(nn.Module):
__constants__ = ['bias', 'in_features', 'out_features']
def __init__(self, in_features, out_features, bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=512):
super(Delinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
if block > in_features:
block = in_features
else:
if in_features%block!=0:
block=math.gcd(block,in_features)
print('block size set to:', block)
self.block = block
self.momentum = momentum
self.n_iter = n_iter
self.eps = eps
self.register_buffer('running_mean', torch.zeros(self.block))
self.register_buffer('running_deconv', torch.eye(self.block))
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
if self.training:
# 1. reshape
X=input.view(-1, self.block)
# 2. subtract mean
X_mean = X.mean(0)
X = X - X_mean.unsqueeze(0)
self.running_mean.mul_(1 - self.momentum)
self.running_mean.add_(X_mean.detach() * self.momentum)
# 3. calculate COV, COV^(-0.5), then deconv
# Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Id = torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
# track stats for evaluation
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
X_mean = self.running_mean
deconv = self.running_deconv
w = self.weight.view(-1, self.block) @ deconv
b = self.bias
if self.bias is not None:
b = b - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
w = w.view(self.weight.shape)
return F.linear(input, w, b)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
class FastDeconv(conv._ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,freeze=False,freeze_iter=100):
self.momentum = momentum
self.n_iter = n_iter
self.eps = eps
self.counter=0
self.track_running_stats=True
super(FastDeconv, self).__init__(
in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
False, _pair(0), groups, bias, padding_mode='zeros')
if block > in_channels:
block = in_channels
else:
if in_channels%block!=0:
block=math.gcd(block,in_channels)
if groups>1:
#grouped conv
block=in_channels//groups
self.block=block
self.num_features = kernel_size**2 *block
if groups==1:
self.register_buffer('running_mean', torch.zeros(self.num_features))
self.register_buffer('running_deconv', torch.eye(self.num_features))
else:
self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels))
self.register_buffer('running_deconv', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))
self.sampling_stride=sampling_stride*stride
self.counter=0
self.freeze_iter=freeze_iter
self.freeze=freeze
def forward(self, x):
N, C, H, W = x.shape
B = self.block
frozen=self.freeze and (self.counter>self.freeze_iter)
if self.training and self.track_running_stats:
self.counter+=1
self.counter %= (self.freeze_iter * 10)
if self.training and (not frozen):
# 1. im2col: N x cols x pixels -> N*pixles x cols
if self.kernel_size[0]>1:
X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
else:
#channel wise
X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]
if self.groups==1:
# (C//B*N*pixels,k*k*B)
X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
else:
X=X.view(-1,X.shape[-1])
# 2. subtract mean
X_mean = X.mean(0)
X = X - X_mean.unsqueeze(0)
# 3. calculate COV, COV^(-0.5), then deconv
if self.groups==1:
#Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Id=torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
else:
X = X.view(-1, self.groups, self.num_features).transpose(0, 1)
Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
Cov = torch.baddbmm(self.eps, Id, 1. / X.shape[1], X.transpose(1, 2), X)
deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter)
if self.track_running_stats:
self.running_mean.mul_(1 - self.momentum)
self.running_mean.add_(X_mean.detach() * self.momentum)
# track stats for evaluation
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
X_mean = self.running_mean
deconv = self.running_deconv
#4. X * deconv * conv = X * (deconv * conv)
if self.groups==1:
w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ deconv
b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
else:
w = self.weight.view(C//B, -1,self.num_features)@deconv
b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)
w = w.view(self.weight.shape)
x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)
return x
### ResNet
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, deconv=None):
super(BasicBlock, self).__init__()
if deconv:
self.conv1 = deconv(in_planes, planes, kernel_size=3, stride=stride, padding=1)
self.conv2 = deconv(planes, planes, kernel_size=3, stride=1, padding=1)
self.deconv = True
else:
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.deconv = False
self.shortcut = nn.Sequential()
if not deconv:
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
#self.bn1 = nn.GroupNorm(planes//16,planes)
#self.bn2 = nn.GroupNorm(planes//16,planes)
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
#nn.GroupNorm(self.expansion * planes//16,self.expansion * planes)
)
else:
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
deconv(in_planes, self.expansion*planes, kernel_size=1, stride=stride)
)
def forward(self, x):
if self.deconv:
out = F.relu(self.conv1(x))
out = self.conv2(out)
out += self.shortcut(x)
out = F.relu(out)
return out
else: #self.batch_norm:
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, deconv=None):
super(Bottleneck, self).__init__()
if deconv:
self.deconv = True
self.conv1 = deconv(in_planes, planes, kernel_size=1)
self.conv2 = deconv(planes, planes, kernel_size=3, stride=stride, padding=1)
self.conv3 = deconv(planes, self.expansion*planes, kernel_size=1)
else:
self.deconv = False
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.shortcut = nn.Sequential()
if not deconv:
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
else:
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
deconv(in_planes, self.expansion * planes, kernel_size=1, stride=stride)
)
def forward(self, x):
"""
No batch normalization for deconv.
"""
if self.deconv:
out = F.relu((self.conv1(x)))
out = F.relu((self.conv2(out)))
out = self.conv3(out)
out += self.shortcut(x)
out = F.relu(out)
return out
else:
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10, deconv=None,channel_deconv=None):
super(ResNet, self).__init__()
self.in_planes = 64
if deconv:
self.deconv = True
self.conv1 = deconv(3, 64, kernel_size=3, stride=1, padding=1)
else:
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
if not deconv:
self.bn1 = nn.BatchNorm2d(64)
#this line is really recent, take extreme care if the result is not good.
if channel_deconv:
self.deconv1=channel_deconv()
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, deconv=deconv)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, deconv=deconv)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, deconv=deconv)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, deconv=deconv)
self.linear = nn.Linear(512*block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride, deconv):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride, deconv))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
if hasattr(self,'bn1'):
out = F.relu(self.bn1(self.conv1(x)))
else:
out = F.relu(self.conv1(x))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
if hasattr(self, 'deconv1'):
out = self.deconv1(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def_deconv = partial(FastDeconv,bias=True, eps=1e-5, n_iter=5,block=64,sampling_stride=3)
#channel_deconv=partial(ChannelDeconv, block=512,eps=1e-5, n_iter=5,sampling_stride=3) #Pas forcément conseillé
def ResNet18_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet34_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet50_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet101_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet152_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
import math
class Wide_ResNet_Cifar_DC(nn.Module):
def __init__(self, block, layers, wfactor, num_classes=10, deconv=None, channel_deconv=None):
super(Wide_ResNet_Cifar_DC, self).__init__()
self.depth=layers[0]*6+2
self.widen_factor=wfactor
self.inplanes = 16
self.conv1 = deconv(3, 16, kernel_size=3, stride=1, padding=1)
if channel_deconv:
self.deconv1=channel_deconv()
# self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
# self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16*wfactor, layers[0], stride=1, deconv=deconv)
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2, deconv=deconv)
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2, deconv=deconv)
self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, num_blocks, stride, deconv):
# downsample = None
# if stride != 1 or self.inplanes != planes * block.expansion:
# downsample = nn.Sequential(
# nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(planes * block.expansion)
# )
# layers = []
# layers.append(block(self.inplanes, planes, stride, downsample))
# self.inplanes = planes * block.expansion
# for _ in range(1, blocks):
# layers.append(block(self.inplanes, planes))
# return nn.Sequential(*layers)
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.inplanes, planes, stride, deconv))
self.inplanes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
# x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if hasattr(self, 'deconv1'):
out = self.deconv1(out)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def __str__(self):
""" Get name of model
"""
return "Wide_ResNet_cifar_DC%d_%d"%(self.depth,self.widen_factor)
def WRN_DC26_10(depth=26, width=10, deconv=def_deconv, channel_deconv=None, **kwargs):
assert (depth - 2) % 6 == 0
n = int((depth - 2) / 6)
return Wide_ResNet_Cifar_DC(BasicBlock, [n, n, n], width, deconv=deconv,channel_deconv=channel_deconv, **kwargs)
def test():
net = ResNet18_DC()
y = net(torch.randn(1,3,32,32))
print(y.size())
# test()

View file

@ -0,0 +1,98 @@
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import numpy as np
_bn_momentum = 0.1
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
def conv_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.xavier_uniform_(m.weight, gain=np.sqrt(2))
init.constant_(m.bias, 0)
elif classname.find('BatchNorm') != -1:
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
class WideBasic(nn.Module):
def __init__(self, in_planes, planes, dropout_rate, stride=1):
super(WideBasic, self).__init__()
assert dropout_rate==0.0, 'dropout layer not used'
self.bn1 = nn.BatchNorm2d(in_planes, momentum=_bn_momentum)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
#self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(planes, momentum=_bn_momentum)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
)
def forward(self, x):
# out = self.dropout(self.conv1(F.relu(self.bn1(x))))
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out += self.shortcut(x)
return out
class WideResNet(nn.Module):
def __init__(self, depth, widen_factor, dropout_rate, num_classes):
super(WideResNet, self).__init__()
self.depth=depth
self.widen_factor=widen_factor
self.in_planes = 16
assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
n = int((depth - 4) / 6)
k = widen_factor
nStages = [16, 16*k, 32*k, 64*k]
self.conv1 = conv3x3(3, nStages[0])
self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=_bn_momentum)
self.linear = nn.Linear(nStages[3], num_classes)
# self.apply(conv_init)
def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, dropout_rate, stride))
self.in_planes = planes
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
# out = F.avg_pool2d(out, 8)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def __str__(self):
""" Get name of model
"""
return "Wide_ResNet%d_%d"%(self.depth,self.widen_factor)

View file

@ -0,0 +1,119 @@
"""
wide resnet for cifar in pytorch
Reference:
[1] S. Zagoruyko and N. Komodakis. Wide residual networks. In BMVC, 2016.
"""
import torch
import torch.nn as nn
import math
#from models.resnet_cifar import BasicBlock
def conv3x3(in_planes, out_planes, stride=1):
" 3x3 convolution with padding "
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion=1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Wide_ResNet_Cifar(nn.Module):
def __init__(self, block, layers, wfactor, num_classes=10):
super(Wide_ResNet_Cifar, self).__init__()
self.depth=layers[0]*6+2
self.widen_factor=wfactor
self.inplanes = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16*wfactor, layers[0])
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2)
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2)
self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion)
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def __str__(self):
""" Get name of model
"""
return "Wide_ResNet_cifar%d_%d"%(self.depth,self.widen_factor)
def wide_resnet_cifar(depth, width, **kwargs):
assert (depth - 2) % 6 == 0
n = int((depth - 2) / 6)
return Wide_ResNet_Cifar(BasicBlock, [n, n, n], width, **kwargs)
if __name__=='__main__':
net = wide_resnet_cifar(20, 10)
y = net(torch.randn(1, 3, 32, 32))
print(isinstance(net, Wide_ResNet_Cifar))
print(y.size())

View file

@ -0,0 +1,490 @@
# coding=utf-8
# Copyright 2019 The Google UDA Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transforms used in the Augmentation Policies.
Copied from AutoAugment: https://github.com/tensorflow/models/blob/master/research/autoaugment/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
# pylint:disable=g-multiple-import
from PIL import ImageOps, ImageEnhance, ImageFilter, Image
# pylint:enable=g-multiple-import
#import tensorflow as tf
#FLAGS = tf.flags.FLAGS
IMAGE_SIZE = 32
# What is the dataset mean and std of the images on the training set
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
def get_mean_and_std():
#if FLAGS.task_name == "cifar10":
means = [0.49139968, 0.48215841, 0.44653091]
stds = [0.24703223, 0.24348513, 0.26158784]
#elif FLAGS.task_name == "svhn":
# means = [0.4376821, 0.4437697, 0.47280442]
# stds = [0.19803012, 0.20101562, 0.19703614]
#else:
# assert False
return means, stds
def random_flip(x):
"""Flip the input x horizontally with 50% probability."""
if np.random.rand(1)[0] > 0.5:
return np.fliplr(x)
return x
def zero_pad_and_crop(img, amount=4):
"""Zero pad by `amount` zero pixels on each side then take a random crop.
Args:
img: numpy image that will be zero padded and cropped.
amount: amount of zeros to pad `img` with horizontally and verically.
Returns:
The cropped zero padded img. The returned numpy array will be of the same
shape as `img`.
"""
padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2,
img.shape[2]))
padded_img[amount:img.shape[0] + amount, amount:
img.shape[1] + amount, :] = img
top = np.random.randint(low=0, high=2 * amount)
left = np.random.randint(low=0, high=2 * amount)
new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :]
return new_img
def create_cutout_mask(img_height, img_width, num_channels, size):
"""Creates a zero mask used for cutout of shape `img_height` x `img_width`.
Args:
img_height: Height of image cutout mask will be applied to.
img_width: Width of image cutout mask will be applied to.
num_channels: Number of channels in the image.
size: Size of the zeros mask.
Returns:
A mask of shape `img_height` x `img_width` with all ones except for a
square of zeros of shape `size` x `size`. This mask is meant to be
elementwise multiplied with the original image. Additionally returns
the `upper_coord` and `lower_coord` which specify where the cutout mask
will be applied.
"""
assert img_height == img_width
# Sample center where cutout mask will be applied
height_loc = np.random.randint(low=0, high=img_height)
width_loc = np.random.randint(low=0, high=img_width)
# Determine upper right and lower left corners of patch
upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
lower_coord = (min(img_height, height_loc + size // 2),
min(img_width, width_loc + size // 2))
mask_height = lower_coord[0] - upper_coord[0]
mask_width = lower_coord[1] - upper_coord[1]
assert mask_height > 0
assert mask_width > 0
mask = np.ones((img_height, img_width, num_channels))
zeros = np.zeros((mask_height, mask_width, num_channels))
mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = (
zeros)
return mask, upper_coord, lower_coord
def cutout_numpy(img, size=16):
"""Apply cutout with mask of shape `size` x `size` to `img`.
The cutout operation is from the paper https://arxiv.org/abs/1708.04552.
This operation applies a `size`x`size` mask of zeros to a random location
within `img`.
Args:
img: Numpy image that cutout will be applied to.
size: Height/width of the cutout mask that will be
Returns:
A numpy tensor that is the result of applying the cutout mask to `img`.
"""
img_height, img_width, num_channels = (img.shape[0], img.shape[1],
img.shape[2])
assert len(img.shape) == 3
mask, _, _ = create_cutout_mask(img_height, img_width, num_channels, size)
return img * mask
def float_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
A float that results from scaling `maxval` according to `level`.
"""
return float(level) * maxval / PARAMETER_MAX
def int_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
An int that results from scaling `maxval` according to `level`.
"""
return int(level * maxval / PARAMETER_MAX)
def pil_wrap(img, use_mean_std):
"""Convert the `img` numpy tensor to a PIL Image."""
if use_mean_std:
MEANS, STDS = get_mean_and_std()
else:
MEANS = [0, 0, 0]
STDS = [1, 1, 1]
img_ori = (img * STDS + MEANS) * 255
return Image.fromarray(
np.uint8((img * STDS + MEANS) * 255.0)).convert('RGBA')
def pil_unwrap(pil_img, use_mean_std, img_shape):
"""Converts the PIL img to a numpy array."""
if use_mean_std:
MEANS, STDS = get_mean_and_std()
else:
MEANS = [0, 0, 0]
STDS = [1, 1, 1]
pic_array = np.array(pil_img.getdata()).reshape((img_shape[0], img_shape[1], 4)) / 255.0
i1, i2 = np.where(pic_array[:, :, 3] == 0)
pic_array = (pic_array[:, :, :3] - MEANS) / STDS
pic_array[i1, i2] = [0, 0, 0]
return pic_array
def apply_policy(policy, img, use_mean_std=True):
"""Apply the `policy` to the numpy `img`.
Args:
policy: A list of tuples with the form (name, probability, level) where
`name` is the name of the augmentation operation to apply, `probability`
is the probability of applying the operation and `level` is what strength
the operation to apply.
img: Numpy image that will have `policy` applied to it.
Returns:
The result of applying `policy` to `img`.
"""
img_shape = img.shape
pil_img = pil_wrap(img, use_mean_std)
for xform in policy:
assert len(xform) == 3
name, probability, level = xform
xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(
probability, level, img_shape)
pil_img = xform_fn(pil_img)
return pil_unwrap(pil_img, use_mean_std, img_shape)
class TransformFunction(object):
"""Wraps the Transform function for pretty printing options."""
def __init__(self, func, name):
self.f = func
self.name = name
def __repr__(self):
return '<' + self.name + '>'
def __call__(self, pil_img):
return self.f(pil_img)
class TransformT(object):
"""Each instance of this class represents a specific transform."""
def __init__(self, name, xform_fn):
self.name = name
self.xform = xform_fn
def pil_transformer(self, probability, level, img_shape):
def return_function(im):
if random.random() < probability:
im = self.xform(im, level, img_shape)
return im
name = self.name + '({:.1f},{})'.format(probability, level)
return TransformFunction(return_function, name)
################## Transform Functions ##################
identity = TransformT('identity', lambda pil_img, level, _: pil_img)
flip_lr = TransformT(
'FlipLR',
lambda pil_img, level, _: pil_img.transpose(Image.FLIP_LEFT_RIGHT))
flip_ud = TransformT(
'FlipUD',
lambda pil_img, level, _: pil_img.transpose(Image.FLIP_TOP_BOTTOM))
# pylint:disable=g-long-lambda
auto_contrast = TransformT(
'AutoContrast',
lambda pil_img, level, _: ImageOps.autocontrast(
pil_img.convert('RGB')).convert('RGBA'))
equalize = TransformT(
'Equalize',
lambda pil_img, level, _: ImageOps.equalize(
pil_img.convert('RGB')).convert('RGBA'))
invert = TransformT(
'Invert',
lambda pil_img, level, _: ImageOps.invert(
pil_img.convert('RGB')).convert('RGBA'))
# pylint:enable=g-long-lambda
blur = TransformT(
'Blur', lambda pil_img, level, _: pil_img.filter(ImageFilter.BLUR))
smooth = TransformT(
'Smooth',
lambda pil_img, level, _: pil_img.filter(ImageFilter.SMOOTH))
def _rotate_impl(pil_img, level, _):
"""Rotates `pil_img` from -30 to 30 degrees depending on `level`."""
degrees = int_parameter(level, 30)
if random.random() > 0.5:
degrees = -degrees
return pil_img.rotate(degrees)
rotate = TransformT('Rotate', _rotate_impl)
def _posterize_impl(pil_img, level, _):
"""Applies PIL Posterize to `pil_img`."""
level = int_parameter(level, 4)
return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA')
posterize = TransformT('Posterize', _posterize_impl)
def _shear_x_impl(pil_img, level, img_shape):
"""Applies PIL ShearX to `pil_img`.
The ShearX operation shears the image along the horizontal axis with `level`
magnitude.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had ShearX applied to it.
"""
level = float_parameter(level, 0.3)
if random.random() > 0.5:
level = -level
return pil_img.transform(
(img_shape[0], img_shape[1]),
Image.AFFINE,
(1, level, 0, 0, 1, 0))
shear_x = TransformT('ShearX', _shear_x_impl)
def _shear_y_impl(pil_img, level, img_shape):
"""Applies PIL ShearY to `pil_img`.
The ShearY operation shears the image along the vertical axis with `level`
magnitude.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had ShearX applied to it.
"""
level = float_parameter(level, 0.3)
if random.random() > 0.5:
level = -level
return pil_img.transform(
(img_shape[0], img_shape[1]),
Image.AFFINE,
(1, 0, 0, level, 1, 0))
shear_y = TransformT('ShearY', _shear_y_impl)
def _translate_x_impl(pil_img, level, img_shape):
"""Applies PIL TranslateX to `pil_img`.
Translate the image in the horizontal direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had TranslateX applied to it.
"""
level = int_parameter(level, 10)
if random.random() > 0.5:
level = -level
return pil_img.transform(
(img_shape[0], img_shape[1]),
Image.AFFINE,
(1, 0, level, 0, 1, 0))
translate_x = TransformT('TranslateX', _translate_x_impl)
def _translate_y_impl(pil_img, level, img_shape):
"""Applies PIL TranslateY to `pil_img`.
Translate the image in the vertical direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had TranslateY applied to it.
"""
level = int_parameter(level, 10)
if random.random() > 0.5:
level = -level
return pil_img.transform(
(img_shape[0], img_shape[1]),
Image.AFFINE,
(1, 0, 0, 0, 1, level))
translate_y = TransformT('TranslateY', _translate_y_impl)
def _crop_impl(pil_img, level, img_shape, interpolation=Image.BILINEAR):
"""Applies a crop to `pil_img` with the size depending on the `level`."""
cropped = pil_img.crop((level, level, img_shape[0] - level, img_shape[1] - level))
resized = cropped.resize((img_shape[0], img_shape[1]), interpolation)
return resized
crop_bilinear = TransformT('CropBilinear', _crop_impl)
def _solarize_impl(pil_img, level, _):
"""Applies PIL Solarize to `pil_img`.
Translate the image in the vertical direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had Solarize applied to it.
"""
level = int_parameter(level, 256)
return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA')
solarize = TransformT('Solarize', _solarize_impl)
def _cutout_pil_impl(pil_img, level, img_shape):
"""Apply cutout to pil_img at the specified level."""
size = int_parameter(level, 20)
if size <= 0:
return pil_img
img_height, img_width, num_channels = (img_shape[0], img_shape[1], 3)
_, upper_coord, lower_coord = (
create_cutout_mask(img_height, img_width, num_channels, size))
pixels = pil_img.load() # create the pixel map
for i in range(upper_coord[0], lower_coord[0]): # for every col:
for j in range(upper_coord[1], lower_coord[1]): # For every row
pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly
return pil_img
cutout = TransformT('Cutout', _cutout_pil_impl)
def _enhancer_impl(enhancer):
"""Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL."""
def impl(pil_img, level, _):
v = float_parameter(level, 1.8) + .1 # going to 0 just destroys it
return enhancer(pil_img).enhance(v)
return impl
color = TransformT('Color', _enhancer_impl(ImageEnhance.Color))
contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast))
brightness = TransformT('Brightness', _enhancer_impl(
ImageEnhance.Brightness))
sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness))
ALL_TRANSFORMS = [
flip_lr,
flip_ud,
auto_contrast,
equalize,
invert,
rotate,
posterize,
crop_bilinear,
solarize,
color,
contrast,
brightness,
sharpness,
shear_x,
shear_y,
translate_x,
translate_y,
cutout,
blur,
smooth
]
NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS}
TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys()

View file

@ -0,0 +1,14 @@
import RandAugment as rand
import PIL
import torchvision
import transformations as TF
tpil=torchvision.transforms.ToPILImage()
ttensor=torchvision.transforms.ToTensor()
img,label =data_train[0]
rimg=ttensor(PIL.ImageEnhance.Color(tpil(img)).enhance(1.5))#ttensor(PIL.ImageOps.solarize(tpil(img), 50))#ttensor(tpil(img).transform(tpil(img).size, PIL.Image.AFFINE, (1, -0.1, 0, 0, 1, 0)))#rand.augmentations.FlipUD(tpil(img),1))
timg=TF.color(img.unsqueeze(0),torch.Tensor([1.5])).squeeze(0)
print(torch.allclose(rimg,timg, atol=1e-3))
tpil(rimg).save('rimg.jpg')
tpil(timg).save('timg.jpg')

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,85 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import higher
import time
data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=torchvision.transforms.ToTensor())
dl_train = torch.utils.data.DataLoader(data_train, batch_size=300, shuffle=True, num_workers=0, pin_memory=False)
class Aug_model(nn.Module):
def __init__(self, model, hyper_param=True):
super(Aug_model, self).__init__()
#### Origin of the issue ? ####
if hyper_param:
self._params = nn.ParameterDict({
"hyper_param": nn.Parameter(torch.Tensor([0.5])),
})
###############################
self._mods = nn.ModuleDict({
'model': model,
})
def forward(self, x):
return self._mods['model'](x) #* self._params['hyper_param']
def __getitem__(self, key):
return self._mods[key]
class Aug_model2(nn.Module): #Slow increase like no hyper_param
def __init__(self, model, hyper_param=True):
super(Aug_model2, self).__init__()
#### Origin of the issue ? ####
if hyper_param:
self._params = nn.ParameterDict({
"hyper_param": nn.Parameter(torch.Tensor([0.5])),
})
###############################
self._mods = nn.ModuleDict({
'model': model,
'fmodel': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
})
def forward(self, x):
return self._mods['fmodel'](x) * self._params['hyper_param']
def get_diffopt(self, opt, track_higher_grads=True):
return higher.optim.get_diff_optim(opt,
self._mods['model'].parameters(),
fmodel=self._mods['fmodel'],
track_higher_grads=track_higher_grads)
def __getitem__(self, key):
return self._mods[key]
if __name__ == "__main__":
device = torch.device('cuda:1')
aug_model = Aug_model2(
model=torch.hub.load('pytorch/vision:v0.4.2', 'resnet18', pretrained=False),
hyper_param=True #False will not extend step time
).to(device)
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)
#fmodel = higher.patch.monkeypatch(aug_model, device=None, copy_initial_weights=True)
#diffopt = higher.optim.get_diff_optim(inner_opt, aug_model.parameters(),fmodel=fmodel,track_higher_grads=True)
diffopt = aug_model.get_diffopt(inner_opt)
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
#logits = fmodel(xs)
logits = aug_model(xs)
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean')
t = time.process_time()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
#print(len(fmodel._fast_params),"step", time.process_time()-t)
print(len(aug_model['fmodel']._fast_params),"step", time.process_time()-t)

View file

@ -0,0 +1,502 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
## Basic CNN ##
class LeNet_F(nn.Module):
def __init__(self, num_inp, num_out):
super(LeNet_F, self).__init__()
self._params = nn.ParameterDict({
'w1': nn.Parameter(torch.zeros(20, num_inp, 5, 5)),
'b1': nn.Parameter(torch.zeros(20)),
'w2': nn.Parameter(torch.zeros(50, 20, 5, 5)),
'b2': nn.Parameter(torch.zeros(50)),
#'w3': nn.Parameter(torch.zeros(500,4*4*50)), #num_imp=1
'w3': nn.Parameter(torch.zeros(500,5*5*50)), #num_imp=3
'b3': nn.Parameter(torch.zeros(500)),
'w4': nn.Parameter(torch.zeros(num_out, 500)),
'b4': nn.Parameter(torch.zeros(num_out))
})
self.initialize()
def initialize(self):
nn.init.kaiming_uniform_(self._params["w1"], a=math.sqrt(5))
nn.init.kaiming_uniform_(self._params["w2"], a=math.sqrt(5))
nn.init.kaiming_uniform_(self._params["w3"], a=math.sqrt(5))
nn.init.kaiming_uniform_(self._params["w4"], a=math.sqrt(5))
def forward(self, x):
#print("Start Shape ", x.shape)
out = F.relu(F.conv2d(input=x, weight=self._params["w1"], bias=self._params["b1"]))
#print("Shape ", out.shape)
out = F.max_pool2d(out, 2)
#print("Shape ", out.shape)
out = F.relu(F.conv2d(input=out, weight=self._params["w2"], bias=self._params["b2"]))
#print("Shape ", out.shape)
out = F.max_pool2d(out, 2)
#print("Shape ", out.shape)
out = out.view(out.size(0), -1)
#print("Shape ", out.shape)
out = F.relu(F.linear(out, self._params["w3"], self._params["b3"]))
#print("Shape ", out.shape)
out = F.linear(out, self._params["w4"], self._params["b4"])
#print("Shape ", out.shape)
#return F.log_softmax(out, dim=1)
return out
def __getitem__(self, key):
return self._params[key]
def __str__(self):
return "LeNet"
## MobileNetv2 ##
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self,
num_classes=1000,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
"""
super(MobileNetV2, self).__init__()
if block is None:
block = InvertedResidual
input_channel = 32
last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
x = x.mean([2, 3])
x = self.classifier(x)
return x
def forward(self, x):
return self._forward_impl(x)
def __str__(self):
return "MobileNetV2"
## ResNet ##
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
#ResNet18 : block=BasicBlock, layers=[2, 2, 2, 2]
class ResNet(nn.Module):
def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def __str__(self):
return "ResNet18"
## Wide ResNet ##
#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
#https://github.com/arcelien/pba/blob/master/pba/wrn.py
#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py
'''
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
super(BasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_planes)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.droprate = dropRate
self.equalInOut = (in_planes == out_planes)
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=False) or None
def forward(self, x):
if not self.equalInOut:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
class NetworkBlock(nn.Module):
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
super(NetworkBlock, self).__init__()
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
layers = []
for i in range(int(nb_layers)):
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
#wrn_size: 32 = WRN-28-2 ? 160 = WRN-28-10
class WideResNet(nn.Module):
#def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0):
super(WideResNet, self).__init__()
self.kernel_size = wrn_size
self.depth=depth
filter_size = 3
nChannels = [min(self.kernel_size, 16), self.kernel_size, self.kernel_size * 2, self.kernel_size * 4]
strides = [1, 2, 2] # stride for each resblock
#nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
assert((depth - 4) % 6 == 0)
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(filter_size, nChannels[0], kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, strides[0], dropRate)
# 2nd block
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, strides[1], dropRate)
# 3rd block
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, strides[2], dropRate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3])
self.relu = nn.ReLU(inplace=True)
self.fc = nn.Linear(nChannels[3], num_classes)
self.nChannels = nChannels[3]
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.nChannels)
return self.fc(out)
def architecture(self):
return super(WideResNet, self).__str__()
def __str__(self):
return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth)
'''

View file

@ -0,0 +1,184 @@
from model import *
from dataug import *
#from utils import *
from train_utils import *
import torchvision.models as models
# Use available TF (see transformations.py)
tf_names = [
## Geometric TF ##
'Identity',
'FlipUD',
'FlipLR',
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
## Bad Tranformations ##
# Bad Geometric TF #
#'BShearX',
#'BShearY',
#'BTranslateX-',
#'BTranslateX-',
#'BTranslateY',
#'BTranslateY-',
#'BadContrast',
#'BadBrightness',
#'Random',
#'RandBlend'
]
device = torch.device('cuda')
if device == torch.device('cpu'):
device_name = 'CPU'
else:
device_name = torch.cuda.get_device_name(device)
torch.backends.cudnn.benchmark = True #Faster if same input size #Not recommended for reproductibility
#Increase reproductibility
torch.manual_seed(0)
np.random.seed(0)
##########################################
if __name__ == "__main__":
n_inner_iter = 1
epochs = 150
dataug_epoch_start=0
optim_param={
'Meta':{
'optim':'Adam',
'lr':1e-2, #1e-2
},
'Inner':{
'optim': 'SGD',
'lr':1e-1, #1e-2
'momentum':0.9, #0.9
}
}
model=models.resnet18()
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
####
'''
t0 = time.process_time()
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None)
exec_time=time.process_time() - t0
####
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
'''
####
'''
t0 = time.process_time()
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None)
exec_time=time.process_time() - t0
####
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
'''
res_folder="../res/brutus-tests2/"
epochs= 150
inner_its = [1]
dist_mix = [0.0, 0.5, 0.8, 1.0]
dataug_epoch_starts= [0]
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
N_seq_TF= [4, 3, 2]
mag_setup = [(True,True), (False, False)] #(Fixed, Shared)
#prob_setup = [True, False]
nb_run= 3
try:
os.mkdir(res_folder)
os.mkdir(res_folder+"log/")
except FileExistsError:
pass
for n_inner_iter in inner_its:
for dataug_epoch_start in dataug_epoch_starts:
for n_tf in N_seq_TF:
for dist in dist_mix:
#for i in TF_nb:
for m_setup in mag_setup:
#for p_setup in prob_setup:
p_setup=False
for run in range(nb_run):
if (n_inner_iter == 0 and (m_setup!=(True,True) and p_setup!=True)) or (p_setup and dist!=0.0): continue #Autres setup inutiles sans meta-opti
#keys = list(TF.TF_dict.keys())[0:i]
#ntf_dict = {k: TF.TF_dict[k] for k in keys}
t0 = time.process_time()
model = ResNet(num_classes=10)
model = Higher_model(model) #run_dist_dataugV3
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
log= run_dist_dataugV3(model=aug_model,
epochs=epochs,
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
print_freq=50,
KLdiv=True)
exec_time=time.process_time() - t0
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs (dataug:{})- {} in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter, run)
with open("../res/log/%s.json" % filename, "w+") as f:
try:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
except:
print("Failed to save logs :",f.name)
try:
plot_resV2(log, fig_name="../res/"+filename, param_names=aug_model.TF_names())
except:
print("Failed to plot res")
print('Execution Time : %.00f '%(exec_time))
print('-'*9)
#'''

View file

@ -0,0 +1,150 @@
import numpy as np
import json, math, time, os
from torch.utils.data import SubsetRandomSampler
import torch.optim as optim
import higher
from model import *
import copy
BATCH_SIZE = 300
TEST_SIZE = 300
mnist_train = torchvision.datasets.MNIST(
"./data", train=True, download=True,
transform=torchvision.transforms.Compose([
#torchvision.transforms.RandomAffine(degrees=180, translate=None, scale=None, shear=None, resample=False, fillcolor=0),
torchvision.transforms.ToTensor()
])
)
mnist_test = torchvision.datasets.MNIST(
"./data", train=False, download=True, transform=torchvision.transforms.ToTensor()
)
#train_subset_indices=range(int(len(mnist_train)/2))
train_subset_indices=range(BATCH_SIZE)
val_subset_indices=range(int(len(mnist_train)/2),len(mnist_train))
dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
dl_val = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=TEST_SIZE, shuffle=False)
def test(model):
model.eval()
for i, (features, labels) in enumerate(dl_test):
pred = model.forward(features)
return pred.argmax(dim=1).eq(labels).sum().item() / TEST_SIZE * 100
def train_classic(model, optim, epochs=1):
model.train()
log = []
for epoch in range(epochs):
t0 = time.process_time()
for i, (features, labels) in enumerate(dl_train):
optim.zero_grad()
pred = model.forward(features)
loss = F.cross_entropy(pred,labels)
loss.backward()
optim.step()
#### Log ####
tf = time.process_time()
data={
"time": tf - t0,
}
log.append(data)
times = [x["time"] for x in log]
print("Vanilla : acc", test(model), "in (ms):", np.mean(times), "+/-", np.std(times))
##########################################
if __name__ == "__main__":
device = torch.device('cpu')
model = LeNet(1,10)
opt_param = {
"lr": torch.tensor(1e-2).requires_grad_(),
"momentum": torch.tensor(0.9).requires_grad_()
}
n_inner_iter = 1
dl_train_it = iter(dl_train)
dl_val_it = iter(dl_val)
epoch = 0
epochs = 10
####
train_classic(model=model, optim=torch.optim.Adam(model.parameters(), lr=0.001), epochs=epochs)
model = LeNet(1,10)
meta_opt = torch.optim.Adam(opt_param.values(), lr=1e-2)
inner_opt = torch.optim.SGD(model.parameters(), lr=opt_param['lr'], momentum=opt_param['momentum'])
#for xs_val, ys_val in dl_val:
while epoch < epochs:
#print(data_aug.params["mag"], data_aug.params["mag"].grad)
meta_opt.zero_grad()
model.train()
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
for param_group in diffopt.param_groups:
param_group['lr'] = opt_param['lr']
param_group['momentum'] = opt_param['momentum']
for i in range(n_inner_iter):
try:
xs, ys = next(dl_train_it)
except StopIteration: #Fin epoch train
epoch +=1
dl_train_it = iter(dl_train)
xs, ys = next(dl_train_it)
print('Epoch', epoch)
print('train loss',loss.item(), '/ val loss', val_loss.item())
print('acc', test(model))
print('opt : lr', opt_param['lr'].item(), 'momentum', opt_param['momentum'].item())
print('-'*9)
model.train()
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
#print('loss',loss.item())
diffopt.step(loss) # note that `step` must take `loss` as an argument!
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
# these new parameters, as an alternative to getting them from
# `fmodel.fast_params` or `fmodel.parameters()` after calling
# `diffopt.step`.
# At this point, or at any point in the iteration, you can take the
# gradient of `fmodel.parameters()` (or equivalently
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
# `grad_fn` as an attribute, and be part of the gradient tape.
# At the end of your inner loop you can obtain these e.g. ...
#grad_of_grads = torch.autograd.grad(
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val_it)
xs_val, ys_val = next(dl_val_it)
val_logits = fmodel(xs_val)
val_loss = F.cross_entropy(val_logits, ys_val)
#print('val_loss',val_loss.item())
val_loss.backward()
#meta_grads = torch.autograd.grad(val_loss, opt_lr, allow_unused=True)
#print(meta_grads)
for param_group in diffopt.param_groups:
print(param_group['lr'], '/',param_group['lr'].grad)
print(param_group['momentum'], '/',param_group['momentum'].grad)
#model=copy.deepcopy(fmodel)
model.load_state_dict(fmodel.state_dict())
meta_opt.step()

View file

@ -0,0 +1,866 @@
import torch
#import torch.optim
import torchvision
import higher
from datasets import *
from utils import *
def train_classic_higher(model, epochs=1):
device = next(model.parameters()).device
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
model.train()
dl_val_it = iter(dl_val)
log = []
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False)
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, diffopt):
for epoch in range(epochs):
#print_torch_mem("Start epoch "+str(epoch))
#print("Fast param ",len(fmodel._fast_params))
t0 = time.process_time()
for i, (features, labels) in enumerate(dl_train):
#print_torch_mem("Start iter")
features,labels = features.to(device), labels.to(device)
#optim.zero_grad()
logits = model.forward(features)
pred = F.log_softmax(logits, dim=1)
loss = F.cross_entropy(pred,labels)
#.backward()
#optim.step()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
model_copy(src=fmodel, dst=model, patch_copy=False)
optim_copy(dopt=diffopt, opt=optim)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False)
#### Tests ####
tf = time.process_time()
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val)
xs_val, ys_val = next(dl_val_it)
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
val_loss = F.cross_entropy(model(xs_val), ys_val)
accuracy, _ =test(model)
model.train()
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": None,
}
log.append(data)
return log
def train_classic_tests(model, epochs=1):
device = next(model.parameters()).device
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
countcopy=0
model.train()
dl_val_it = iter(dl_val)
log = []
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
doptim = higher.optim.get_diff_optim(optim, model.parameters(), fmodel=fmodel, track_higher_grads=False)
for epoch in range(epochs):
print_torch_mem("Start epoch")
print(len(fmodel._fast_params))
t0 = time.process_time()
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=True) as (fmodel, doptim):
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
#doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True)
for i, (features, labels) in enumerate(dl_train):
features,labels = features.to(device), labels.to(device)
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, doptim):
#optim.zero_grad()
pred = fmodel.forward(features)
loss = F.cross_entropy(pred,labels)
doptim.step(loss) #(opt.zero_grad, loss.backward, opt.step)
#loss.backward()
#new_params = doptim.step(loss, params=fmodel.parameters())
#fmodel.update_params(new_params)
#print('Fast param',len(fmodel._fast_params))
#print('opt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][2]['momentum_buffer'].shape)
if False or (len(fmodel._fast_params)>1):
print("fmodel fast param",len(fmodel._fast_params))
'''
#val_loss = F.cross_entropy(fmodel(features), labels)
#print_graph(val_loss)
#val_loss.backward()
#print('bip')
tmp = fmodel.parameters()
#print(list(tmp)[1])
tmp = [higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in tmp]
#print(len(tmp))
#fmodel._fast_params.clear()
del fmodel._fast_params
fmodel._fast_params=None
fmodel.fast_params=tmp # Surcharge la memoire
#fmodel.update_params(tmp) #Meilleur perf / Surcharge la memoire avec trach higher grad
#optim._fmodel=fmodel
'''
countcopy+=1
model_copy(src=fmodel, dst=model, patch_copy=False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
#doptim.detach_dyn()
#tmp = doptim.state
#tmp = doptim.state_dict()
#for k, v in tmp['state'].items():
# print('dict',k, type(v))
a = optim.param_groups[0]['params'][0]
state = optim.state[a]
#state['momentum_buffer'] = None
#print('opt state', type(optim.state[a]), len(optim.state[a]))
#optim.load_state_dict(tmp)
for group_idx, group in enumerate(optim.param_groups):
# print('gp idx',group_idx)
for p_idx, p in enumerate(group['params']):
optim.state[p]=doptim.state[group_idx][p_idx]
#print('opt state', type(optim.state[a]['momentum_buffer']), optim.state[a]['momentum_buffer'][0:10])
#print('dopt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][0]['momentum_buffer'][0:10])
'''
for a in tmp:
#print(type(a), len(a))
for nb, b in a.items():
#print(nb, type(b), len(b))
for n, state in b.items():
#print(n, type(states))
#print(state.grad_fn)
state = torch.tensor(state.data).requires_grad_()
#print(state.grad_fn)
'''
doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True)
#doptim.state = tmp
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
#### Tests ####
tf = time.process_time()
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val)
xs_val, ys_val = next(dl_val_it)
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
val_loss = F.cross_entropy(model(xs_val), ys_val)
accuracy, _ =test(model)
model.train()
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": None,
}
log.append(data)
#countcopy+=1
#model_copy(src=fmodel, dst=model, patch_copy=False)
#optim.load_state_dict(doptim.state_dict()) #Besoin sauver etat otpim ?
print("Copy ", countcopy)
return log
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import augmentation_transforms
import numpy as np
class AugmentedDatasetV2(VisionDataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None):
super(AugmentedDatasetV2, self).__init__(root, transform=transform, target_transform=target_transform)
supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform)
self.sup_data = supervised_dataset.data if not subset else supervised_dataset.data[subset[0]:subset[1]]
self.sup_targets = supervised_dataset.targets if not subset else supervised_dataset.targets[subset[0]:subset[1]]
assert len(self.sup_data)==len(self.sup_targets)
for idx, img in enumerate(self.sup_data):
self.sup_data[idx]= Image.fromarray(img) #to PIL Image
self.unsup_data=[]
self.unsup_targets=[]
self.origin_idx=[]
self.dataset_info= {
'name': 'CIFAR10',
'sup': len(self.sup_data),
'unsup': len(self.unsup_data),
'length': len(self.sup_data)+len(self.unsup_data),
}
self._TF = [
## Geometric TF ##
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
'Cutout',
## Color TF ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize',
'Invert',
'AutoContrast',
'Equalize',
]
self._op_list =[]
self.prob=0.5
self.mag_range=(1, 10)
for tf in self._TF:
for mag in range(self.mag_range[0], self.mag_range[1]):
self._op_list+=[(tf, self.prob, mag)]
self._nb_op = len(self._op_list)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
aug_img, origin_img, target = self.unsup_data[index], self.sup_data[self.origin_idx[index]], self.unsup_targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
#img = Image.fromarray(img)
if self.transform is not None:
aug_img = self.transform(aug_img)
origin_img = self.transform(origin_img)
if self.target_transform is not None:
target = self.target_transform(target)
return aug_img, origin_img, target
def augement_data(self, aug_copy=1):
policies = []
for op_1 in self._op_list:
for op_2 in self._op_list:
policies += [[op_1, op_2]]
for idx, image in enumerate(self.sup_data):
if idx%(self.dataset_info['sup']/5)==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup'])
#if idx==10000:break
for _ in range(aug_copy):
chosen_policy = policies[np.random.choice(len(policies))]
aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image
#aug_image = augmentation_transforms.cutout_numpy(aug_image)
self.unsup_data+=[(aug_image*255.).astype(self.sup_data.dtype)]#Cast float image to uint8
self.unsup_targets+=[self.sup_targets[idx]]
self.origin_idx+=[idx]
#self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8
self.unsup_data=np.array(self.unsup_data)
assert len(self.unsup_data)==len(self.unsup_targets)
self.dataset_info['unsup']=len(self.unsup_data)
self.dataset_info['length']=self.dataset_info['sup']+self.dataset_info['unsup']
def __len__(self):
return self.dataset_info['unsup']#self.dataset_info['length']
def __str__(self):
return "CIFAR10(Sup:{}-Unsup:{}-{}TF(Mag{}-{}))".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF), self.mag_range[0], self.mag_range[1])
def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1):
"""Training of a model using UDA inspired approach.
Intended to be used alongside an already augmented dataset (see AugmentedDatasetV2).
Args:
model (nn.Module): Model to train.
dl_unsup (Dataloader): Data loader of unsupervised/augmented data.
opt_param (dict): Dictionnary containing optimizers parameters.
epochs (int): Number of epochs to perform. (default: 1)
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
Returns:
(list) Logs of training. Each items is a dict containing results of an epoch.
"""
device = next(model.parameters()).device
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
opt = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
model.train()
dl_val_it = iter(dl_val)
dl_unsup_it =iter(dl_unsup)
log = []
for epoch in range(epochs):
#print_torch_mem("Start epoch")
t0 = time.process_time()
for i, (features, labels) in enumerate(dl_train):
#print_torch_mem("Start iter")
features,labels = features.to(device), labels.to(device)
optim.zero_grad()
#Supervised
logits = model.forward(features)
pred = F.log_softmax(logits, dim=1)
sup_loss = F.cross_entropy(pred,labels)
#Unsupervised
try:
aug_xs, origin_xs, ys = next(dl_unsup_it)
except StopIteration: #Fin epoch val
dl_unsup_it =iter(dl_unsup)
aug_xs, origin_xs, ys = next(dl_unsup_it)
aug_xs, origin_xs, ys = aug_xs.to(device), origin_xs.to(device), ys.to(device)
#print(aug_xs.shape, origin_xs.shape, ys.shape)
sup_logits = model.forward(origin_xs)
unsup_logits = model.forward(aug_xs)
log_sup=F.log_softmax(sup_logits, dim=1)
log_unsup=F.log_softmax(unsup_logits, dim=1)
#KL div w/ logits
unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup)
unsup_loss=unsup_loss.sum(dim=-1).mean()
#print(unsup_loss)
unsupp_coeff = 1
loss = sup_loss + unsup_loss * unsupp_coeff
loss.backward()
optim.step()
#### Tests ####
tf = time.process_time()
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val)
xs_val, ys_val = next(dl_val_it)
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
val_loss = F.cross_entropy(model(xs_val), ys_val)
accuracy, _ =test(model)
model.train()
#### Print ####
if(print_freq and epoch%print_freq==0):
print('-'*9)
print('Epoch : %d/%d'%(epoch,epochs))
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Sup Loss :', sup_loss.item(), '/ unsup_loss :', unsup_loss.item())
print('Accuracy :', accuracy)
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": None,
}
log.append(data)
return log
def run_simple_dataug(inner_it, epochs=1):
device = next(model.parameters()).device
dl_train_it = iter(dl_train)
dl_val_it = iter(dl_val)
#aug_model = nn.Sequential(
# Data_aug(),
# LeNet(1,10),
# )
aug_model = Augmented_model(Data_aug(), LeNet(1,10)).to(device)
print(str(aug_model))
meta_opt = torch.optim.Adam(aug_model['data_aug'].parameters(), lr=1e-2)
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)
log = []
t0 = time.process_time()
epoch = 0
while epoch < epochs:
meta_opt.zero_grad()
aug_model.train()
with higher.innerloop_ctx(aug_model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
for i in range(n_inner_iter):
try:
xs, ys = next(dl_train_it)
except StopIteration: #Fin epoch train
tf = time.process_time()
epoch +=1
dl_train_it = iter(dl_train)
xs, ys = next(dl_train_it)
accuracy, _ =test(model)
aug_model.train()
#### Print ####
print('-'*9)
print('Epoch %d/%d'%(epoch,epochs))
print('train loss',loss.item(), '/ val loss', val_loss.item())
print('acc', accuracy)
print('mag', aug_model['data_aug']['mag'].item())
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": aug_model['data_aug']['mag'].item(),
}
log.append(data)
t0 = time.process_time()
xs, ys = xs.to(device), ys.to(device)
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
#loss.backward(retain_graph=True)
#print(fmodel['model']._params['b4'].grad)
#print('mag', fmodel['data_aug']['mag'].grad)
diffopt.step(loss) # note that `step` must take `loss` as an argument!
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
# these new parameters, as an alternative to getting them from
# `fmodel.fast_params` or `fmodel.parameters()` after calling
# `diffopt.step`.
# At this point, or at any point in the iteration, you can take the
# gradient of `fmodel.parameters()` (or equivalently
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
# `grad_fn` as an attribute, and be part of the gradient tape.
# At the end of your inner loop you can obtain these e.g. ...
#grad_of_grads = torch.autograd.grad(
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val)
xs_val, ys_val = next(dl_val_it)
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
fmodel.augment(mode=False)
val_logits = fmodel(xs_val) #Validation sans transfornations !
val_loss = F.cross_entropy(val_logits, ys_val)
#print('val_loss',val_loss.item())
val_loss.backward()
#print('mag', fmodel['data_aug']['mag'], '/', fmodel['data_aug']['mag'].grad)
#model=copy.deepcopy(fmodel)
aug_model.load_state_dict(fmodel.state_dict()) #Do not copy gradient !
#Copie des gradients
for paramName, paramValue, in fmodel.named_parameters():
for netCopyName, netCopyValue, in aug_model.named_parameters():
if paramName == netCopyName:
netCopyValue.grad = paramValue.grad
#print('mag', aug_model['data_aug']['mag'], '/', aug_model['data_aug']['mag'].grad)
meta_opt.step()
plot_res(log, fig_name="res/{}-{} epochs- {} in_it".format(str(aug_model),epochs,inner_it))
print('-'*9)
times = [x["time"] for x in log]
print(str(aug_model),": acc", max([x["acc"] for x in log]), "in (ms):", np.mean(times), "+/-", np.std(times))
def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
device = next(model.parameters()).device
dl_train_it = iter(dl_train)
dl_val_it = iter(dl_val)
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-3)
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
high_grad_track = True
if dataug_epoch_start>0:
model.augment(mode=False)
high_grad_track = False
model.train()
log = []
t0 = time.process_time()
countcopy=0
val_loss=torch.tensor(0)
opt_param=None
epoch = 0
while epoch < epochs:
meta_opt.zero_grad()
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
for i in range(n_inner_iter):
try:
xs, ys = next(dl_train_it)
except StopIteration: #Fin epoch train
tf = time.process_time()
epoch +=1
dl_train_it = iter(dl_train)
xs, ys = next(dl_train_it)
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
#viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
accuracy, _ =test(model)
model.train()
#### Print ####
print('-'*9)
print('Epoch : %d/%d'%(epoch,epochs))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy :', accuracy)
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
print('TF Proba :', model['data_aug']['prob'].data)
#print('proba grad',aug_model['data_aug']['prob'].grad)
#############
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": [p for p in model['data_aug']['prob']],
}
log.append(data)
#############
if epoch == dataug_epoch_start:
print('Starting Data Augmention...')
model.augment(mode=True)
high_grad_track = True
t0 = time.process_time()
xs, ys = xs.to(device), ys.to(device)
'''
#Methode exacte
final_loss = 0
for tf_idx in range(fmodel['data_aug']._nb_tf):
fmodel['data_aug'].transf_idx=tf_idx
logits = fmodel(xs)
loss = F.cross_entropy(logits, ys)
#loss.backward(retain_graph=True)
#print('idx', tf_idx)
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad)
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
loss = final_loss
'''
#Methode uniforme
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight().to(device)
loss = loss * w_loss
loss = loss.mean()
#'''
#to visualize computational graph
#print_graph(loss)
#loss.backward(retain_graph=True)
#print(fmodel['model']._params['b4'].grad)
#print('prob grad', fmodel['data_aug']['prob'].grad)
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val)
xs_val, ys_val = next(dl_val_it)
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
fmodel.augment(mode=False) #Validation sans transfornations !
val_loss = F.cross_entropy(fmodel(xs_val), ys_val)
#print_graph(val_loss)
val_loss.backward()
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
meta_opt.step()
model['data_aug'].adjust_param() #Contrainte sum(proba)=1
print("Copy ", countcopy)
return log
def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
device = next(model.parameters()).device
log = []
countcopy=0
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
dl_val_it = iter(dl_val)
#if inner_it!=0:
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
high_grad_track = True
if inner_it == 0:
high_grad_track=False
if dataug_epoch_start!=0:
model.augment(mode=False)
high_grad_track = False
val_loss_monitor= None
if loss_patience != None :
if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start
else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor (Not on val data : used by Dataug... => Test data)
model.train()
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
meta_opt.zero_grad()
for epoch in range(1, epochs+1):
#print_torch_mem("Start epoch "+str(epoch))
#print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params))
t0 = time.process_time()
#with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt):
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
#Methode exacte
#final_loss = 0
#for tf_idx in range(fmodel['data_aug']._nb_tf):
# fmodel['data_aug'].transf_idx=tf_idx
# logits = fmodel(xs)
# loss = F.cross_entropy(logits, ys)
# #loss.backward(retain_graph=True)
# final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
#loss = final_loss
if(not KLdiv):
#Methode uniforme
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
else:
#Methode KL div
if fmodel._data_augmentation :
fmodel.augment(mode=False)
sup_logits = fmodel(xs)
fmodel.augment(mode=True)
else:
sup_logits = fmodel(xs)
log_sup=F.log_softmax(sup_logits, dim=1)
loss = F.cross_entropy(log_sup, ys)
if fmodel._data_augmentation:
aug_logits = fmodel(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
#if epoch>50: #debut differe ?
#KL div w/ logits - Similarite predictions (distributions)
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
aug_loss = aug_loss.sum(dim=-1)
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
aug_loss = (w_loss * aug_loss).mean()
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
unsupp_coeff = 1
loss += aug_loss * unsupp_coeff
#to visualize computational graph
#print_graph(loss)
#loss.backward(retain_graph=True)
#print(fmodel['model']._params['b4'].grad)
#print('prob grad', fmodel['data_aug']['prob'].grad)
#t = time.process_time()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
#print(len(fmodel._fast_params),"step", time.process_time()-t)
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
#print("meta")
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #+ fmodel['data_aug'].reg_loss()
#print_graph(val_loss)
#t = time.process_time()
val_loss.backward()
#print("meta", time.process_time()-t)
#print('proba grad',model['data_aug']['prob'].grad)
if model['data_aug']['prob'].grad is None or model['data_aug']['mag'] is None:
print("Warning no grad (iter",i,") :\n Prob-",model['data_aug']['prob'].grad,"\n Mag-", model['data_aug']['mag'].grad)
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
#if epoch>50:
meta_opt.step()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
try: #Dataugv6
model['data_aug'].next_TF_set()
except:
pass
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
meta_opt.zero_grad()
tf = time.process_time()
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
if(not high_grad_track):
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
accuracy, test_loss =test(model)
model.train()
#### Log ####
#print(type(model['data_aug']) is dataug.Data_augV5)
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": param #if isinstance(model['data_aug'], Data_augV5)
#else [p.item() for p in model['data_aug']['prob']],
}
log.append(data)
#############
#### Print ####
if(print_freq and epoch%print_freq==0):
print('-'*9)
print('Epoch : %d/%d'%(epoch,epochs))
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy :', max([x["acc"] for x in log]))
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
print('TF Proba :', model['data_aug']['prob'].data)
#print('proba grad',model['data_aug']['prob'].grad)
print('TF Mag :', model['data_aug']['mag'].data)
#print('Mag grad',model['data_aug']['mag'].grad)
#print('Reg loss:', model['data_aug'].reg_loss().item())
#print('Aug loss', aug_loss.item())
#############
if val_loss_monitor :
model.eval()
val_loss_monitor.register(test_loss)#val_loss.item())
if val_loss_monitor.end_training(): break #Stop training
model.train()
if not model.is_augmenting() and (epoch == dataug_epoch_start or (val_loss_monitor and val_loss_monitor.limit_reached()==1)):
print('Starting Data Augmention...')
dataug_epoch_start = epoch
model.augment(mode=True)
if inner_it != 0: high_grad_track = True
try:
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
except:
print("Couldn't save finals samples")
pass
#print("Copy ", countcopy)
return log

View file

@ -0,0 +1,161 @@
import numpy as np
import json, math, time, os
import matplotlib.pyplot as plt
import copy
import gc
from torchviz import make_dot
import torch
import torch.nn.functional as F
import time
class timer():
def __init__(self):
self._start_time=time.time()
def exec_time(self):
end = time.time()
res = end-self._start_time
self._start_time=end
return res
def plot_res(log, fig_name='res', param_names=None):
epochs = [x["epoch"] for x in log]
fig, ax = plt.subplots(ncols=3, figsize=(15, 3))
ax[0].set_title('Loss')
ax[0].plot(epochs,[x["train_loss"] for x in log], label='Train')
ax[0].plot(epochs,[x["val_loss"] for x in log], label='Val')
ax[0].legend()
ax[1].set_title('Acc')
ax[1].plot(epochs,[x["acc"] for x in log])
if log[0]["param"]!= None:
if isinstance(log[0]["param"],float):
ax[2].set_title('Mag')
ax[2].plot(epochs,[x["param"] for x in log], label='Mag')
ax[2].legend()
else :
ax[2].set_title('Prob')
#for idx, _ in enumerate(log[0]["param"]):
#ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx))
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
ax[2].stackplot(epochs, proba, labels=param_names)
ax[2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name)
plt.close()
def plot_res_compare(filenames, fig_name='res'):
all_data=[]
#legend=""
for idx, file in enumerate(filenames):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
all_data.append(data)
n_tf = [len(x["Param_names"]) for x in all_data]
acc = [x["Accuracy"] for x in all_data]
time = [x["Time"][0] for x in all_data]
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
ax[0].plot(n_tf, acc)
ax[1].plot(n_tf, time)
ax[0].set_title('Acc')
ax[1].set_title('Time')
#for a in ax: a.legend()
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def plot_TF_res(log, tf_names, fig_name='res'):
mean = np.mean([x["param"] for x in log], axis=0)
std = np.std([x["param"] for x in log], axis=0)
fig, ax = plt.subplots(1, 1, figsize=(30, 8), sharey=True)
ax.bar(tf_names, mean, yerr=std)
#ax.bar(tf_names, log[-1]["param"])
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def model_copy(src,dst, patch_copy=True, copy_grad=True):
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
if patch_copy:
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
#Copie des gradients
if copy_grad:
for paramName, paramValue, in src.named_parameters():
for netCopyName, netCopyValue, in dst.named_parameters():
if paramName == netCopyName:
netCopyValue.grad = paramValue.grad
#netCopyValue=copy.deepcopy(paramValue)
try: #Data_augV4
dst['data_aug']._input_info = src['data_aug']._input_info
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
except:
pass
def optim_copy(dopt, opt):
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
for group_idx, group in enumerate(opt.param_groups):
# print('gp idx',group_idx)
for p_idx, p in enumerate(group['params']):
opt.state[p]=dopt.state[group_idx][p_idx]
class loss_monitor(): #Voir https://github.com/pytorch/ignite
def __init__(self, patience, end_train=1):
self.patience = patience
self.end_train = end_train
self.counter = 0
self.best_score = None
self.reached_limit = 0
def register(self, loss):
if self.best_score is None:
self.best_score = loss
elif loss > self.best_score:
self.counter += 1
#if not self.reached_limit:
print("loss no improve counter", self.counter, self.reached_limit)
else:
self.best_score = loss
self.counter = 0
def limit_reached(self):
if self.counter >= self.patience:
self.counter = 0
self.reached_limit +=1
self.best_score = None
return self.reached_limit
def end_training(self):
if self.limit_reached() >= self.end_train:
return True
else:
return False
def reset(self):
self.__init__(self.patience, self.end_train)

View file

@ -0,0 +1,169 @@
from utils import *
if __name__ == "__main__":
'''
# files=[
# "../res/log/Aug_mod(Data_augV5(Mix0.5-18TFx3-Mag)-efficientnet-b1)-200 epochs (dataug 0)- 1 in_it__AL2.json",
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-0.json",
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-1.json",
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-2.json",
# #"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
# ]
files = ["../res/benchmark/CIFAR10/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,1,'resnet18', str(run)) for run in range(3)]
#files = ["../res/benchmark/CIFAR10/log/Aug_mod(RandAug(18TFx%d-Mag%d)-%s)-200 epochs (dataug:0)- 0 in_it-%s.json"%(2,1,'resnet18', str(run)) for run in range(3)]
#files = ["../res/benchmark/CIFAR10/log/Aug_mod(Data_augV5(Mix%.1f-18TFx%d-Mag)-%s)-200 epochs (dataug:0)- 1 in_it-%s.json"%(0.5,3,'resnet18', str(run)) for run in range(3)]
for idx, file in enumerate(files):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
plot_resV2(data['Log'], fig_name=file.replace("/log","").replace(".json",""), param_names=data['Param_names'], f1=True)
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
'''
#Res print
# '''
nb_run=3
accs = []
aug_accs = []
f1_max = []
f1_min = []
times = []
mem = []
files = ["../res/benchmark/log/Aug_mod(Data_augV5(T0.5-19TFx3-Mag)-resnet18)-200 epochs (dataug 0)- 1 in_it__%s.json"%(str(run)) for run in range(1, nb_run+1)]
for idx, file in enumerate(files):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
accs.append(data['Accuracy'])
aug_accs.append(data['Aug_Accuracy'][1])
times.append(data['Time'][0])
mem.append(data['Memory'][1])
acc_idx = [x['acc'] for x in data['Log']].index(data['Accuracy'])
f1_max.append(max(data['Log'][acc_idx]['f1'])*100)
f1_min.append(min(data['Log'][acc_idx]['f1'])*100)
print(idx, accs[-1], aug_accs[-1])
print(files[0])
print("Acc : %.2f ~ %.2f / Aug_Acc %d: %.2f ~ %.2f"%(np.mean(accs), np.std(accs), data['Aug_Accuracy'][0], np.mean(aug_accs), np.std(aug_accs)))
print("F1 max : %.2f ~ %.2f / F1 min : %.2f ~ %.2f"%(np.mean(f1_max), np.std(f1_max), np.mean(f1_min), np.std(f1_min)))
print("Time (h): %.2f ~ %.2f"%(np.mean(times)/3600, np.std(times)/3600))
print("Mem (MB): %d ~ %d"%(np.mean(mem), np.std(mem)))
# '''
## Loss , Acc, Proba = f(epoch) ##
#plot_compare(filenames=files, fig_name="res/compare")
'''
## Acc, Time, Epochs = f(n_tf) ##
#fig_name="res/TF_nb_tests_compare"
fig_name="res/TF_seq_tests_compare"
inner_its = [0, 10]
dataug_epoch_starts= [0]
TF_nb = 14#[len(TF.TF_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
N_seq_TF= [1, 2, 3, 4, 6] #[1]
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
for in_it in inner_its:
for dataug in dataug_epoch_starts:
#n_tf = TF_nb
#filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF)-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(n_tf, dataug, in_it) for n_tf in TF_nb]
#filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(n_tf, 1, dataug, in_it) for n_tf in TF_nb]
n_tf = N_seq_TF
#filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(TF_nb, n_tf, dataug, in_it) for n_tf in N_seq_TF]
filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(TF_nb, n_tf, dataug, in_it) for n_tf in N_seq_TF]
all_data=[]
#legend=""
for idx, file in enumerate(filenames):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
all_data.append(data)
acc = [x["Accuracy"] for x in all_data]
epochs = [len(x["Log"]) for x in all_data]
time = [x["Time"][0] for x in all_data]
#for i in range(len(time)): time[i] *= epochs[i] #Estimation temps total
ax[0].plot(n_tf, acc, label="{} in_it/{} dataug".format(in_it,dataug))
ax[1].plot(n_tf, time, label="{} in_it/{} dataug".format(in_it,dataug))
ax[2].plot(n_tf, epochs, label="{} in_it/{} dataug".format(in_it,dataug))
#for data in all_data:
#print(np.mean([x["param"] for x in data["Log"]], axis=0))
#print(len(data["Param_names"]), np.argsort(np.argsort(np.mean([x["param"] for x in data["Log"]], axis=0))))
ax[0].set_title('Acc')
ax[1].set_title('Time')
ax[2].set_title('Epochs')
for a in ax: a.legend()
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
'''
'''
#HP search
inner_its = [1]
dist_mix = [0.3, 0.5, 0.8, 1.0] #Uniform
N_seq_TF= [3]
nb_run= 3
for n_inner_iter in inner_its:
for n_tf in N_seq_TF:
for dist in dist_mix:
files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-Mag)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(dist, n_tf, str(run)) for run in range(nb_run)]
#files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Uniform-14TFx%d-MagFxSh)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(n_tf, str(run)) for run in range(nb_run)]
accs = []
times = []
for idx, file in enumerate(files):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
accs.append(data['Accuracy'])
times.append(data['Time'][0])
print(idx, data['Accuracy'])
print(files[0], 'acc', np.mean(accs), '+-',np.std(accs), ',t', np.mean(times))
'''
'''
#Benchmark
model_list=['resnet18', 'resnet50','wide_resnet50_2']
nb_run= 3
for model_name in model_list:
files = ["../res/benchmark/CIFAR100/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,0.17,model_name, str(run)) for run in range(nb_run)]
#files = ["../res/benchmark/CIFAR10/log/%s-200 epochs -%s.json"%(model_name, str(run)) for run in range(nb_run)]
accs = []
times = []
mem_alloc = []
mem_cach = []
for idx, file in enumerate(files):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
accs.append(data['Accuracy'])
times.append(data['Time'][0])
mem_cach.append(data['Memory'])
print(idx, data['Accuracy'])
print(files[0], 'acc', np.mean(accs), '+-',np.std(accs), ',t', np.mean(times), 'Mem', np.mean(mem_cach))
'''

View file

@ -0,0 +1,59 @@
""" Example use of smart augmentation.
"""
from LeNet import *
from dataug import *
from train_utils import *
tf_config='../config/base_tf_config.json'
TF_loader=TF_loader()
device = torch.device('cuda') #Select device to use
if device == torch.device('cpu'):
device_name = 'CPU'
else:
device_name = torch.cuda.get_device_name(device)
##########################################
if __name__ == "__main__":
#Parameters
n_inner_iter = 1
epochs = 150
optim_param={
'Meta':{
'optim':'Adam',
'lr':1e-2, #1e-2
},
'Inner':{
'optim': 'SGD',
'lr':1e-2, #1e-2/1e-1 (ResNet)
'momentum':0.9, #0.9
'decay':0.0005, #0.0005
'nesterov':False, #False (True: Bad behavior w/ Data_aug)
'scheduler':'cosine', #None, 'cosine', 'multiStep', 'exponential'
}
}
#Models
model = LeNet(3,10)
#Smart_aug initialisation
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(tf_config)
model = Higher_model(model) #run_dist_dataugV3
aug_model = Augmented_model(
Data_augV5(TF_dict=tf_dict,
N_TF=3,
mix_dist=0.8,
fixed_prob=False,
fixed_mag=False,
shared_mag=False,
TF_ignore_mag=tf_ignore_mag),
model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
# Training
trained_model = run_simple_smartaug(model=aug_model, epochs=epochs, inner_it=n_inner_iter, opt_param=optim_param)

View file

@ -0,0 +1,204 @@
""" Script to run experiment on smart augmentation.
"""
import sys
from dataug import *
#from utils import *
from train_utils import *
from transformations import TF_loader
# from arg_parser import *
TF_loader=TF_loader()
torch.backends.cudnn.benchmark = True #Faster if same input size #Not recommended for reproductibility
#Increase reproductibility
torch.manual_seed(0)
np.random.seed(0)
##########################################
if __name__ == "__main__":
args = parser.parse_args()
print(args)
res_folder=args.res_folder
postfix=args.postfix
if args.dtype == 'FP32':
def_type=torch.float32
elif args.dtype == 'FP16':
# def_type=torch.float16 #Default : float32
def_type=torch.bfloat16
else:
raise Exception('dtype not supported :', args.dtype)
torch.set_default_dtype(def_type) #Default : float32
device = torch.device(args.device) #Select device to use
if device == torch.device('cpu'):
device_name = 'CPU'
else:
device_name = torch.cuda.get_device_name(device)
#Parameters
n_inner_iter = args.K
epochs = args.epochs
dataug_epoch_start=0
Nb_TF_seq= args.N
optim_param={
'Meta':{
'optim':'Adam',
'lr':args.mlr,
'epoch_start': args.meta_epoch_start, #0 / 2 (Resnet?)
'reg_factor': args.mag_reg,
'scheduler': None, #None, 'multiStep'
},
'Inner':{
'optim': 'SGD',
'lr':args.lr, #1e-2/1e-1 (ResNet)
'momentum':args.momentum, #0.9
'weight_decay':args.decay, #0.0005
'nesterov':args.nesterov, #False (True: Bad behavior w/ Data_aug)
'scheduler': args.scheduler, #None, 'cosine', 'multiStep', 'exponential'
'warmup':{
'multiplier': args.warmup, #2 #+ batch_size => + mutliplier #No warmup = 0
'epochs': 5
}
}
}
#Info params
F1=True
sample_save=None
print_f= epochs/4
#Load network
model, model_name= load_model(args.model, num_classes=len(dl_train.dataset.classes), pretrained=args.pretrained)
#### Classic ####
if not args.augment:
if device_name != 'CPU':
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
torch.cuda.reset_max_memory_cached() #reset_peak_stats
t0 = time.perf_counter()
model = model.to(device)
print("{} on {} for {} epochs{}".format(model_name, device_name, epochs, postfix))
#print("RandAugment(N{}-M{:.2f})-{} on {} for {} epochs{}".format(rand_aug['N'],rand_aug['M'],model_name, device_name, epochs, postfix))
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=print_f)
#log= train_classic_higher(model=model, epochs=epochs)
exec_time=time.perf_counter() - t0
if device_name != 'CPU':
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
else:
max_allocated = 0.0
max_cached=0.0
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]),
"Time": (np.mean(times),np.std(times), exec_time),
'Optimizer': optim_param['Inner'],
"Device": device_name,
"Memory": [max_allocated, max_cached],
#"Rand_Aug": rand_aug,
"Log": log}
print(model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs".format(model_name,epochs)+postfix
#print("RandAugment-",model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
#filename = "RandAugment(N{}-M{:.2f})-{}-{} epochs".format(rand_aug['N'],rand_aug['M'],model_name,epochs)+postfix
with open(res_folder+"log/%s.json" % filename, "w+") as f:
try:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
except:
print("Failed to save logs :",f.name)
print(sys.exc_info()[1])
try:
plot_resV2(log, fig_name=res_folder+filename, f1=F1)
except:
print("Failed to plot res")
print(sys.exc_info()[1])
print('Execution Time (s): %.00f '%(exec_time))
print('-'*9)
#### Augmented Model ####
else:
# tf_config='../config/invScale_wide_tf_config.json'#'../config/invScale_wide_tf_config.json'#'../config/base_tf_config.json'
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(args.tf_config)
if device_name != 'CPU':
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
torch.cuda.reset_max_memory_cached() #reset_peak_stats
t0 = time.perf_counter()
model = Higher_model(model, model_name) #run_dist_dataugV3
dataug_mod = 'Data_augV8' if args.learn_seq else 'Data_augV5'
if n_inner_iter !=0:
aug_model = Augmented_model(
globals()[dataug_mod](TF_dict=tf_dict,
N_TF=Nb_TF_seq,
temp=args.temp,
fixed_prob=False,
fixed_mag=args.fixed_mag,
shared_mag=args.shared_mag,
TF_ignore_mag=tf_ignore_mag), model).to(device)
else:
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=Nb_TF_seq), model).to(device)
print("{} on {} for {} epochs - {} inner_it{}".format(str(aug_model), device_name, epochs, n_inner_iter, postfix))
log, aug_acc = run_dist_dataugV3(model=aug_model,
epochs=epochs,
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
unsup_loss=1,
augment_loss=args.augment_loss,
hp_opt=False, #False #['lr', 'momentum', 'weight_decay']
print_freq=print_f,
save_sample_freq=sample_save)
exec_time=time.perf_counter() - t0
if device_name != 'CPU':
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
else:
max_allocated = 0.0
max_cached = 0.0
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]),
"Aug_Accuracy": [args.augment_loss, aug_acc],
"Time": (np.mean(times),np.std(times), exec_time),
'Optimizer': optim_param,
"Device": device_name,
"Memory": [max_allocated, max_cached],
"TF_config": args.tf_config,
"Param_names": aug_model.TF_names(),
"Log": log}
print(str(aug_model),": acc", out["Accuracy"], "/ aug_acc", out["Aug_Accuracy"][1] , "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{}_epochs-{}_in_it-AL{}".format(str(aug_model),epochs,n_inner_iter,args.augment_loss)+postfix
with open(res_folder+"log/%s.json" % filename, "w+") as f:
try:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
except:
print("Failed to save logs :",f.name)
print(sys.exc_info()[1])
try:
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names(), f1=F1)
except:
print("Failed to plot res")
print(sys.exc_info()[1])
print('Execution Time (s): %.00f '%(exec_time))
print('-'*9)

View file

@ -0,0 +1,592 @@
""" Utilities function for training.
"""
import sys
import torch
#import torch.optim
#import torchvision
import higher
import higher_patch
from datasets import *
from utils import *
from transformations import Normalizer, translate, zero_stack
norm = Normalizer(MEAN, STD)
confmat = ConfusionMatrix(num_classes=len(dl_test.dataset.classes))
max_grad = 1 #Max gradient value #Limite catastrophic drop
def test(model, augment=0):
"""Evaluate a model on test data.
Args:
model (nn.Module): Model to test.
augment (int): Number of augmented example for each sample. (Default : 0)
Returns:
(float, Tensor) Returns the accuracy and F1 score of the model.
"""
device = next(model.parameters()).device
model.eval()
# model['model']['functional'].set_mode('mixed') #ABN
#for i, (features, labels) in enumerate(dl_test):
# features,labels = features.to(device), labels.to(device)
# pred = model.forward(features)
# return pred.argmax(dim=1).eq(labels).sum().item() / dl_test.batch_size * 100
correct = 0
total = 0
#loss = []
global confmat
confmat.reset()
with torch.no_grad():
for features, labels in dl_test:
features,labels = features.to(device), labels.to(device)
if augment>0: #Test Time Augmentation
model.augment(True)
# V2
features=torch.cat([features for _ in range(augment)], dim=0) # (B,C,H,W)=>(B*augment,C,H,W)
outputs=model(features)
outputs=torch.cat([o.unsqueeze(dim=0) for o in outputs.chunk(chunks=augment, dim=0)],dim=0) # (B*augment,nb_class)=>(augment,B,nb_class)
w_losses=model['data_aug'].loss_weight(batch_norm=False) #(B*augment) if Dataug
if w_losses.shape[0]==1: #RandAugment
outputs=torch.sum(outputs, axis=0)/augment #mean
else: #Dataug
w_losses=torch.cat([w.unsqueeze(dim=0) for w in w_losses.chunk(chunks=augment, dim=0)], dim=0) #(augment, B)
w_losses = w_losses / w_losses.sum(axis=0, keepdim=True) #sum(w_losses)=1 pour un même echantillons
outputs=torch.sum(outputs*w_losses.unsqueeze(dim=2).expand_as(outputs), axis=0)/augment #Weighted mean
else:
outputs = model(features)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
#loss.append(F.cross_entropy(outputs, labels).item())
confmat.update(labels, predicted)
accuracy = 100 * correct / total
#print(confmat)
#from sklearn.metrics import f1_score
#f1 = f1_score(labels.data.to('cpu'), predicted.data.to('cpu'), average="macro")
return accuracy, confmat.f1_metric(average=None)
def compute_vaLoss(model, dl_it, dl):
"""Evaluate a model on a batch of data.
Args:
model (nn.Module): Model to evaluate.
dl_it (Iterator): Data loader iterator.
dl (DataLoader): Data loader.
Returns:
(Tensor) Loss on a single batch of data.
"""
device = next(model.parameters()).device
try:
xs, ys = next(dl_it)
except StopIteration: #Fin epoch val
dl_it = iter(dl)
xs, ys = next(dl_it)
xs, ys = xs.to(device), ys.to(device)
model.eval() #Validation sans transformations !
# model['model']['functional'].set_mode('mixed') #ABN
# return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys)
return F.cross_entropy(model(xs), ys)
def mixed_loss(xs, ys, model, unsup_factor=1, augment=1):
"""Evaluate a model on a batch of data.
Compute a combinaison of losses:
+ Supervised Cross-Entropy loss from original data.
+ Unsupervised Cross-Entropy loss from augmented data.
+ KL divergence loss encouraging similarity between original and augmented prediction.
If unsup_factor is equal to 0 or if there isn't data augmentation, only the supervised loss is computed.
Inspired by UDA, see: https://github.com/google-research/uda/blob/master/image/main.py
Args:
xs (Tensor): Batch of data.
ys (Tensor): Batch of labels.
model (nn.Module): Augmented model (see dataug.py).
unsup_factor (float): Factor by which unsupervised CE and KL div loss are multiplied.
augment (int): Number of augmented example for each sample. (Default : 1)
Returns:
(Tensor) Mixed loss if there's data augmentation, just supervised CE loss otherwise.
"""
#TODO: add test to prevent augmented model error and redirect to classic loss
if unsup_factor!=0 and model.is_augmenting() and augment>0:
# Supervised loss - Cross-entropy
model.augment(mode=False)
sup_logits = model(xs)
model.augment(mode=True)
log_sup = F.log_softmax(sup_logits, dim=1)
sup_loss = F.nll_loss(log_sup, ys)
# sup_loss = F.cross_entropy(log_sup, ys)
if augment>1:
# Unsupervised loss - Cross-Entropy
xs_a=torch.cat([xs for _ in range(augment)], dim=0) # (B,C,H,W)=>(B*augment,C,H,W)
ys_a=torch.cat([ys for _ in range(augment)], dim=0)
aug_logits=model(xs_a) # (B*augment,nb_class)
w_loss=model['data_aug'].loss_weight() #(B*augment) if Dataug
log_aug = F.log_softmax(aug_logits, dim=1)
aug_loss = F.nll_loss(log_aug, ys_a , reduction='none')
# aug_loss = F.cross_entropy(log_aug, ys_a , reduction='none')
aug_loss = (aug_loss * w_loss).mean()
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
sup_logits_a=torch.cat([sup_logits for _ in range(augment)], dim=0)
log_sup_a=torch.cat([log_sup for _ in range(augment)], dim=0)
kl_loss = (F.softmax(sup_logits_a, dim=1)*(log_sup_a-log_aug)).sum(dim=-1)
kl_loss = (w_loss * kl_loss).mean()
else:
# Unsupervised loss - Cross-Entropy
aug_logits = model(xs)
w_loss = model['data_aug'].loss_weight() #Weight loss
log_aug = F.log_softmax(aug_logits, dim=1)
aug_loss = F.nll_loss(log_aug, ys , reduction='none')
# aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
aug_loss = (aug_loss * w_loss).mean()
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
kl_loss = (w_loss * kl_loss).mean()
loss = sup_loss + unsup_factor * (aug_loss + kl_loss)
else: #Supervised loss - Cross-Entropy
sup_logits = model(xs)
loss = F.cross_entropy(sup_logits, ys)
# log_sup = F.log_softmax(sup_logits, dim=1)
# loss = F.cross_entropy(log_sup, ys)
return loss
def train_classic(model, opt_param, epochs=1, print_freq=1):
"""Classic training of a model.
Args:
model (nn.Module): Model to train.
opt_param (dict): Dictionnary containing optimizers parameters.
epochs (int): Number of epochs to perform. (default: 1)
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
Returns:
(list) Logs of training. Each items is a dict containing results of an epoch.
"""
device = next(model.parameters()).device
#Optimizer
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
optim = torch.optim.SGD(model.parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['weight_decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
#Scheduler
inner_scheduler=None
if opt_param['Inner']['scheduler']=='cosine':
inner_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=epochs, eta_min=0.)
elif opt_param['Inner']['scheduler']=='multiStep':
#Multistep milestones inspired by AutoAugment
inner_scheduler=torch.optim.lr_scheduler.MultiStepLR(optim,
milestones=[int(epochs/3), int(epochs*2/3), int(epochs*2.7/3)],
gamma=0.1)
elif opt_param['Inner']['scheduler']=='exponential':
#inner_scheduler=torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.1) #Wrong gamma
inner_scheduler=torch.optim.lr_scheduler.LambdaLR(optim, lambda epoch: (1 - epoch / epochs) ** 0.9)
elif opt_param['Inner']['scheduler'] is not None:
raise ValueError("Lr scheduler unknown : %s"%opt_param['Inner']['scheduler'])
# from warmup_scheduler import GradualWarmupScheduler
# inner_scheduler=GradualWarmupScheduler(optim, multiplier=2, total_epoch=5, after_scheduler=inner_scheduler)
#Training
model.train()
dl_val_it = iter(dl_val)
log = []
for epoch in range(epochs):
#print_torch_mem("Start epoch")
#print(optim.param_groups[0]['lr'])
t0 = time.perf_counter()
for i, (features, labels) in enumerate(dl_train):
#viz_sample_data(imgs=features, labels=labels, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
#print_torch_mem("Start iter")
features,labels = features.to(device), labels.to(device)
optim.zero_grad()
logits = model.forward(features)
pred = F.log_softmax(logits, dim=1)
loss = F.cross_entropy(pred,labels)
loss.backward()
optim.step()
# print_graph(loss, '../samples/torchvision_WRN') #to visualize computational graph
# sys.exit()
if inner_scheduler is not None:
inner_scheduler.step()
# print(optim.param_groups[0]['lr'])
#### Tests ####
tf = time.perf_counter()
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
accuracy, f1 =test(model)
model.train()
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"f1": f1.tolist(),
"time": tf - t0,
"param": None,
}
log.append(data)
#### Print ####
if(print_freq and epoch%print_freq==0):
print('-'*9)
print('Epoch : %d/%d'%(epoch,epochs))
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy max:', max([x["acc"] for x in log]))
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
return log
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, unsup_loss=1, augment_loss=1, hp_opt=False, print_freq=1, save_sample_freq=None):
"""Training of an augmented model with higher.
This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
Ex : Augmented_model(Data_augV5(...), Higher_model(model))
Training loss can either be computed directly from augmented inputs (unsup_loss=0).
However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
Args:
model (nn.Module): Augmented model to train.
opt_param (dict): Dictionnary containing optimizers parameters.
epochs (int): Number of epochs to perform. (default: 1)
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
dataug_epoch_start (int): Epoch when to start data augmentation. (default: 0)
unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
augment_loss (int): Number of augmented example for each sample in loss computation. (Default : 1)
hp_opt (bool): Wether to learn inner optimizer parameters. (default: False)
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, no sample will be saved. (default: None)
Returns:
(list) Logs of training. Each items is a dict containing results of an epoch.
"""
device = next(model.parameters()).device
log = []
# kl_log={"prob":[], "mag":[]}
dl_val_it = iter(dl_val)
high_grad_track = True
if inner_it == 0: #No HP optimization
high_grad_track=False
if dataug_epoch_start!=0: #Augmentation de donnee differee
model.augment(mode=False)
high_grad_track = False
## Optimizers ##
#Inner Opt
inner_opt = torch.optim.SGD(model['model']['original'].parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['weight_decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
diffopt = model['model'].get_diffopt(
inner_opt,
grad_callback=(lambda grads: clip_norm(grads, max_norm=max_grad)),
track_higher_grads=high_grad_track)
#Scheduler
inner_scheduler=None
if opt_param['Inner']['scheduler']=='cosine':
inner_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(inner_opt, T_max=epochs, eta_min=0.)
elif opt_param['Inner']['scheduler']=='multiStep':
#Multistep milestones inspired by AutoAugment
inner_scheduler=torch.optim.lr_scheduler.MultiStepLR(inner_opt,
milestones=[int(epochs/3), int(epochs*2/3), int(epochs*2.7/3)],
gamma=0.1)
elif opt_param['Inner']['scheduler']=='exponential':
#inner_scheduler=torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.1) #Wrong gamma
inner_scheduler=torch.optim.lr_scheduler.LambdaLR(inner_opt, lambda epoch: (1 - epoch / epochs) ** 0.9)
elif not(opt_param['Inner']['scheduler'] is None or opt_param['Inner']['scheduler']==''):
raise ValueError("Lr scheduler unknown : %s"%opt_param['Inner']['scheduler'])
#Warmup
if opt_param['Inner']['warmup']['multiplier']>=1:
from warmup_scheduler import GradualWarmupScheduler
inner_scheduler=GradualWarmupScheduler(inner_opt,
multiplier=opt_param['Inner']['warmup']['multiplier'],
total_epoch=opt_param['Inner']['warmup']['epochs'],
after_scheduler=inner_scheduler)
#Meta Opt
hyper_param = list(model['data_aug'].parameters())
if hp_opt : #(deprecated)
for param_group in diffopt.param_groups:
# print(param_group)
for param in hp_opt:
param_group[param]=torch.tensor(param_group[param]).to(device).requires_grad_()
hyper_param += [param_group[param]]
meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr'])
#Meta-Scheduler (deprecated)
meta_scheduler=None
if opt_param['Meta']['scheduler']=='multiStep':
meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt,
milestones=[int(epochs/3), int(epochs*2/3)],# int(epochs*2.7/3)],
gamma=3.16)
elif opt_param['Meta']['scheduler'] is not None:
raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler'])
model.train()
meta_opt.zero_grad()
for epoch in range(1, epochs+1):
t0 = time.perf_counter()
val_loss=None
#Cross-Validation
#dl_train, dl_val = cvs.next_split()
#dl_val_it = iter(dl_val)
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
if(unsup_loss==0):
#Methode uniforme
logits = model(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
if model._data_augmentation: #Weight loss
w_loss = model['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
else:
#Methode mixed
loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss, augment=augment_loss)
# print_graph(loss, '../samples/pytorch_WRN') #to visualize computational graph
# sys.exit()
# t = time.process_time()
diffopt.step(loss)#(opt.zero_grad, loss.backward, opt.step)
# print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
if(high_grad_track and i>0 and i%inner_it==0 and epoch>=opt_param['Meta']['epoch_start']): #Perform Meta step
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss(opt_param['Meta']['reg_factor'])
model.train()
#print_graph(val_loss) #to visualize computational graph
val_loss.backward()
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=max_grad, norm_type=2) #Prevent exploding grad with RNN
# print("Grad mix",model['data_aug']["temp"].grad)
# prv_param=model['data_aug']._params
meta_opt.step()
# kl_log["prob"].append(F.kl_div(prv_param["prob"],model['data_aug']["prob"], reduction='batchmean').item())
# kl_log["mag"].append(F.kl_div(prv_param["mag"],model['data_aug']["mag"], reduction='batchmean').item())
#Adjust Hyper-parameters
model['data_aug'].adjust_param()
if hp_opt:
for param_group in diffopt.param_groups:
for param in hp_opt:
param_group[param].data = param_group[param].data.clamp(min=1e-5)
#Reset gradients
diffopt.detach_()
model['model'].detach_()
meta_opt.zero_grad()
elif not high_grad_track or epoch<opt_param['Meta']['epoch_start']:
diffopt.detach_()
model['model'].detach_()
meta_opt.zero_grad()
tf = time.perf_counter()
#Schedulers
if inner_scheduler is not None:
inner_scheduler.step()
#Transfer inner_opt lr to diffopt
for diff_param_group in diffopt.param_groups:
for param_group in inner_opt.param_groups:
diff_param_group['lr'] = param_group['lr']
if meta_scheduler is not None:
meta_scheduler.step()
# if epoch<epochs/3:
# model['data_aug']['temp'].data=torch.tensor(0.5, device=device)
# elif epoch>epochs/3 and epoch<(epochs*2/3):
# model['data_aug']['temp'].data=torch.tensor(0.75, device=device)
# elif epoch>(epochs*2/3):
# model['data_aug']['temp'].data=torch.tensor(1.0, device=device)
# model['data_aug']['temp'].data=torch.tensor(1./3+2/3*(epoch/epochs), device=device)
# print('Temp',model['data_aug']['temp'])
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
try:
viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
model.train()
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
model.eval()
except:
print("Couldn't save samples epoch %d : %s"%(epoch, str(sys.exc_info()[1])))
pass
if(not val_loss): #Compute val loss for logs
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
# Test model
accuracy, f1 =test(model)
model.train()
#### Log ####
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"f1": f1.tolist(),
"time": tf - t0,
"param": param,
}
if not model['data_aug']._fixed_temp: data["temp"]=model['data_aug']['temp'].item()
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'], 'momentum': p_grp['momentum']} for p_grp in diffopt.param_groups]
log.append(data)
#############
#### Print ####
if(print_freq and epoch%print_freq==0):
print('-'*9)
print('Epoch : %d/%d'%(epoch,epochs))
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy max:', max([x["acc"] for x in log]))
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
if not model['data_aug']._fixed_prob: print('TF Proba :', ["{0:0.4f}".format(p) for p in model['data_aug']['prob']])
#print('proba grad',model['data_aug']['prob'].grad)
if not model['data_aug']._fixed_mag:
if model['data_aug']._shared_mag:
print('TF Mag :', "{0:0.4f}".format(model['data_aug']['mag']))
else:
print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']])
#print('Mag grad',model['data_aug']['mag'].grad)
if not model['data_aug']._fixed_temp: print('Temp:', model['data_aug']['temp'].item())
#print('Reg loss:', model['data_aug'].reg_loss().item())
# if len(kl_log["prob"])!=0:
# print("KL prob : mean %f, std %f, max %f, min %f"%(np.mean(kl_log["prob"]), np.std(kl_log["prob"]), max(kl_log["prob"]), min(kl_log["prob"])))
# print("KL mag : mean %f, std %f, max %f, min %f"%(np.mean(kl_log["mag"]), np.std(kl_log["mag"]), max(kl_log["mag"]), min(kl_log["mag"])))
# kl_log={"prob":[], "mag":[]}
if hp_opt :
for param_group in diffopt.param_groups:
print('Opt param - lr:', param_group['lr'].item(),'- momentum:', param_group['momentum'].item())
#############
#Augmentation de donnee differee
if not model.is_augmenting() and (epoch == dataug_epoch_start):
print('Starting Data Augmention...')
dataug_epoch_start = epoch
model.augment(mode=True)
if inner_it != 0: #Rebuild diffopt if needed
high_grad_track = True
diffopt = model['model'].get_diffopt(
inner_opt,
grad_callback=(lambda grads: clip_norm(grads, max_norm=max_grad)),
track_higher_grads=high_grad_track)
aug_acc, aug_f1 = test(model, augment=augment_loss)
return log, aug_acc
#OLD
# def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1):
# """Simple training of an augmented model with higher.
# This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
# Ex : Augmented_model(Data_augV5(...), Higher_model(model))
# Training loss can either be computed directly from augmented inputs (unsup_loss=0).
# However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
# Does not support LR scheduler.
# Args:
# model (nn.Module): Augmented model to train.
# opt_param (dict): Dictionnary containing optimizers parameters.
# epochs (int): Number of epochs to perform. (default: 1)
# inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
# print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
# unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
# Returns:
# (dict) A dictionary containing a whole state of the trained network.
# """
# device = next(model.parameters()).device
# ## Optimizers ##
# hyper_param = list(model['data_aug'].parameters())
# model.start_bilevel_opt(inner_it=inner_it, hp_list=hyper_param, opt_param=opt_param, dl_val=dl_val)
# model.train()
# for epoch in range(1, epochs+1):
# t0 = time.process_time()
# for i, (xs, ys) in enumerate(dl_train):
# xs, ys = xs.to(device), ys.to(device)
# #Methode mixed
# loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
# model.step(loss) #(opt.zero_grad, loss.backward, opt.step) + automatic meta-optimisation
# tf = time.process_time()
# #### Print ####
# if(print_freq and epoch%print_freq==0):
# print('-'*9)
# print('Epoch : %d/%d'%(epoch,epochs))
# print('Time : %.00f'%(tf - t0))
# print('Train loss :',loss.item(), '/ val loss', model.val_loss().item())
# if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
# if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
# if not model['data_aug']._fixed_temp: print('Temp:', model['data_aug']['temp'].item())
# #############
# return model['model'].state_dict()

View file

@ -0,0 +1,686 @@
""" PyTorch implementation of some PIL image transformations.
Those implementation are thinked to take advantages of batched computation of PyTorch on GPU.
Based on Kornia library.
See: https://github.com/kornia/kornia
And PIL.
See:
https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
Inspired from AutoAugment.
See: https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py
"""
import torch
import kornia
import random
import json
#TF that don't have use for magnitude parameter.
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend', 'identity', 'flip'}
#TF which implemetation doesn't allow gradient propagaition.
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize', 'posterize','solarize'} #Numpy implementation would be better ?
#TF for which magnitude should be ignored (Magnitude fixed).
TF_ignore_mag= TF_no_mag | TF_no_grad
# What is the max 'level' a transform could be predicted
PARAMETER_MAX = 1
# What is the min 'level' a transform could be predicted
PARAMETER_MIN = 0.01
#Dict containing the value for wich TF are closer to identity
#TF_identity={
# PARAMETER_MAX:{'Solarize', 'Posterize'},
# PARAMETER_MAX/2:{'Contrast','Color','Brightness','Sharpness'},
# PARAMETER_MIN:{'Rotate','TranslateX','TranslateY','ShearX','ShearY'},
#}
class Normalizer(object):
def __init__(self, mean, std):
self.mean=torch.tensor(mean)
self.std=torch.tensor(std)
def __call__(self, x):
# return x.sub_(self.mean.to(x.device)[None, :, None, None]).div_(self.std.to(x.device)[None, :, None, None])
return kornia.color.normalize(x, self.mean, self.std)
def reverse(self, x):
# return x.mul_(self.std.to(x.device)[None, :, None, None]).add_(self.mean.to(x.device)[None, :, None, None])
return kornia.color.denormalize(x, self.mean, self.std).to(torch.half).to(torch.float)
class TF_loader(object):
""" Transformations builder.
See 'config' folder for pre-defined config files.
Attributes:
_filename (str): Path to config file (JSON) used.
_TF_dict (dict): Transformations dictionnary built from config file.
_TF_ignore_mag (set): Ensemble of transformations names for which magnitude should be ignored.
_TF_names (list): List of transformations names/keys.
"""
def __init__(self):
""" Initialize TF_loader.
"""
self._filename=''
self._TF_dict={}
self._TF_ignore_mag=set()
self._TF_names=[]
def load_TF_dict(self, filename):
""" Build a TF dictionnary.
Args:
filename (str): Path to config file (JSON) defining the transformations.
Returns:
(dict, set) : TF dicttionnary built and ensemble of TF names for which mag should be ignored.
"""
self._filename=filename
self._TF_names=[]
self._TF_dict={}
self._TF_ignore_mag=set()
with open(filename) as json_file:
TF_params = json.load(json_file)
for tf in TF_params:
self._TF_names.append(tf['name'])
if tf['function'] in TF_ignore_mag:
self._TF_ignore_mag.add(tf['name'])
if tf['function'] == 'identity':
self._TF_dict[tf['name']]=(lambda x, mag: x)
elif tf['function'] == 'flip':
#Inverser axes ?
if tf['param']['axis'] == 'X':
self._TF_dict[tf['name']]=(lambda x, mag: flipLR(x))
elif tf['param']['axis'] == 'Y':
self._TF_dict[tf['name']]=(lambda x, mag: flipUD(x))
else:
raise Exception("Unknown TF axis : %s in %s"%(tf['function'], self._filename))
# elif tf['function'] in {'translate', 'shear'}:
# rand_fct= 'invScale_rand_floats' if tf['param']['invScale'] else 'rand_floats'
# self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'], tf['param']['absolute'], tf['param']['axis'])
else:
axis = tf['param']['axis'] if 'axis' in tf['param'].keys() else None
absolute = tf['param']['absolute'] if 'absolute' in tf['param'].keys() else True
rand_fct= 'invScale_rand_floats' if tf['param']['invScale'] else 'rand_floats'
self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'], absolute, axis)
return self._TF_dict, self._TF_ignore_mag
def build_lambda(self, fct_name, rand_fct_name, minval, maxval, absolute=True, axis=None):
""" Build a lambda function performing transformations.
Force different context for creation of each lambda function.
Args:
fct_name (str): Name of the transformations to use (see transformations.py).
rand_fct_name (str): Name of the random mapping function to use (see transformations.py).
minval (float): minimum magnitude value of the TF.
maxval (float): maximum magnitude value of the TF.
absolute (bool): Wether the maxval should be relative (absolute=False) to the image size. (default: True)
axis (str): Axis ('X' / 'Y') of the TF, if relevant. Should be used for (flip)/translate/shear functions. (default: None)
Returns:
(function) transformations function : Tensor=f(Tensor, magnitude)
"""
if absolute:
max_val_fct=(lambda x: maxval)
else: #Relative to img size
max_val_fct=(lambda x: x*maxval)
if axis is None:
return (lambda x, mag:
globals()[fct_name]( #getattr(TF, fct_name)
x,
globals()[rand_fct_name](
size=x.shape[0],
mag=mag,
minval=minval,
maxval=max_val_fct(max(x.shape[2],x.shape[3])))))
elif axis =='X':
return (lambda x, mag:
globals()[fct_name](
x,
zero_stack(
globals()[rand_fct_name](
size=(x.shape[0],),
mag=mag,
minval=minval,
maxval=max_val_fct(x.shape[2])),
zero_pos=0)))
elif axis == 'Y':
return (lambda x, mag:
globals()[fct_name](
x,
zero_stack(
globals()[rand_fct_name](
size=(x.shape[0],),
mag=mag,
minval=minval,
maxval=max_val_fct(x.shape[3])),
zero_pos=1)))
else:
raise Exception("Unknown TF axis : %s in %s"%(fct_name, self._filename))
def get_TF_names(self):
return self._TF_names
def get_TF_dict(self):
return self._TF_dict
## Image type cast ##
def int_image(float_image):
"""Convert a float Tensor/Image to an int Tensor/Image.
Be warry that this transformation isn't bijective, each conversion will result in small loss of information.
Granularity: 1/256 = 0.0039.
This will also result in the loss of the gradient associated to input as gradient cannot be tracked on int Tensor.
Args:
float_image (FloatTensor): Image tensor.
Returns:
(ByteTensor) Converted tensor.
"""
return (float_image*255.).type(torch.uint8)
def float_image(int_image):
"""Convert a int Tensor/Image to an float Tensor/Image.
Args:
int_image (ByteTensor): Image tensor.
Returns:
(FloatTensor) Converted tensor.
"""
return int_image.type(torch.float)/255.
## Parameters utils ##
def rand_floats(size, mag, maxval, minval=None):
"""Generate a batch of random values.
Args:
size (int): Number of value to generate.
mag (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX].
maxval (float): Maximum value that can be generated. This will be scaled to mag/PARAMETER_MAX.
minval (float): Minimum value that can be generated. (default: -maxval)
Returns:
(Tensor) Generated batch of float values between [minval, maxval].
"""
real_mag = float_parameter(mag, maxval=maxval)
if not minval : minval = -real_mag
#return random.uniform(minval, real_max)
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
def invScale_rand_floats(size, mag, maxval, minval):
"""Generate a batch of random values.
Similar to rand_floats() except that the mag is used in an inversed scale.
Mag:[0,PARAMETER_MAX] => [PARAMETER_MAX, 0]
Args:
size (int): Number of value to generate.
mag (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX].
maxval (float): Maximum value that can be generated. This will be scaled to mag/PARAMETER_MAX.
minval (float): Minimum value that can be generated. (default: -maxval)
Returns:
(Tensor) Generated batch of float values between [minval, maxval].
"""
real_mag = float_parameter(float(PARAMETER_MAX) - mag, maxval=maxval-minval)+minval
return real_mag + (maxval-real_mag) * torch.rand(size, device=mag.device) #[real_mag, max_val]
def zero_stack(tensor, zero_pos):
"""Add a row of zeros to a Tensor.
This function is intended to be used with single row Tensor, thus returning a 2 dimension Tensor.
Args:
tensor (Tensor): Tensor to be stacked with zeros.
zero_pos (int): Wheter the zeros should be added before or after the Tensor. Either 0 or 1.
Returns:
Stacked Tensor.
"""
if zero_pos==0:
return torch.stack((tensor, torch.zeros((tensor.shape[0],), device=tensor.device)), dim=1)
if zero_pos==1:
return torch.stack((torch.zeros((tensor.shape[0],), device=tensor.device), tensor), dim=1)
else:
raise Exception("Invalid zero_pos : ", zero_pos)
def float_parameter(level, maxval):
"""Scale level between 0 and maxval.
Args:
level (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX].
maxval: Maximum value that the operation can have. This will be scaled to level/PARAMETER_MAX.
Returns:
A float that results from scaling `maxval` according to `level`.
"""
#return float(level) * maxval / PARAMETER_MAX
return (level * maxval / PARAMETER_MAX)#.to(torch.float)
## Tranformations ##
def flipLR(x):
"""Flip horizontaly/Left-Right images.
Args:
x (Tensor): Batch of images.
Returns:
(Tensor): Batch of fliped images.
"""
device = x.device
(batch_size, channels, h, w) = x.shape
M =torch.tensor( [[[-1., 0., w-1],
[ 0., 1., 0.],
[ 0., 0., 1.]]], device=device).expand(batch_size,-1,-1)
# warp the original image by the found transform
return kornia.warp_perspective(x, M, dsize=(h, w))
def flipUD(x):
"""Flip vertically/Up-Down images.
Args:
x (Tensor): Batch of images.
Returns:
(Tensor): Batch of fliped images.
"""
device = x.device
(batch_size, channels, h, w) = x.shape
M =torch.tensor( [[[ 1., 0., 0.],
[ 0., -1., h-1],
[ 0., 0., 1.]]], device=device).expand(batch_size,-1,-1)
# warp the original image by the found transform
return kornia.warp_perspective(x, M, dsize=(h, w))
def rotate(x, angle):
"""Rotate images.
Args:
x (Tensor): Batch of images.
angle (Tensor): Angles (degrees) of rotation for each images.
Returns:
(Tensor): Batch of rotated images.
"""
return kornia.rotate(x, angle=angle.type(torch.float)) #Kornia ne supporte pas les int
def translate(x, translation):
"""Translate images.
Args:
x (Tensor): Batch of images.
translation (Tensor): Distance (pixels) of translation for each images.
Returns:
(Tensor): Batch of translated images.
"""
return kornia.translate(x, translation=translation.type(torch.float)) #Kornia ne supporte pas les int
def shear(x, shear):
"""Shear images.
Args:
x (Tensor): Batch of images.
shear (Tensor): Angle of shear for each images.
Returns:
(Tensor): Batch of skewed images.
"""
return kornia.shear(x, shear=shear)
def contrast(x, contrast_factor):
"""Adjust contast of images.
Args:
x (FloatTensor): Batch of images.
contrast_factor (FloatTensor): Contrast adjust factor per element in the batch.
0 generates a compleatly black image, 1 does not modify the input image while any other non-negative number modify the brightness by this factor.
Returns:
(Tensor): Batch of adjusted images.
"""
return kornia.adjust_contrast(x, contrast_factor=contrast_factor) #Expect image in the range of [0, 1]
def color(x, color_factor):
"""Adjust color of images.
Args:
x (Tensor): Batch of images.
color_factor (Tensor): Color factor for each images.
0.0 gives a black and white image. A factor of 1.0 gives the original image.
Returns:
(Tensor): Batch of adjusted images.
"""
(batch_size, channels, h, w) = x.shape
gray_x = kornia.rgb_to_grayscale(x)
gray_x = gray_x.repeat_interleave(channels, dim=1)
return blend(gray_x, x, color_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
def brightness(x, brightness_factor):
"""Adjust brightness of images.
Args:
x (Tensor): Batch of images.
brightness_factor (Tensor): Brightness factor for each images.
0.0 gives a black image. A factor of 1.0 gives the original image.
Returns:
(Tensor): Batch of adjusted images.
"""
device = x.device
return blend(torch.zeros(x.size(), device=device), x, brightness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
def sharpness(x, sharpness_factor):
"""Adjust sharpness of images.
Args:
x (Tensor): Batch of images.
sharpness_factor (Tensor): Sharpness factor for each images.
0.0 gives a black image. A factor of 1.0 gives the original image.
Returns:
(Tensor): Batch of adjusted images.
"""
device = x.device
(batch_size, channels, h, w) = x.shape
k = torch.tensor([[[ 1., 1., 1.],
[ 1., 5., 1.],
[ 1., 1., 1.]]], device=device) #Smooth Filter : https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageFilter.py
smooth_x = kornia.filter2D(x, kernel=k, border_type='reflect', normalized=True) #Peut etre necessaire de s'occuper du channel Alhpa differement
return blend(smooth_x, x, sharpness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
def posterize(x, bits):
"""Reduce the number of bits for each color channel.
Be warry that the cast to integers block the gradient propagation.
Args:
x (Tensor): Batch of images.
bits (Tensor): The number of bits to keep for each channel (1-8).
Returns:
(Tensor): Batch of posterized images.
"""
bits = bits.type(torch.uint8) #Perte du gradient
x = int_image(x) #Expect image in the range of [0, 1]
mask = ~(2 ** (8 - bits) - 1).type(torch.uint8)
(batch_size, channels, h, w) = x.shape
mask = mask.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
return float_image(x & mask)
def cutout(img, length):
"""
Args:
img (Tensor): Batch of images. Expect image value between [0, 1].
length (Tensor): The length (in pixels) of each square patch.
Returns:
Tensor: Images with single holes of dimension length x length cut out of it.
"""
device = img.device
(batch_size, channels, h, w) = img.shape
# mask = np.ones((h, w), np.float32)
mask = torch.ones((batch_size, h, w), device=device)
length=length.type(torch.uint8)
# y = np.random.randint(h)
# x = np.random.randint(w)
y = torch.randint(low=0, high=h, size=(batch_size,)).to(device)
x = torch.randint(low=0, high=w, size=(batch_size,)).to(device)
# y1 = np.clip(y - length // 2, 0, h)
# y2 = np.clip(y + length // 2, 0, h)
# x1 = np.clip(x - length // 2, 0, w)
# x2 = np.clip(x + length // 2, 0, w)
y1 = (y - length // 2).clamp(min=0, max=h)
y2 = (y + length // 2).clamp(min=0, max=h)
x1 = (x - length // 2).clamp(min=0, max=w)
x2 = (x + length // 2).clamp(min=0, max=w)
# mask[y1: y2, x1: x2] = 0.
for idx in range(batch_size): #Pas opti pour des batch
mask[idx, y1[idx]: y2[idx], x1[idx]: x2[idx]]= 0.
# mask = torch.from_numpy(mask)
# mask = mask.expand_as(img)
mask = mask.unsqueeze(dim=1).expand_as(img)
img = img * mask
return img
import torch.nn.functional as F
def solarize(x, thresholds):
"""Invert all pixel values above a threshold.
Be warry that the use of the inequality (x>tresholds) block the gradient propagation.
TODO : Make differentiable.
Args:
x (Tensor): Batch of images.
thresholds (Tensor): All pixels above this level are inverted
Returns:
(Tensor): Batch of solarized images.
"""
batch_size, channels, h, w = x.shape
#imgs=[]
#for idx, t in enumerate(thresholds): #Operation par image
# mask = x[idx] > t #Perte du gradient
#In place
# inv_x = 1-x[idx][mask]
# x[idx][mask]=inv_x
#
#Out of place
# im = x[idx]
# inv_x = 1-im[mask]
# imgs.append(im.masked_scatter(mask,inv_x))
#idxs=torch.tensor(range(x.shape[0]), device=x.device)
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
#
thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
x=torch.where(x>thresholds,1-x, x)
#x=x.min(thresholds)
#inv_x = 1-x[mask]
#x=x.where(x<thresholds,1-x)
#x[mask]=inv_x
#x=x.masked_scatter(mask, inv_x)
#Differentiable (/Thresholds) ?
#inv_x_bT= F.relu(x) - F.relu(x - thresholds)
#inv_x_aT= 1-x #Besoin thresholds
#print('-'*10)
#print(thresholds[0])
#print(x[0])
#print(inv_x_bT[0])
#print(inv_x_aT[0])
#x=torch.where(x>thresholds,inv_x_aT, inv_x_bT)
#print(torch.allclose(x, x+0.001, atol=1e-3))
#print(torch.allclose(x, sol_x, atol=1e-2))
#print(torch.eq(x,sol_x)[0])
#print(x[0])
#print(sol_x[0])
#'''
return x
def blend(x,y,alpha):
"""Creates a new images by interpolating between two input images, using a constant alpha.
x and y should have the same size.
alpha should have the same batch size as the images.
Apply batch wise :
out = image1 * (1.0 - alpha) + image2 * alpha
Args:
x (Tensor): Batch of images.
y (Tensor): Batch of images.
alpha (Tensor): The interpolation alpha factor for each images.
Returns:
(Tensor): Batch of solarized images.
"""
#return kornia.add_weighted(src1=x, alpha=(1-alpha), src2=y, beta=alpha, gamma=0) #out=src1alpha+src2beta+gamma #Ne fonctionne pas pour des batch de alpha
if not isinstance(x, torch.Tensor):
raise TypeError("x should be a tensor. Got {}".format(type(x)))
if not isinstance(y, torch.Tensor):
raise TypeError("y should be a tensor. Got {}".format(type(y)))
assert(x.shape==y.shape and x.shape[0]==alpha.shape[0])
(batch_size, channels, h, w) = x.shape
alpha = alpha.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
res = x*(1-alpha) + y*alpha
return res
#Not working
def auto_contrast(x):
"""NOT TESTED - EXTRA SLOW
"""
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
print("Warning : Pas encore check !")
(batch_size, channels, h, w) = x.shape
x = int_image(x) #Expect image in the range of [0, 1]
#print('Start',x[0])
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
#print(img.shape)
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
#print(chan.shape)
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
# find lowest/highest samples after preprocessing
for lo in range(256):
if hist[lo]:
break
for hi in range(255, -1, -1):
if hist[hi]:
break
if hi <= lo:
# don't bother
pass
else:
scale = 255.0 / (hi - lo)
offset = -lo * scale
for ix in range(256):
n_ix = int(ix * scale + offset)
if n_ix < 0: n_ix = 0
elif n_ix > 255: n_ix = 255
chan[chan==ix]=n_ix
x[im_idx, chan_idx]=chan
#print('End',x[0])
return float_image(x)
def equalize(x):
""" NOT WORKING
"""
raise Exception(self, "not implemented")
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
(batch_size, channels, h, w) = x.shape
x = int_image(x) #Expect image in the range of [0, 1]
#print('Start',x[0])
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
#print(img.shape)
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
#print(chan.shape)
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
return float_image(x)
'''
# Dictionnary mapping tranformations identifiers to their function.
# Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data.
TF_dict={ #Dataugv5+
## Geometric TF ##
'Identity' : (lambda x, mag: x),
'FlipUD' : (lambda x, mag: flipUD(x)),
'FlipLR' : (lambda x, mag: flipLR(x)),
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=x.shape[2]*0.33), zero_pos=0))),
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=x.shape[3]*0.33), zero_pos=1))),
'TranslateXabs': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
'TranslateYabs': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
#Color TF (Common mag scale)
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'+Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'-Sharpness':(lambda x, mag: sharpness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
## Bad Tranformations ##
# Bad Geometric TF #
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))),
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))),
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))),
'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))),
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))),
'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))),
# Bad Color TF #
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))),
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
# Random TF #
'Random':(lambda x, mag: torch.rand_like(x)),
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))),
#Not ready for use
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None),
}
'''

436
higher/smart_aug/utils.py Normal file
View file

@ -0,0 +1,436 @@
""" Utilties function.
"""
import numpy as np
import json, math, time, os
import matplotlib
matplotlib.use('Agg') #https://stackoverflow.com/questions/4706451/how-to-save-a-figure-remotely-with-pylab
import matplotlib.pyplot as plt
import copy
import gc
import torch
import torch.nn.functional as F
import time
from nets.LeNet import *
from nets.wideresnet import *
from nets.wideresnet_cifar import *
import nets.resnet_abn as resnet_abn
import nets.resnet_deconv as resnet_DC
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import url_map as EfficientNet_map
import torchvision.models as models
def load_model(model, num_classes, pretrained=False):
if model in models.resnet.__all__ :
model_name = model #'resnet18' #'resnet34' #'wide_resnet50_2'
if pretrained :
print("Using pretrained weights")
model = getattr(models.resnet, model_name)(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
else:
model = getattr(models.resnet, model_name)(pretrained=False, num_classes=num_classes)
elif model in models.vgg.__all__ :
model_name = model #'vgg11', 'vgg1_bn'
if pretrained :
print("Using pretrained weights")
model = getattr(models.vgg, model_name)(pretrained=True)
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
else :
model = getattr(models.vgg, model_name)(pretrained=False, num_classes=num_classes)
elif model in models.densenet.__all__ :
model_name = model #'densenet121' #'densenet201'
if pretrained :
print("Using pretrained weights")
model = getattr(models.densenet, model_name)(pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, num_classes)
else:
model = getattr(models.densenet, model_name)(pretrained=False, num_classes=num_classes)
elif model == 'LeNet':
if pretrained :
print("Pretrained weights not available")
model = LeNet(3,num_classes)
model_name=str(model)
elif model == 'WideResNet':
if pretrained :
print("Pretrained weights not available")
# model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_classes)
# model = WideResNet(16, 4, dropout_rate=0.0, num_classes=num_classes)
model = wide_resnet_cifar(26, 10, num_classes=num_classes)
# model = wide_resnet_cifar(20, 10, num_classes=num_classes)
model_name=str(model)
elif model in EfficientNet_map.keys():
model_name=model # efficientnet-b0 , efficientnet-b1, efficientnet-b4
if pretrained: #ImageNet ou Advprop (Meilleurs perf normalement mais normalisation differentes)
print("Using pretrained weights")
model = EfficientNet.from_pretrained(model_name, advprop=False)
else:
model = EfficientNet.from_name(model_name)
elif model in resnet_abn.__all__ :
if pretrained :
print("Pretrained weights not available")
model_name=model
model = getattr(resnet_abn, model_name)(pretrained=False, num_classes=num_classes)
elif model in resnet_DC.__all__:
if pretrained :
print("Pretrained weights not available")
model_name = model
model = getattr(resnet_DC, model_name)(num_classes=num_classes)
else:
raise Exception('Unknown model')
return model, model_name
class ConfusionMatrix(object):
""" Confusion matrix.
Helps computing the confusion matrix and F1 scores.
Example use ::
confmat = ConfusionMatrix(...)
confmat.reset()
for data in dataset:
...
confmat.update(...)
confmat.f1_metric(...)
Attributes:
num_classes (int): Number of classes.
mat (Tensor): Confusion matrix. Filled by update method.
"""
def __init__(self, num_classes):
""" Initialize ConfusionMatrix.
Args:
num_classes (int): Number of classes.
"""
self.num_classes = num_classes
self.mat = None
def update(self, target, pred):
""" Update the confusion matrix.
Args:
target (Tensor): Target labels.
pred (Tensor): Prediction.
"""
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
with torch.no_grad():
k = (target >= 0) & (target < n)
inds = n * target[k].to(torch.int64) + pred[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
""" Reset the Confusion matrix.
"""
if self.mat is not None:
self.mat.zero_()
def f1_metric(self, average=None):
""" Compute the F1 score.
Inspired from :
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
Args:
average (str): Type of averaging performed on the data. (Default: None)
``None``:
The scores for each class are returned.
``'micro'``:
Calculate metrics globally by counting the total true positives,
false negatives and false positives.
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account.
Return:
Tensor containing the F1 score. It's shape is either 1, if there was averaging, or (num_classes).
"""
h = self.mat.float()
TP = torch.diag(h)
TN = []
FP = []
FN = []
for c in range(self.num_classes):
idx = torch.ones(self.num_classes).bool()
idx[c] = 0
# all non-class samples classified as non-class
TN.append(self.mat[idx.nonzero()[:, None], idx.nonzero()].sum()) #conf_matrix[idx[:, None], idx].sum() - conf_matrix[idx, c].sum()
# all non-class samples classified as class
FP.append(self.mat[idx, c].sum())
# all class samples not classified as class
FN.append(self.mat[c, idx].sum())
#print('Class {}\nTP {}, TN {}, FP {}, FN {}'.format(c, TP[c], TN[c], FP[c], FN[c]))
tp = (TP/h.sum(1))#.sum()
tn = (torch.tensor(TN, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
fp = (torch.tensor(FP, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
fn = (torch.tensor(FN, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
if average=="micro":
tp, tn, fp, fn = tp.sum(), tn.sum(), fp.sum(), fn.sum()
epsilon = 1e-7
precision = tp / (tp + fp + epsilon)
recall = tp / (tp + fn + epsilon)
f1 = 2* (precision*recall) / (precision + recall + epsilon)
if average=="macro":
f1=f1.mean()
return f1
#from torchviz import make_dot
def print_graph(PyTorch_obj=torch.randn(1, 3, 32, 32), fig_name='graph'):
"""Save the computational graph.
Args:
PyTorch_obj (Tensor): End of the graph. Commonly, the loss tensor to get the whole graph.
fig_name (string): Relative path where to save the graph. (default: graph)
"""
graph=make_dot(PyTorch_obj)
graph.format = 'png' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
graph.render(fig_name)
def plot_resV2(log, fig_name='res', param_names=None, f1=True):
"""Save a visual graph of the logs.
Args:
log (dict): Logs of the training generated by most of train_utils.
fig_name (string): Relative path where to save the graph. (default: res)
param_names (list): Labels for the parameters. (default: None)
f1 (bool): Wether to plot F1 scores. (default: True)
"""
epochs = [x["epoch"] for x in log]
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(30, 15))
ax[0, 0].set_title('Loss')
ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train')
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
ax[0, 0].legend()
ax[1, 0].set_title('Test')
ax[1, 0].plot(epochs,[x["acc"] for x in log], label='Acc')
if f1 and "f1" in log[0].keys():
#ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1')
#'''
#print(log[0]["f1"])
if isinstance(log[0]["f1"], list):
if len(log[0]["f1"])>10:
print("Plotting results : Too many class for F1, plotting only min/max")
ax[1, 0].plot(epochs,[max(x["f1"])*100 for x in log], label='F1-Max', ls='--')
ax[1, 0].plot(epochs,[min(x["f1"])*100 for x in log], label='F1-Min', ls='--')
else:
for c in range(len(log[0]["f1"])):
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
else:
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1', ls='--')
#'''
ax[1, 0].legend()
if log[0]["param"]!= None:
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
#proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
ax[0, 1].set_title('Prob =f(epoch)')
ax[0, 1].stackplot(epochs, proba, labels=param_names)
#ax[0, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
ax[1, 1].set_title('Prob =f(TF)')
mean = np.mean(proba, axis=1)
std = np.std(proba, axis=1)
ax[1, 1].bar(param_names, mean, yerr=std)
plt.sca(ax[1, 1]), plt.xticks(rotation=90)
ax[0, 2].set_title('Mag =f(epoch)')
ax[0, 2].stackplot(epochs, mag, labels=param_names)
#ax[0, 2].plot(epochs, np.array(mag).T, label=param_names)
ax[0, 2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
ax[1, 2].set_title('Mag =f(TF)')
mean = np.mean(mag, axis=1)
std = np.std(mag, axis=1)
ax[1, 2].bar(param_names, mean, yerr=std)
plt.sca(ax[1, 2]), plt.xticks(rotation=90)
fig_name = fig_name.replace('.',',').replace(',,/','../')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def plot_compare(filenames, fig_name='res'):
"""Save a visual graph comparing trainings stats.
Args:
filenames (list[Strings]): Relative paths to the logs (JSON files).
fig_name (string): Relative path where to save the graph. (default: res)
"""
all_data=[]
legend=""
for idx, file in enumerate(filenames):
legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
all_data.append(data)
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
for data_idx, log in enumerate(all_data):
log=log['Log']
epochs = [x["epoch"] for x in log]
ax[0].plot(epochs,[x["train_loss"] for x in log], label=str(data_idx)+'-Train')
ax[0].plot(epochs,[x["val_loss"] for x in log], label=str(data_idx)+'-Val')
ax[1].plot(epochs,[x["acc"] for x in log], label=str(data_idx))
#ax[1].text(x=0.5,y=0,s=str(data_idx)+'-'+filenames[data_idx], transform=ax[1].transAxes)
if log[0]["param"]!= None:
if isinstance(log[0]["param"],float):
ax[2].plot(epochs,[x["param"] for x in log], label=str(data_idx)+'-Mag')
else :
for idx, _ in enumerate(log[0]["param"]):
ax[2].plot(epochs,[x["param"][idx] for x in log], label=str(data_idx)+'-P'+str(idx))
fig.suptitle(legend)
ax[0].set_title('Loss')
ax[1].set_title('Acc')
ax[2].set_title('Param')
for a in ax: a.legend()
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
"""Save data samples.
Args:
imgs (Tensor): Batch of image to sample from. Intended to contain at least 25 images.
labels (Tensor): Labels of the images.
fig_name (string): Relative path where to save the graph. (default: data_sample)
weight_labels (Tensor): Weights associated to each labels. (default: None)
"""
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1) #Trop de figure cree ?
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
label = str(labels[i].item())
if torch.is_tensor(weight_labels): label+= (" - p %.2f" % weight_labels[i].item())
plt.xlabel(label)
plt.savefig(fig_name)
print("Sample saved :", fig_name)
plt.close('all')
def print_torch_mem(add_info=''):
"""Print informations on PyTorch memory usage.
Args:
add_info (string): Prefix added before the print. (default: None)
"""
nb=0
max_size=0
for obj in gc.get_objects():
#print(type(obj))
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # and len(obj.size())>1:
#print(i, type(obj), obj.size())
size = np.sum(obj.size())
if(size>max_size): max_size=size
nb+=1
except:
pass
print(add_info, "-Pytroch tensor nb:",nb," / Max dim:", max_size)
#print(add_info, "-Garbage size :",len(gc.garbage))
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = add_info + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format(
torch.cuda.max_memory_cached()/ mega_bytes)
print(string)
'''
def plot_TF_influence(log, fig_name='TF_influence', param_names=None):
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
plt.figure()
mean = np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
std = np.std(proba, axis=1)*np.std(mag, axis=1)
plt.bar(param_names, mean, yerr=std)
plt.xticks(rotation=90)
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
'''
from torch._six import inf
def clip_norm(tensors, max_norm, norm_type=2):
"""Clips norm of passed tensors.
The norm is computed over all tensors together, as if they were
concatenated into a single vector. Clipped tensors are returned.
See: https://github.com/facebookresearch/higher/issues/18
Args:
tensors (Iterable[Tensor]): an iterable of Tensors or a
single Tensor to be normalized.
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Clipped (List[Tensor]) tensors.
"""
if isinstance(tensors, torch.Tensor):
tensors = [tensors]
tensors = list(tensors)
max_norm = float(max_norm)
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(t.abs().max() for t in tensors)
else:
total_norm = 0
for t in tensors:
if t is None:
continue
param_norm = t.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef >= 1:
return tensors
#return [t.mul(clip_coef) for t in tensors]
return [t if t is None else t.mul(clip_coef) for t in tensors]