mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Changes since Teledyne
This commit is contained in:
parent
03ffd7fe05
commit
b89dac9084
185 changed files with 16668 additions and 484 deletions
73
higher/smart_aug/arg_parser.py
Normal file
73
higher/smart_aug/arg_parser.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
import argparse
|
||||
|
||||
#Argparse
|
||||
parser = argparse.ArgumentParser(description='Run smart augmentation')
|
||||
parser.add_argument('-dv','--device', default='cuda', dest='device',
|
||||
help='Device : cpu / cuda')
|
||||
parser.add_argument('-dt','--dtype', default='FP32', dest='dtype',
|
||||
help='Data type (Default: Float32)')
|
||||
|
||||
parser.add_argument('-m','--model', default='resnet18', dest='model',
|
||||
help='Network')
|
||||
parser.add_argument('-pt','--pretrained', default='', dest='pretrained',
|
||||
help='Use pretrained weight if possible')
|
||||
|
||||
parser.add_argument('-ep','--epochs', type=int, default=10, dest='epochs',
|
||||
help='epoch')
|
||||
# parser.add_argument('-ot', '--optimizer', default='SGD', dest='opt_type',
|
||||
# help='Model optimizer')
|
||||
parser.add_argument('-lr', type=float, default=1e-1, dest='lr',
|
||||
help='Model learning rate')
|
||||
parser.add_argument('-mo', '--momentum', type=float, default=0.9, dest='momentum',
|
||||
help='Momentum')
|
||||
parser.add_argument('-dc', '--decay', type=float, default=0.0005, dest='decay',
|
||||
help='Weight decay')
|
||||
parser.add_argument('-ns','--nesterov', type=bool, default=False, dest='nesterov',
|
||||
help='Nesterov momentum ?')
|
||||
parser.add_argument('-sc', '--scheduler', default='cosine', dest='scheduler',
|
||||
help='Model learning rate scheduler')
|
||||
parser.add_argument('-wu', '--warmup', type=float, default=0, dest='warmup',
|
||||
help='Warmup multiplier')
|
||||
|
||||
|
||||
parser.add_argument('-a','--augment', type=bool, default=False, dest='augment',
|
||||
help='Data augmentation ?')
|
||||
parser.add_argument('-N', type=int, default=1,
|
||||
help='Combination of TF')
|
||||
parser.add_argument('-K', type=int, default=0,
|
||||
help='Number inner iteration')
|
||||
parser.add_argument('-al','--augment_loss', type=int, default=1, dest='augment_loss',
|
||||
help='Number of augmented example for each sample in loss computation.')
|
||||
parser.add_argument('-t', '--temp', type=float, default=0.5, dest='temp',
|
||||
help='Probability distribution temperature')
|
||||
parser.add_argument('-tfc','--tf_config', default='../config/invScale_wide_tf_config.json', dest='tf_config',
|
||||
help='TF config')
|
||||
parser.add_argument('-ls', '--learn_seq', type=bool, default=False, dest='learn_seq',
|
||||
help='Learn order of application of TF (DataugV7-8) ?')
|
||||
parser.add_argument('-fm', '--fixed_mag', type=bool, default=False, dest='fixed_mag',
|
||||
help='Fixed magnitude when learning data augmentation ?')
|
||||
parser.add_argument('-sm', '--shared_mag', type=bool, default=False, dest='shared_mag',
|
||||
help='Shared magnitude when learning data augmentation ?')
|
||||
|
||||
# parser.add_argument('-mot', '--metaoptimizer', default='Adam', dest='meta_opt_type',
|
||||
# help='Meta optimizer (Augmentations)')
|
||||
parser.add_argument('-mlr', type=float, default=1e-2, dest='mlr',
|
||||
help='Meta learning rate (Augmentations)')
|
||||
parser.add_argument('-ms', type=int, default=0, dest='meta_epoch_start',
|
||||
help='Epoch at which start meta learning')
|
||||
parser.add_argument('-mr', type=float, default=0.001, dest='mag_reg',
|
||||
help='Augmentation magnitudes regulation factor')
|
||||
|
||||
parser.add_argument('-rf','--res_folder', default='../res/', dest='res_folder',
|
||||
help='Results folder')
|
||||
parser.add_argument('-pf','--postfix', default='', dest='postfix',
|
||||
help='Res postfix')
|
||||
|
||||
parser.add_argument('-dr','--dataroot', default='~/scratch/data', dest='dataroot',
|
||||
help='Datasets folder')
|
||||
parser.add_argument('-ds','--dataset', default='CIFAR10', dest='dataset',
|
||||
help='Dataset')
|
||||
parser.add_argument('-bs','--batch_size', type=int, default=256, dest='batch_size',
|
||||
help='Batch size') #256 (WRN) / 512
|
||||
parser.add_argument('-w','--workers', type=int, default=6, dest='workers',
|
||||
help='Numer of workers (Nb CPU cores).')
|
|
@ -7,19 +7,22 @@ from train_utils import *
|
|||
from transformations import TF_loader
|
||||
|
||||
import torchvision.models as models
|
||||
from LeNet import *
|
||||
|
||||
#model_list={models.resnet: ['resnet18', 'resnet50','wide_resnet50_2']} #lr=0.1
|
||||
model_list={models.resnet: ['resnet18']}
|
||||
model_list={models.resnet: ['wide_resnet50_2']}
|
||||
|
||||
optim_param={
|
||||
'Meta':{
|
||||
'optim':'Adam',
|
||||
'lr':1e-2, #1e-2
|
||||
'lr':5e-3, #1e-2
|
||||
'epoch_start': 2, #0 / 2 (Resnet?)
|
||||
'reg_factor': 0.001,
|
||||
'scheduler': None, #None, 'multiStep'
|
||||
},
|
||||
'Inner':{
|
||||
'optim': 'SGD',
|
||||
'lr':1e-1, #1e-2/1e-1 (ResNet)
|
||||
'lr':1e-2, #1e-2/1e-1 (ResNet)
|
||||
'momentum':0.9, #0.9
|
||||
'decay':0.0005, #0.0005
|
||||
'nesterov':False, #False (True: Bad behavior w/ Data_aug)
|
||||
|
@ -28,16 +31,17 @@ optim_param={
|
|||
}
|
||||
|
||||
res_folder="../res/benchmark/CIFAR10/"
|
||||
#res_folder="../res/benchmark/MNIST/"
|
||||
#res_folder="../res/HPsearch/"
|
||||
epochs= 200
|
||||
dataug_epoch_start=0
|
||||
nb_run= 1
|
||||
nb_run= 3
|
||||
|
||||
tf_config='../config/wide_tf_config.json' #'../config/wide_tf_config.json'#'../config/base_tf_config.json'
|
||||
tf_config='../config/bad_tf_config.json' #'../config/wide_tf_config.json'#'../config/base_tf_config.json'
|
||||
TF_loader=TF_loader()
|
||||
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(tf_config)
|
||||
|
||||
device = torch.device('cuda')
|
||||
device = torch.device('cuda:1')
|
||||
|
||||
if device == torch.device('cpu'):
|
||||
device_name = 'CPU'
|
||||
|
@ -54,8 +58,8 @@ np.random.seed(0)
|
|||
if __name__ == "__main__":
|
||||
|
||||
### Benchmark ###
|
||||
#'''
|
||||
inner_its = [3]
|
||||
'''
|
||||
inner_its = [0]
|
||||
dist_mix = [0.5]
|
||||
N_seq_TF= [3]
|
||||
mag_setup = [(False, False)] #[(True, True), (False, False)] #(FxSh, Independant)
|
||||
|
@ -74,6 +78,8 @@ if __name__ == "__main__":
|
|||
t0 = time.perf_counter()
|
||||
|
||||
model = getattr(model_type, model_name)(pretrained=False, num_classes=len(dl_train.dataset.classes))
|
||||
#model_name = 'LeNet'
|
||||
#model = LeNet(3,10)
|
||||
|
||||
model = Higher_model(model, model_name) #run_dist_dataugV3
|
||||
if n_inner_iter!=0:
|
||||
|
@ -122,12 +128,17 @@ if __name__ == "__main__":
|
|||
print('Log :\"',f.name, '\" saved !')
|
||||
except:
|
||||
print("Failed to save logs :",f.name)
|
||||
try:
|
||||
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names())
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
print(sys.exc_info()[1])
|
||||
|
||||
print('Execution Time : %.00f '%(exec_time))
|
||||
print('-'*9)
|
||||
#'''
|
||||
### Benchmark - RandAugment/Vanilla ###
|
||||
'''
|
||||
### Benchmark - RandAugment/Vanilla ###
|
||||
#'''
|
||||
for model_type in model_list.keys():
|
||||
for model_name in model_list[model_type]:
|
||||
for run in range(nb_run):
|
||||
|
@ -155,7 +166,7 @@ if __name__ == "__main__":
|
|||
#"Rand_Aug": rand_aug,
|
||||
"Log": log}
|
||||
print(model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs -{}".format(model_name,epochs, run)
|
||||
filename = "{}-{} epochs -{}-basicDA".format(model_name,epochs, run)
|
||||
#print("RandAugment-",model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
#filename = "RandAugment(N{}-M{:.2f})-{}-{} epochs -{}".format(rand_aug['N'],rand_aug['M'],model_name,epochs, run)
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
|
@ -166,11 +177,14 @@ if __name__ == "__main__":
|
|||
print("Failed to save logs :",f.name)
|
||||
print(sys.exc_info()[1])
|
||||
|
||||
#plot_resV2(log, fig_name=res_folder+filename)
|
||||
|
||||
try:
|
||||
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names())
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
print(sys.exc_info()[1])
|
||||
print('Execution Time : %.00f '%(exec_time))
|
||||
print('-'*9)
|
||||
'''
|
||||
#'''
|
||||
### HP Search ###
|
||||
'''
|
||||
from LeNet import *
|
||||
|
|
|
@ -2,41 +2,47 @@
|
|||
|
||||
MNIST / CIFAR10 / CIFAR100 / SVHN / ImageNet
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data.dataset import ConcatDataset
|
||||
import torchvision
|
||||
from arg_parser import *
|
||||
|
||||
#Train/Validation batch size.
|
||||
BATCH_SIZE = 512
|
||||
#Test batch size.
|
||||
TEST_SIZE = BATCH_SIZE
|
||||
#TEST_SIZE = 10000 #legerement +Rapide / + Consomation memoire !
|
||||
args = parser.parse_args()
|
||||
|
||||
#Wether to download data.
|
||||
download_data=False
|
||||
#Number of worker to use.
|
||||
num_workers=2 #4
|
||||
#Pin GPU memory
|
||||
pin_memory=False #True :+ GPU memory / + Lent
|
||||
#Data storage folder
|
||||
dataroot="../data"
|
||||
dataroot=args.dataroot
|
||||
|
||||
# if args.dtype == 'FP32':
|
||||
# def_type=torch.float32
|
||||
# elif args.dtype == 'FP16':
|
||||
# # def_type=torch.float16 #Default : float32
|
||||
# def_type=torch.bfloat16
|
||||
# else:
|
||||
# raise Exception('dtype not supported :', args.dtype)
|
||||
|
||||
#ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1]
|
||||
#transform_train = torchvision.transforms.Compose([
|
||||
# torchvision.transforms.RandomHorizontalFlip(),
|
||||
# torchvision.transforms.ToTensor(),
|
||||
# torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #CIFAR10
|
||||
#])
|
||||
transform = torchvision.transforms.Compose([
|
||||
transform = [
|
||||
#torchvision.transforms.Grayscale(3), #MNIST
|
||||
#torchvision.transforms.Resize((224,224), interpolation=2)#VGG
|
||||
torchvision.transforms.ToTensor(),
|
||||
# torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #CIFAR10
|
||||
])
|
||||
#torchvision.transforms.Normalize(MEAN, STD), #CIFAR10
|
||||
# torchvision.transforms.Lambda(lambda tensor: tensor.to(def_type)),
|
||||
]
|
||||
|
||||
transform_train = torchvision.transforms.Compose([
|
||||
transform_train = [
|
||||
#transforms.RandomHorizontalFlip(),
|
||||
#transforms.RandomVerticalFlip(),
|
||||
#torchvision.transforms.Grayscale(3), #MNIST
|
||||
#torchvision.transforms.Resize((224,224), interpolation=2)
|
||||
torchvision.transforms.ToTensor(),
|
||||
])
|
||||
#torchvision.transforms.Normalize(MEAN, STD), #CIFAR10
|
||||
# torchvision.transforms.Lambda(lambda tensor: tensor.to(def_type)),
|
||||
]
|
||||
|
||||
## RandAugment ##
|
||||
#from RandAugment import RandAugment
|
||||
|
@ -49,20 +55,77 @@ transform_train = torchvision.transforms.Compose([
|
|||
#transform_train.transforms.insert(0, RandAugment(n=rand_aug['N'], m=rand_aug['M']))
|
||||
|
||||
### Classic Dataset ###
|
||||
BATCH_SIZE = args.batch_size
|
||||
TEST_SIZE = BATCH_SIZE
|
||||
# Load Dataset
|
||||
if args.dataset == 'MNIST':
|
||||
transform_train.insert(0, torchvision.transforms.Grayscale(3))
|
||||
transform.insert(0, torchvision.transforms.Grayscale(3))
|
||||
|
||||
#MNIST
|
||||
#data_train = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=transform_train)
|
||||
#data_val = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=transform)
|
||||
#data_test = torchvision.datasets.MNIST(dataroot, train=False, download=True, transform=transform)
|
||||
val_set=False
|
||||
data_train = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=torchvision.transforms.Compose(transform_train))
|
||||
data_val = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=torchvision.transforms.Compose(transform))
|
||||
data_test = torchvision.datasets.MNIST(dataroot, train=False, download=True, transform=torchvision.transforms.Compose(transform))
|
||||
elif args.dataset == 'CIFAR10': #(32x32 RGB)
|
||||
val_set=False
|
||||
MEAN=(0.4914, 0.4822, 0.4465)
|
||||
STD=(0.2023, 0.1994, 0.2010)
|
||||
data_train = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform_train))
|
||||
data_val = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform))
|
||||
data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=download_data, transform=torchvision.transforms.Compose(transform))
|
||||
elif args.dataset == 'CIFAR100': #(32x32 RGB)
|
||||
val_set=False
|
||||
MEAN=(0.4914, 0.4822, 0.4465)
|
||||
STD=(0.2023, 0.1994, 0.2010)
|
||||
data_train = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform_train))
|
||||
data_val = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform))
|
||||
data_test = torchvision.datasets.CIFAR100(dataroot, train=False, download=download_data, transform=torchvision.transforms.Compose(transform))
|
||||
elif args.dataset == 'TinyImageNet': #(Train:100k, Val:5k, Test:5k) (64x64 RGB)
|
||||
image_size=64 #128 / 224
|
||||
print('Using image size', image_size)
|
||||
transform_train=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform_train
|
||||
transform=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform
|
||||
|
||||
val_set=True
|
||||
MEAN=(0.485, 0.456, 0.406)
|
||||
STD=(0.229, 0.224, 0.225)
|
||||
data_train = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/train'), transform=torchvision.transforms.Compose(transform_train))
|
||||
data_val = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/val'), transform=torchvision.transforms.Compose(transform))
|
||||
data_test = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/test'), transform=torchvision.transforms.Compose(transform))
|
||||
elif args.dataset == 'ImageNet': #
|
||||
image_size=128 #224
|
||||
print('Using image size', image_size)
|
||||
transform_train=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform_train
|
||||
transform=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform
|
||||
|
||||
val_set=False
|
||||
MEAN=(0.485, 0.456, 0.406)
|
||||
STD=(0.229, 0.224, 0.225)
|
||||
data_train = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/train'), transform=torchvision.transforms.Compose(transform_train))
|
||||
data_val = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/train'), transform=torchvision.transforms.Compose(transform))
|
||||
data_test = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/validation'), transform=torchvision.transforms.Compose(transform))
|
||||
|
||||
#CIFAR
|
||||
data_train = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=transform_train)
|
||||
data_val = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=transform)
|
||||
data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=download_data, transform=transform)
|
||||
else:
|
||||
raise Exception('Unknown dataset')
|
||||
|
||||
# Ready dataloader
|
||||
if not val_set : #Split Training set into Train/Val
|
||||
#Validation set size [0, 1]
|
||||
valid_size=0.1
|
||||
train_subset_indices=range(int(len(data_train)*(1-valid_size)))
|
||||
val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
|
||||
|
||||
from torch.utils.data import SubsetRandomSampler
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=args.workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=args.workers, pin_memory=pin_memory)
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=args.workers, pin_memory=pin_memory)
|
||||
else:
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.workers, pin_memory=pin_memory)
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=args.workers, pin_memory=pin_memory)
|
||||
|
||||
#data_train = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=transform_train)
|
||||
#data_val = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=transform)
|
||||
#data_test = torchvision.datasets.CIFAR100(dataroot, train=False, download=download_data, transform=transform)
|
||||
|
||||
#SVHN
|
||||
#trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=download_data, transform=transform_train)
|
||||
|
@ -76,19 +139,6 @@ data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=downloa
|
|||
#data_train = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='train', transform=transform_train)
|
||||
#data_test = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)
|
||||
|
||||
|
||||
#Validation set size [0, 1]
|
||||
valid_size=0.1
|
||||
train_subset_indices=range(int(len(data_train)*(1-valid_size)))
|
||||
val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
|
||||
|
||||
from torch.utils.data import SubsetRandomSampler
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
||||
#Cross Validation
|
||||
'''
|
||||
import numpy as np
|
||||
|
|
|
@ -18,13 +18,17 @@ import numpy as np
|
|||
import copy
|
||||
|
||||
import transformations as TF
|
||||
import torchvision
|
||||
|
||||
import higher
|
||||
import higher_patch
|
||||
|
||||
from utils import clip_norm
|
||||
from utils import clip_norm
|
||||
from train_utils import compute_vaLoss
|
||||
|
||||
from datasets import MEAN, STD
|
||||
norm = TF.Normalizer(MEAN, STD)
|
||||
|
||||
### Data augmenter ###
|
||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||
"""Data augmentation module with learnable parameters.
|
||||
|
@ -46,19 +50,19 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
_fixed_mag (bool): Wether to lock the TF magnitudes.
|
||||
_fixed_prob (bool): Wether to lock the TF probabilies.
|
||||
_samples (list): Sampled TF index during last forward pass.
|
||||
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||
_fixed_mix (bool): Wether we lock the mix distribution factor.
|
||||
_temp (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||
_fixed_temp (bool): Wether we lock the mix distribution factor.
|
||||
_params (nn.ParameterDict): Learnable parameters.
|
||||
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||
"""
|
||||
def __init__(self, TF_dict, N_TF=1, mix_dist=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
def __init__(self, TF_dict, N_TF=1, temp=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
"""Init Data_augv5.
|
||||
|
||||
Args:
|
||||
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
||||
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0.5)
|
||||
temp (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-temp)*Uniform_distribution + temp*Real_distribution. If None is given, try to learn this parameter. (default: 0.5)
|
||||
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||
|
@ -88,27 +92,30 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._fixed_prob=fixed_prob
|
||||
self._samples = []
|
||||
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0: #Mix dist
|
||||
self._mix_dist = True
|
||||
# self._temp = False
|
||||
# if temp != 0.0: #Mix dist
|
||||
# self._temp = True
|
||||
|
||||
self._fixed_mix=True
|
||||
if mix_dist is None: #Learn Mix dist
|
||||
self._fixed_mix = False
|
||||
mix_dist=0.5
|
||||
self._fixed_temp=True
|
||||
if temp is None: #Learn Temp
|
||||
print("WARNING: Learning Temperature parameter isn't working with this version (No grad)")
|
||||
self._fixed_temp = False
|
||||
temp=0.5
|
||||
|
||||
#Params
|
||||
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
#"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)),
|
||||
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
||||
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
|
||||
"temp": nn.Parameter(torch.tensor(temp))#.clamp(min=0.0,max=0.999))
|
||||
})
|
||||
|
||||
for tf in self._TF_ignore_mag :
|
||||
self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
if not self._shared_mag:
|
||||
for tf in self._TF_ignore_mag :
|
||||
self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
|
@ -117,7 +124,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
else:
|
||||
TF_mag=[t for t in self._TF if t not in self._TF_ignore_mag] #TF w/ differentiable mag
|
||||
self._reg_mask=[self._TF.index(t) for t in TF_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX, dtype=self._params['mag'].dtype) #Encourage amplitude max
|
||||
|
||||
#Prevent Identity
|
||||
#print(TF.TF_identity)
|
||||
|
@ -137,28 +144,44 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
Tensor : Batch of tranformed data.
|
||||
"""
|
||||
self._samples = torch.Tensor([])
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
# x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
|
||||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
|
||||
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"]
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
self._distrib = F.softmax(prob*temp, dim=0)
|
||||
# prob = F.softmax(prob[1:], dim=0) #Bernouilli
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
self._samples=cat_distrib.sample([self._N_seqTF])
|
||||
|
||||
#Bernoulli (Requiert Identité en position 0)
|
||||
#assert(self._TF[0]=="Identity")
|
||||
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf-1), device=device)*self._distrib)
|
||||
# bern_distrib = Bernoulli(torch.tensor([0.5], device=device))
|
||||
# mask = bern_distrib.sample([self._N_seqTF, batch_size]).squeeze()
|
||||
# self._samples=(cat_distrib.sample([self._N_seqTF])+1)*mask
|
||||
|
||||
for sample in self._samples:
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
|
||||
# self._samples.to(device)
|
||||
# for n in range(self._N_seqTF):
|
||||
# # print('temp', (temp+0.3*n))
|
||||
# self._distrib = F.softmax(prob*(temp+0.2*n), dim=0)
|
||||
# # prob = F.softmax(prob[1:], dim=0) #Bernouilli
|
||||
|
||||
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
# new_sample=cat_distrib.sample()
|
||||
# self._samples=torch.cat((self._samples.to(device).to(new_sample.dtype), new_sample.unsqueeze(dim=0)), dim=0)
|
||||
|
||||
# x = self.apply_TF(x, new_sample)
|
||||
# print('sample',self._samples.shape)
|
||||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
|
@ -204,20 +227,20 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
Args:
|
||||
soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False)
|
||||
"""
|
||||
if not self._fixed_prob:
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
|
||||
else:
|
||||
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
# if not self._fixed_prob:
|
||||
# if soft :
|
||||
# self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
|
||||
# else:
|
||||
# self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||
# self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
if not self._fixed_mag:
|
||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||
|
||||
if not self._fixed_mix:
|
||||
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
||||
if not self._fixed_temp:
|
||||
self._params['temp'].data = self._params['temp'].data.clamp(min=0.0, max=0.999)
|
||||
|
||||
def loss_weight(self, mean_norm=False):
|
||||
def loss_weight(self, batch_norm=True):
|
||||
""" Weights for the loss.
|
||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||
Should be applied to the loss before reduction.
|
||||
|
@ -225,30 +248,37 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
Do not take into account the order of application of the TF. See Data_augV7.
|
||||
|
||||
Args:
|
||||
mean_norm (bool): Wether to normalize weights by mean or by distribution. (Default: Normalize by distribution.)
|
||||
Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
|
||||
Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if mix_dist=1.
|
||||
batch_norm (bool): Wether to normalize mean of the weights. (Default: True)
|
||||
|
||||
Returns:
|
||||
Tensor : Loss weights.
|
||||
"""
|
||||
if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
|
||||
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
#prob = F.softmax(prob, dim=0)
|
||||
|
||||
#Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !)
|
||||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||
for sample in self._samples:
|
||||
for sample in self._samples.to(torch.long):
|
||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
||||
w_loss += tmp_w
|
||||
|
||||
if mean_norm:
|
||||
w_loss = w_loss * prob
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
#w_loss=w_loss/w_loss.sum(dim=1, keepdim=True) #Bernoulli
|
||||
|
||||
#Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
|
||||
w_loss = w_loss * prob
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
if batch_norm:
|
||||
w_min = w_loss.min()
|
||||
w_loss = w_loss-w_min if w_min<0 else w_loss
|
||||
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
|
||||
else:
|
||||
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
#Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if temp=1.
|
||||
# w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
# w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
|
@ -310,6 +340,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
Returns:
|
||||
nn.Parameter.
|
||||
"""
|
||||
if key == 'prob': #Override prob access
|
||||
return F.softmax(self._params["prob"]*self._params["temp"], dim=0)
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
|
@ -323,22 +355,20 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
mag_param='Mag'
|
||||
if self._fixed_mag: mag_param+= 'Fx'
|
||||
if self._shared_mag: mag_param+= 'Sh'
|
||||
if not self._mix_dist:
|
||||
return "Data_augV5(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
elif self._fixed_mix:
|
||||
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
# if not self._temp:
|
||||
# return "Data_augV5(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
if self._fixed_temp:
|
||||
return "Data_augV5(T%.1f%s-%dTFx%d-%s)" % (self._params['temp'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
else:
|
||||
return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
return "Data_augV5(T%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
|
||||
class Data_augV7(nn.Module): #Proba sequentielles
|
||||
class Data_augV8(nn.Module): #Apprentissage proba sequentielles
|
||||
"""Data augmentation module with learnable parameters.
|
||||
|
||||
Applies transformations (TF) to batch of data.
|
||||
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
|
||||
The TF probabilities defines a distribution from which we sample the TF applied.
|
||||
|
||||
Replace the use of TF by TF sets which are combinaisons of classic TF.
|
||||
|
||||
Attributes:
|
||||
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||
|
@ -350,37 +380,34 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
_fixed_mag (bool): Wether to lock the TF magnitudes.
|
||||
_fixed_prob (bool): Wether to lock the TF probabilies.
|
||||
_samples (list): Sampled TF index during last forward pass.
|
||||
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||
_fixed_mix (bool): Wether we lock the mix distribution factor.
|
||||
_temp (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||
_fixed_temp (bool): Wether we lock the mix distribution factor.
|
||||
_params (nn.ParameterDict): Learnable parameters.
|
||||
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||
"""
|
||||
def __init__(self, TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
"""Init Data_augv7.
|
||||
def __init__(self, TF_dict, N_TF=1, temp=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
"""Init Data_augv8.
|
||||
|
||||
Args:
|
||||
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||
N_TF (int): Number of TF to be applied sequentially to each inputs. Minimum 2, otherwise prefer using Data_augV5. (default: 2)
|
||||
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
||||
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
|
||||
temp (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-temp)*Uniform_distribution + temp*Real_distribution. If None is given, try to learn this parameter. (default: 0.5)
|
||||
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||
TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
|
||||
"""
|
||||
super(Data_augV7, self).__init__()
|
||||
super(Data_augV8, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
assert N_TF>=0
|
||||
|
||||
if N_TF<2:
|
||||
print("WARNING: Data_augv7 isn't designed to use less than 2 sequentials TF. Please use Data_augv5 instead.")
|
||||
|
||||
self._data_augmentation = True
|
||||
|
||||
#TF
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._TF_ignore_mag= TF_ignore_mag
|
||||
self._TF_ignore_mag=TF_ignore_mag
|
||||
self._nb_tf= len(self._TF)
|
||||
self._N_seqTF = N_TF
|
||||
|
||||
|
@ -395,58 +422,50 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
self._fixed_prob=fixed_prob
|
||||
self._samples = []
|
||||
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0: #Mix dist
|
||||
self._mix_dist = True
|
||||
# self._temp = False
|
||||
# if temp != 0.0: #Mix dist
|
||||
# self._temp = True
|
||||
|
||||
self._fixed_mix=True
|
||||
if mix_dist is None: #Learn Mix dist
|
||||
self._fixed_mix = False
|
||||
mix_dist=0.5
|
||||
self._fixed_temp=True
|
||||
if temp is None: #Learn temp
|
||||
print("WARNING: Learning Temperature parameter isn't working with this version (No grad)")
|
||||
self._fixed_temp = False
|
||||
temp=0.5
|
||||
|
||||
#TF sets
|
||||
#import itertools
|
||||
#itertools.product(range(self._nb_tf), repeat=self._N_seqTF)
|
||||
|
||||
#no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}} #Specific No consecutive ops
|
||||
no_consecutive={idx for idx, t in enumerate(self._TF) if t not in {'Identity'}} #No consecutive same ops (except Identity)
|
||||
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
|
||||
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF (exclude cons_test arrangement)
|
||||
TF_sets=[]
|
||||
if set_size>1:
|
||||
for i in range(n_TF):
|
||||
if not cons_test(i, idx_prefix):
|
||||
TF_sets += generate_TF_sets(n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i])
|
||||
else:
|
||||
TF_sets+=[[idx_prefix+[i]] for i in range(n_TF) if not cons_test(i, idx_prefix)]
|
||||
return TF_sets
|
||||
|
||||
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
|
||||
self._nb_TF_sets=len(self._TF_sets)
|
||||
print("Number of TF sets:",self._nb_TF_sets)
|
||||
#print(self._TF_sets)
|
||||
self._prob_mem=torch.zeros(self._nb_TF_sets)
|
||||
|
||||
#Params
|
||||
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_TF_sets)/self._nb_TF_sets), #Distribution prob uniforme
|
||||
#"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
# "prob": nn.Parameter(torch.ones([self._nb_tf for _ in range(self._N_seqTF)])),
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf**self._N_seqTF)),
|
||||
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
||||
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
|
||||
"temp": nn.Parameter(torch.tensor(temp))#.clamp(min=0.0,max=0.999))
|
||||
})
|
||||
|
||||
#for tf in TF.TF_no_grad :
|
||||
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
self._prob_mem=torch.zeros(self._nb_tf**self._N_seqTF)
|
||||
|
||||
if not self._shared_mag:
|
||||
for tf in self._TF_ignore_mag :
|
||||
self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
if self._shared_mag :
|
||||
self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
|
||||
else:
|
||||
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in self._TF_ignore_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
TF_mag=[t for t in self._TF if t not in self._TF_ignore_mag] #TF w/ differentiable mag
|
||||
self._reg_mask=[self._TF.index(t) for t in TF_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX, dtype=self._params['mag'].dtype) #Encourage amplitude max
|
||||
|
||||
#Prevent Identity
|
||||
#print(TF.TF_identity)
|
||||
#self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=0.0)
|
||||
#for val in TF.TF_identity.keys():
|
||||
# idx=[self._reg_mask.index(self._TF.index(t)) for t in TF_mag if t in TF.TF_identity[val]]
|
||||
# self._reg_tgt[idx]=val
|
||||
#print(TF_mag, self._reg_tgt)
|
||||
|
||||
def forward(self, x):
|
||||
""" Main method of the Data augmentation module.
|
||||
|
@ -457,32 +476,54 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
Returns:
|
||||
Tensor : Batch of tranformed data.
|
||||
"""
|
||||
self._samples = None
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
self._samples = torch.Tensor([])
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
# x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1)
|
||||
# if not self._temp:
|
||||
# self._distrib = torch.ones(1,self._nb_tf**self._N_seqTF,device=device).softmax(dim=1)
|
||||
# else:
|
||||
# prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] #Uniform dist
|
||||
# # print(prob.shape)
|
||||
# # prob = prob.view(1, -1)
|
||||
# # prob = F.softmax(prob, dim=0)
|
||||
|
||||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
|
||||
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
# temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"] #Temperature
|
||||
# self._distrib = F.softmax(temp*prob, dim=0)
|
||||
# # self._distrib = (temp*prob+(1-temp)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
# # print(prob.shape)
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_TF_sets), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
|
||||
self._samples=sample
|
||||
TF_samples=self._TF_sets[sample,:].to(device) #[Batch_size, TFseq]
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"] #Temperature
|
||||
self._distrib = F.softmax(temp*prob, dim=0)
|
||||
|
||||
for i in range(self._N_seqTF):
|
||||
cat_distrib= Categorical(probs=torch.ones((self._nb_tf**self._N_seqTF), device=device)*self._distrib)
|
||||
samples=cat_distrib.sample([batch_size]) # (batch_size)
|
||||
# print(samples.shape)
|
||||
samples=torch.zeros((batch_size, self._nb_tf**self._N_seqTF), dtype=torch.bool, device=device).scatter_(dim=1, index=samples.unsqueeze(dim=1), value=1)
|
||||
self._samples=samples
|
||||
# print(samples.shape)
|
||||
# print(samples)
|
||||
samples=samples.view((batch_size,)+tuple([self._nb_tf for _ in range(self._N_seqTF)]))
|
||||
# print(samples.shape)
|
||||
# print(samples)
|
||||
samples= torch.nonzero(samples)[:,1:].T #Find indexes (TF sequence) => (N_seqTF, batch_size)
|
||||
# print(samples.shape)
|
||||
|
||||
#Bernoulli (Requiert Identité en position 0)
|
||||
#assert(self._TF[0]=="Identity")
|
||||
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf-1), device=device)*self._distrib)
|
||||
# bern_distrib = Bernoulli(torch.tensor([0.5], device=device))
|
||||
# mask = bern_distrib.sample([self._N_seqTF, batch_size]).squeeze()
|
||||
# self._samples=(cat_distrib.sample([self._N_seqTF])+1)*mask
|
||||
|
||||
for sample in samples:
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, TF_samples[:,i])
|
||||
x = self.apply_TF(x, sample)
|
||||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
|
@ -526,37 +567,55 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
|
||||
|
||||
Args:
|
||||
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
|
||||
soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False)
|
||||
"""
|
||||
if not self._fixed_prob:
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
# if not self._fixed_prob:
|
||||
# if soft :
|
||||
# self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
|
||||
# else:
|
||||
# self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||
# self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
if not self._fixed_mag:
|
||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||
|
||||
if not self._fixed_mix:
|
||||
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
|
||||
if not self._fixed_temp:
|
||||
self._params['temp'].data = self._params['temp'].data.clamp(min=0.0, max=0.999)
|
||||
|
||||
def loss_weight(self):
|
||||
def loss_weight(self, batch_norm=True):
|
||||
""" Weights for the loss.
|
||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||
Should be applied to the loss before reduction.
|
||||
|
||||
Args:
|
||||
batch_norm (bool): Wether to normalize mean of the weights. (Default: True)
|
||||
|
||||
Returns:
|
||||
Tensor : Loss weights.
|
||||
"""
|
||||
if self._samples is None : return 1 #Pas d'echantillon = pas de ponderation
|
||||
device=self._params["prob"].device
|
||||
if len(self._samples)==0 : return torch.tensor(1, device=device) #Pas d'echantillon = pas de ponderation
|
||||
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
|
||||
w_loss = torch.zeros((self._samples.shape[0],self._nb_TF_sets), device=self._samples.device)
|
||||
w_loss.scatter_(1, self._samples.view(-1,1), 1)
|
||||
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
# print("prob",prob.shape)
|
||||
# print(self._samples.shape)
|
||||
|
||||
#w_loss=w_loss/w_loss.sum(dim=1, keepdim=True) #Bernoulli
|
||||
|
||||
#Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
|
||||
w_loss = self._samples * prob
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
# print("W_loss",w_loss.shape)
|
||||
# print(w_loss)
|
||||
|
||||
if batch_norm:
|
||||
w_min = w_loss.min()
|
||||
w_loss = w_loss-w_min if w_min<0 else w_loss
|
||||
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
|
||||
|
||||
#Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if temp=1.
|
||||
# w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
# w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
|
@ -573,30 +632,43 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
else:
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
||||
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Close to target ?
|
||||
#max_mag_reg = - reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Far from target ?
|
||||
return max_mag_reg
|
||||
|
||||
def TF_prob(self):
|
||||
""" Gives an estimation of the individual TF probabilities.
|
||||
|
||||
Be warry that the probability returned isn't exact. The TF distribution isn't fully represented by those.
|
||||
Each probability should be taken individualy. They only represent the chance for a specific TF to be picked at least once.
|
||||
|
||||
Returms:
|
||||
Tensor containing the single TF probabilities of applications.
|
||||
"""
|
||||
if torch.all(self._params['prob']!=self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
|
||||
self._prob_mem=self._params['prob'].data.detach_()
|
||||
self._single_TF_prob=torch.zeros(self._nb_tf)
|
||||
for idx_tf in range(self._nb_tf):
|
||||
for i, t_set in enumerate(self._TF_sets):
|
||||
#uni, count = np.unique(t_set, return_counts=True)
|
||||
#if idx_tf in uni:
|
||||
# res[idx_tf]+=self._params['prob'][i]*int(count[np.where(uni==idx_tf)])
|
||||
if idx_tf in t_set:
|
||||
self._single_TF_prob[idx_tf]+=self._params['prob'][i]
|
||||
# if not torch.all(self._params['prob']==self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
|
||||
# self._prob_mem=self._params['prob'].data.detach_()
|
||||
|
||||
return self._single_TF_prob
|
||||
# p = self._params['prob'].view([self._nb_tf for _ in range(self._N_seqTF)])
|
||||
# # print('prob',p)
|
||||
# self._single_TF_prob=p.mean(dim=[i+1 for i in range(self._N_seqTF-1)]) #Reduce to 1D tensor
|
||||
# # print(self._single_TF_prob)
|
||||
# self._single_TF_prob=F.softmax(self._single_TF_prob, dim=0)
|
||||
# print('Soft',self._single_TF_prob)
|
||||
|
||||
p=F.softmax(self._params['prob']*self._params["temp"], dim=0) #Sampling dist
|
||||
p=p.view([self._nb_tf for _ in range(self._N_seqTF)])
|
||||
p=p.mean(dim=[i+1 for i in range(self._N_seqTF-1)]) #Reduce to 1D tensor
|
||||
|
||||
#Means over each dim
|
||||
# dim_idx=tuple(range(self._N_seqTF))
|
||||
# means=[]
|
||||
# for d in dim_idx:
|
||||
# dim_mean=list(dim_idx)
|
||||
# dim_mean.remove(d)
|
||||
# means.append(p.mean(dim=dim_mean).unsqueeze(dim=1))
|
||||
# means=torch.cat(means,dim=1)
|
||||
# print(means)
|
||||
# p=means.mean(dim=1)
|
||||
# print(p)
|
||||
|
||||
return p
|
||||
|
||||
def train(self, mode=True):
|
||||
""" Set the module training mode.
|
||||
|
@ -607,7 +679,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
#if mode is None :
|
||||
# mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(Data_augV7, self).train(mode)
|
||||
super(Data_augV8, self).train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
|
@ -654,12 +726,13 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
mag_param='Mag'
|
||||
if self._fixed_mag: mag_param+= 'Fx'
|
||||
if self._shared_mag: mag_param+= 'Sh'
|
||||
if not self._mix_dist:
|
||||
return "Data_augV7(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
elif self._fixed_mix:
|
||||
return "Data_augV7(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
# if not self._temp:
|
||||
# return "Data_augV8(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
if self._fixed_temp:
|
||||
return "Data_augV8(T%.1f%s-%dTFx%d-%s)" % (self._params['temp'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
else:
|
||||
return "Data_augV7(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
return "Data_augV8(T%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
|
||||
|
||||
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||
"""RandAugment implementation.
|
||||
|
@ -703,7 +776,7 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
self._shared_mag = True
|
||||
self._fixed_mag = True
|
||||
self._fixed_prob=True
|
||||
self._fixed_mix=True
|
||||
self._fixed_temp=True
|
||||
|
||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||
|
||||
|
@ -716,17 +789,24 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
Returns:
|
||||
Tensor : Batch of tranformed data.
|
||||
"""
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
# x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
|
||||
## Echantillonage ## == sampled_ops = np.random.choice(transforms, N)
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist)
|
||||
self._samples=cat_distrib.sample([self._N_seqTF])
|
||||
|
||||
#Bernoulli (Requiert Identité en position 0)
|
||||
# uniforme_dist = torch.ones(1,self._nb_tf-1,device=device).softmax(dim=1)
|
||||
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf-1), device=device)*uniforme_dist)
|
||||
# bern_distrib = Bernoulli(torch.tensor([0.5], device=device))
|
||||
# mask = bern_distrib.sample([self._N_seqTF, batch_size]).squeeze()
|
||||
# self._samples=(cat_distrib.sample([self._N_seqTF])+1)*mask
|
||||
|
||||
for sample in self._samples:
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
|
@ -765,10 +845,10 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
"""
|
||||
pass #Pas de parametre a opti
|
||||
|
||||
def loss_weight(self):
|
||||
def loss_weight(self, batch_norm=False):
|
||||
"""Not used
|
||||
"""
|
||||
return 1 #Pas d'echantillon = pas de ponderation
|
||||
return torch.tensor([1], device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
"""Not used
|
||||
|
@ -949,18 +1029,22 @@ class Augmented_model(nn.Module):
|
|||
|
||||
self.augment(mode=True)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, copy=False):
|
||||
""" Main method of the Augmented model.
|
||||
|
||||
Perform the forward pass of both modules.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
copy (Bool): Wether to alter a copy or the original input. It's recommended to use a copy for parallel use of the input. (Default: False)
|
||||
|
||||
Returns:
|
||||
Tensor : Output of the networks. Should be logits.
|
||||
"""
|
||||
return self._mods['model'](self._mods['data_aug'](x))
|
||||
if copy:
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
return self._mods['model'](norm(self._mods['data_aug'](x)))
|
||||
# return self._mods['model'](self._mods['data_aug'](x))
|
||||
|
||||
def augment(self, mode=True):
|
||||
""" Set the augmentation mode.
|
||||
|
@ -970,6 +1054,12 @@ class Augmented_model(nn.Module):
|
|||
"""
|
||||
self._data_augmentation=mode
|
||||
self._mods['data_aug'].augment(mode)
|
||||
|
||||
#ABN
|
||||
# if mode :
|
||||
# self._mods['model']['functional'].set_mode('augmented')
|
||||
# else :
|
||||
# self._mods['model']['functional'].set_mode('clean')
|
||||
|
||||
#### Encapsulation Meta Opt ####
|
||||
def start_bilevel_opt(self, inner_it, hp_list, opt_param, dl_val):
|
||||
|
|
56
higher/smart_aug/nets/LeNet.py
Normal file
56
higher/smart_aug/nets/LeNet.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
## Basic CNN ##
|
||||
class LeNet(nn.Module):
|
||||
"""Basic CNN.
|
||||
|
||||
"""
|
||||
def __init__(self, num_inp, num_out):
|
||||
"""Init LeNet.
|
||||
|
||||
"""
|
||||
super(LeNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(num_inp, 20, 5)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(20, 50, 5)
|
||||
self.pool2 = nn.MaxPool2d(2, 2)
|
||||
#self.fc1 = nn.Linear(4*4*50, 500)
|
||||
self.fc1 = nn.Linear(5*5*50, 500)
|
||||
self.fc2 = nn.Linear(500, num_out)
|
||||
|
||||
def forward(self, x):
|
||||
"""Main method of LeNet
|
||||
|
||||
"""
|
||||
x = self.pool(F.relu(self.conv1(x)))
|
||||
x = self.pool2(F.relu(self.conv2(x)))
|
||||
x = x.view(x.size(0), -1)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "LeNet"
|
||||
|
||||
#MNIST
|
||||
class MLPNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(MLPNet, self).__init__()
|
||||
self.fc1 = nn.Linear(28*28, 500)
|
||||
self.fc2 = nn.Linear(500, 256)
|
||||
self.fc3 = nn.Linear(256, 10)
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 28*28)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
def name(self):
|
||||
return "MLP"
|
426
higher/smart_aug/nets/resnet_abn.py
Normal file
426
higher/smart_aug/nets/resnet_abn.py
Normal file
|
@ -0,0 +1,426 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
|
||||
|
||||
# __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
# 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
# 'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
|
||||
# model_urls = {
|
||||
# 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
# 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
# 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
# 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
# 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
# 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
# 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
# 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
# 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
# }
|
||||
|
||||
__all__ = ['ResNet_ABN', 'resnet18_ABN', 'resnet34_ABN', 'resnet50_ABN', 'resnet101_ABN',
|
||||
'resnet152_ABN', 'resnext50_32x4d_ABN', 'resnext101_32x8d_ABN',
|
||||
'wide_resnet50_2_ABN', 'wide_resnet101_2_ABN']
|
||||
|
||||
class aux_batchNorm(nn.Module):
|
||||
def __init__(self, norm_layer, nb_features):
|
||||
super(aux_batchNorm, self).__init__()
|
||||
self.mode='clean'
|
||||
self.bn=nn.ModuleDict({
|
||||
'clean': norm_layer(nb_features),
|
||||
'augmented': norm_layer(nb_features)
|
||||
})
|
||||
def forward(self, x):
|
||||
if self.mode is 'mixed':
|
||||
running_mean=(self.bn['clean'].running_mean+self.bn['augmented'].running_mean)/2
|
||||
running_var=(self.bn['clean'].running_var+self.bn['augmented'].running_var)/2
|
||||
return nn.functional.batch_norm(x, running_mean, running_var, self.bn['clean'].weight, self.bn['clean'].bias)
|
||||
return self.bn[self.mode](x)
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock_ABN(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock_ABN, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
#self.bn1 = norm_layer(planes)
|
||||
self.bn1 = aux_batchNorm(norm_layer, planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
#self.bn2 = norm_layer(planes)
|
||||
self.bn2 = aux_batchNorm(norm_layer, planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck_ABN(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(Bottleneck_ABN, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
#self.bn1 = norm_layer(width)
|
||||
self.bn1 = aux_batchNorm(norm_layer, width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
# self.bn2 = norm_layer(width)
|
||||
self.bn2 = aux_batchNorm(norm_layer, width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
# self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.bn3 = aux_batchNorm(norm_layer, planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_ABN(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None):
|
||||
super(ResNet_ABN, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
#self.bn1 = norm_layer(self.inplanes)
|
||||
self.bn1 = aux_batchNorm(norm_layer, self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
print('WARNING : zero_init_residual not implemented with ABN')
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, Bottleneck):
|
||||
# nn.init.constant_(m.bn3.weight, 0)
|
||||
# elif isinstance(m, BasicBlock):
|
||||
# nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
# Memoire des BN layers pas fonctinnel avec Higher
|
||||
# self.bn_layers=[]
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, aux_batchNorm):
|
||||
# self.bn_layers.append(m)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
#norm_layer(planes * block.expansion),
|
||||
aux_batchNorm(norm_layer, planes * block.expansion)
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def set_mode(self, mode):
|
||||
# for bn in self.bn_layers:
|
||||
for m in self.modules():
|
||||
if isinstance(m, aux_batchNorm):
|
||||
m.mode=mode
|
||||
|
||||
|
||||
|
||||
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
# model = ResNet(block, layers, **kwargs)
|
||||
# if pretrained:
|
||||
# state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
# progress=progress)
|
||||
# model.load_state_dict(state_dict)
|
||||
# return model
|
||||
|
||||
|
||||
def resnet18_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(BasicBlock_ABN, [2, 2, 2, 2], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet34_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(BasicBlock_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet50_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet101_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet152_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 8, 36, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnext50_32x4d_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
# return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
# return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def wide_resnet50_2_ABN(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
# return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def wide_resnet101_2_ABN(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
# return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
|
618
higher/smart_aug/nets/resnet_deconv.py
Normal file
618
higher/smart_aug/nets/resnet_deconv.py
Normal file
|
@ -0,0 +1,618 @@
|
|||
'''ResNet in PyTorch.
|
||||
For Pre-activation ResNet, see 'preact_resnet.py'.
|
||||
Reference:
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
||||
|
||||
https://github.com/yechengxi/deconvolution
|
||||
'''
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from torch.nn.modules import conv
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from functools import partial
|
||||
|
||||
__all__ = ['ResNet18_DC', 'ResNet34_DC', 'ResNet50_DC', 'ResNet101_DC', 'ResNet152_DC', 'WRN_DC26_10']
|
||||
|
||||
### Deconvolution ###
|
||||
|
||||
#iteratively solve for inverse sqrt of a matrix
|
||||
def isqrt_newton_schulz_autograd(A, numIters):
|
||||
dim = A.shape[0]
|
||||
normA=A.norm()
|
||||
Y = A.div(normA)
|
||||
I = torch.eye(dim,dtype=A.dtype,device=A.device)
|
||||
Z = torch.eye(dim,dtype=A.dtype,device=A.device)
|
||||
|
||||
for i in range(numIters):
|
||||
T = 0.5*(3.0*I - Z@Y)
|
||||
Y = Y@T
|
||||
Z = T@Z
|
||||
#A_sqrt = Y*torch.sqrt(normA)
|
||||
A_isqrt = Z / torch.sqrt(normA)
|
||||
return A_isqrt
|
||||
|
||||
def isqrt_newton_schulz_autograd_batch(A, numIters):
|
||||
batchSize,dim,_ = A.shape
|
||||
normA=A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1)
|
||||
Y = A.div(normA)
|
||||
I = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
|
||||
Z = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
|
||||
|
||||
for i in range(numIters):
|
||||
T = 0.5*(3.0*I - Z.bmm(Y))
|
||||
Y = Y.bmm(T)
|
||||
Z = T.bmm(Z)
|
||||
#A_sqrt = Y*torch.sqrt(normA)
|
||||
A_isqrt = Z / torch.sqrt(normA)
|
||||
|
||||
return A_isqrt
|
||||
|
||||
|
||||
|
||||
#deconvolve channels
|
||||
class ChannelDeconv(nn.Module):
|
||||
def __init__(self, block, eps=1e-2,n_iter=5,momentum=0.1,sampling_stride=3):
|
||||
super(ChannelDeconv, self).__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.n_iter=n_iter
|
||||
self.momentum=momentum
|
||||
self.block = block
|
||||
|
||||
self.register_buffer('running_mean1', torch.zeros(block, 1))
|
||||
#self.register_buffer('running_cov', torch.eye(block))
|
||||
self.register_buffer('running_deconv', torch.eye(block))
|
||||
self.register_buffer('running_mean2', torch.zeros(1, 1))
|
||||
self.register_buffer('running_var', torch.ones(1, 1))
|
||||
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
|
||||
self.sampling_stride=sampling_stride
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
if len(x.shape)==2:
|
||||
x=x.view(x.shape[0],x.shape[1],1,1)
|
||||
if len(x.shape)==3:
|
||||
print('Error! Unsupprted tensor shape.')
|
||||
|
||||
N, C, H, W = x.size()
|
||||
B = self.block
|
||||
|
||||
#take the first c channels out for deconv
|
||||
c=int(C/B)*B
|
||||
if c==0:
|
||||
print('Error! block should be set smaller.')
|
||||
|
||||
#step 1. remove mean
|
||||
if c!=C:
|
||||
x1=x[:,:c].permute(1,0,2,3).contiguous().view(B,-1)
|
||||
else:
|
||||
x1=x.permute(1,0,2,3).contiguous().view(B,-1)
|
||||
|
||||
if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride:
|
||||
x1_s = x1[:,::self.sampling_stride**2]
|
||||
else:
|
||||
x1_s=x1
|
||||
|
||||
mean1 = x1_s.mean(-1, keepdim=True)
|
||||
|
||||
if self.num_batches_tracked==0:
|
||||
self.running_mean1.copy_(mean1.detach())
|
||||
if self.training:
|
||||
self.running_mean1.mul_(1-self.momentum)
|
||||
self.running_mean1.add_(mean1.detach()*self.momentum)
|
||||
else:
|
||||
mean1 = self.running_mean1
|
||||
|
||||
x1=x1-mean1
|
||||
|
||||
#step 2. calculate deconv@x1 = cov^(-0.5)@x1
|
||||
if self.training:
|
||||
cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(B, dtype=x.dtype, device=x.device)
|
||||
deconv = isqrt_newton_schulz_autograd(cov, self.n_iter)
|
||||
|
||||
if self.num_batches_tracked==0:
|
||||
#self.running_cov.copy_(cov.detach())
|
||||
self.running_deconv.copy_(deconv.detach())
|
||||
|
||||
if self.training:
|
||||
#self.running_cov.mul_(1-self.momentum)
|
||||
#self.running_cov.add_(cov.detach()*self.momentum)
|
||||
self.running_deconv.mul_(1 - self.momentum)
|
||||
self.running_deconv.add_(deconv.detach() * self.momentum)
|
||||
else:
|
||||
# cov = self.running_cov
|
||||
deconv = self.running_deconv
|
||||
|
||||
x1 =deconv@x1
|
||||
|
||||
#reshape to N,c,J,W
|
||||
x1 = x1.view(c, N, H, W).contiguous().permute(1,0,2,3)
|
||||
|
||||
# normalize the remaining channels
|
||||
if c!=C:
|
||||
x_tmp=x[:, c:].view(N,-1)
|
||||
if self.sampling_stride > 1 and H>=self.sampling_stride and W>=self.sampling_stride:
|
||||
x_s = x_tmp[:, ::self.sampling_stride ** 2]
|
||||
else:
|
||||
x_s = x_tmp
|
||||
|
||||
mean2=x_s.mean()
|
||||
var=x_s.var()
|
||||
|
||||
if self.num_batches_tracked == 0:
|
||||
self.running_mean2.copy_(mean2.detach())
|
||||
self.running_var.copy_(var.detach())
|
||||
|
||||
if self.training:
|
||||
self.running_mean2.mul_(1 - self.momentum)
|
||||
self.running_mean2.add_(mean2.detach() * self.momentum)
|
||||
self.running_var.mul_(1 - self.momentum)
|
||||
self.running_var.add_(var.detach() * self.momentum)
|
||||
else:
|
||||
mean2 = self.running_mean2
|
||||
var = self.running_var
|
||||
|
||||
x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt()
|
||||
x1 = torch.cat([x1, x_tmp], dim=1)
|
||||
|
||||
|
||||
if self.training:
|
||||
self.num_batches_tracked.add_(1)
|
||||
|
||||
if len(x_shape)==2:
|
||||
x1=x1.view(x_shape)
|
||||
return x1
|
||||
|
||||
#An alternative implementation
|
||||
class Delinear(nn.Module):
|
||||
__constants__ = ['bias', 'in_features', 'out_features']
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=512):
|
||||
super(Delinear, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
||||
|
||||
|
||||
if block > in_features:
|
||||
block = in_features
|
||||
else:
|
||||
if in_features%block!=0:
|
||||
block=math.gcd(block,in_features)
|
||||
print('block size set to:', block)
|
||||
self.block = block
|
||||
self.momentum = momentum
|
||||
self.n_iter = n_iter
|
||||
self.eps = eps
|
||||
self.register_buffer('running_mean', torch.zeros(self.block))
|
||||
self.register_buffer('running_deconv', torch.eye(self.block))
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
if self.training:
|
||||
|
||||
# 1. reshape
|
||||
X=input.view(-1, self.block)
|
||||
|
||||
# 2. subtract mean
|
||||
X_mean = X.mean(0)
|
||||
X = X - X_mean.unsqueeze(0)
|
||||
self.running_mean.mul_(1 - self.momentum)
|
||||
self.running_mean.add_(X_mean.detach() * self.momentum)
|
||||
|
||||
# 3. calculate COV, COV^(-0.5), then deconv
|
||||
# Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Id = torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
|
||||
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
|
||||
# track stats for evaluation
|
||||
self.running_deconv.mul_(1 - self.momentum)
|
||||
self.running_deconv.add_(deconv.detach() * self.momentum)
|
||||
|
||||
else:
|
||||
X_mean = self.running_mean
|
||||
deconv = self.running_deconv
|
||||
|
||||
w = self.weight.view(-1, self.block) @ deconv
|
||||
b = self.bias
|
||||
if self.bias is not None:
|
||||
b = b - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
|
||||
w = w.view(self.weight.shape)
|
||||
return F.linear(input, w, b)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None
|
||||
)
|
||||
|
||||
|
||||
|
||||
class FastDeconv(conv._ConvNd):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,freeze=False,freeze_iter=100):
|
||||
self.momentum = momentum
|
||||
self.n_iter = n_iter
|
||||
self.eps = eps
|
||||
self.counter=0
|
||||
self.track_running_stats=True
|
||||
super(FastDeconv, self).__init__(
|
||||
in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
|
||||
False, _pair(0), groups, bias, padding_mode='zeros')
|
||||
|
||||
if block > in_channels:
|
||||
block = in_channels
|
||||
else:
|
||||
if in_channels%block!=0:
|
||||
block=math.gcd(block,in_channels)
|
||||
|
||||
if groups>1:
|
||||
#grouped conv
|
||||
block=in_channels//groups
|
||||
|
||||
self.block=block
|
||||
|
||||
self.num_features = kernel_size**2 *block
|
||||
if groups==1:
|
||||
self.register_buffer('running_mean', torch.zeros(self.num_features))
|
||||
self.register_buffer('running_deconv', torch.eye(self.num_features))
|
||||
else:
|
||||
self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels))
|
||||
self.register_buffer('running_deconv', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))
|
||||
|
||||
self.sampling_stride=sampling_stride*stride
|
||||
self.counter=0
|
||||
self.freeze_iter=freeze_iter
|
||||
self.freeze=freeze
|
||||
|
||||
def forward(self, x):
|
||||
N, C, H, W = x.shape
|
||||
B = self.block
|
||||
frozen=self.freeze and (self.counter>self.freeze_iter)
|
||||
if self.training and self.track_running_stats:
|
||||
self.counter+=1
|
||||
self.counter %= (self.freeze_iter * 10)
|
||||
|
||||
if self.training and (not frozen):
|
||||
|
||||
# 1. im2col: N x cols x pixels -> N*pixles x cols
|
||||
if self.kernel_size[0]>1:
|
||||
X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
|
||||
else:
|
||||
#channel wise
|
||||
X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]
|
||||
|
||||
if self.groups==1:
|
||||
# (C//B*N*pixels,k*k*B)
|
||||
X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
|
||||
else:
|
||||
X=X.view(-1,X.shape[-1])
|
||||
|
||||
# 2. subtract mean
|
||||
X_mean = X.mean(0)
|
||||
X = X - X_mean.unsqueeze(0)
|
||||
|
||||
# 3. calculate COV, COV^(-0.5), then deconv
|
||||
if self.groups==1:
|
||||
#Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Id=torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
|
||||
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
|
||||
else:
|
||||
X = X.view(-1, self.groups, self.num_features).transpose(0, 1)
|
||||
Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
|
||||
Cov = torch.baddbmm(self.eps, Id, 1. / X.shape[1], X.transpose(1, 2), X)
|
||||
|
||||
deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter)
|
||||
|
||||
if self.track_running_stats:
|
||||
self.running_mean.mul_(1 - self.momentum)
|
||||
self.running_mean.add_(X_mean.detach() * self.momentum)
|
||||
# track stats for evaluation
|
||||
self.running_deconv.mul_(1 - self.momentum)
|
||||
self.running_deconv.add_(deconv.detach() * self.momentum)
|
||||
|
||||
else:
|
||||
X_mean = self.running_mean
|
||||
deconv = self.running_deconv
|
||||
|
||||
#4. X * deconv * conv = X * (deconv * conv)
|
||||
if self.groups==1:
|
||||
w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ deconv
|
||||
b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
|
||||
w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
|
||||
else:
|
||||
w = self.weight.view(C//B, -1,self.num_features)@deconv
|
||||
b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)
|
||||
|
||||
w = w.view(self.weight.shape)
|
||||
x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
return x
|
||||
|
||||
### ResNet
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, deconv=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if deconv:
|
||||
self.conv1 = deconv(in_planes, planes, kernel_size=3, stride=stride, padding=1)
|
||||
self.conv2 = deconv(planes, planes, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv = True
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.deconv = False
|
||||
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if not deconv:
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
#self.bn1 = nn.GroupNorm(planes//16,planes)
|
||||
#self.bn2 = nn.GroupNorm(planes//16,planes)
|
||||
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
#nn.GroupNorm(self.expansion * planes//16,self.expansion * planes)
|
||||
)
|
||||
else:
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
deconv(in_planes, self.expansion*planes, kernel_size=1, stride=stride)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.deconv:
|
||||
out = F.relu(self.conv1(x))
|
||||
out = self.conv2(out)
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
else: #self.batch_norm:
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, deconv=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
|
||||
|
||||
if deconv:
|
||||
self.deconv = True
|
||||
self.conv1 = deconv(in_planes, planes, kernel_size=1)
|
||||
self.conv2 = deconv(planes, planes, kernel_size=3, stride=stride, padding=1)
|
||||
self.conv3 = deconv(planes, self.expansion*planes, kernel_size=1)
|
||||
|
||||
else:
|
||||
self.deconv = False
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if not deconv:
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes)
|
||||
)
|
||||
else:
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
deconv(in_planes, self.expansion * planes, kernel_size=1, stride=stride)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
"""
|
||||
No batch normalization for deconv.
|
||||
"""
|
||||
if self.deconv:
|
||||
out = F.relu((self.conv1(x)))
|
||||
out = F.relu((self.conv2(out)))
|
||||
out = self.conv3(out)
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
else:
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10, deconv=None,channel_deconv=None):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
if deconv:
|
||||
self.deconv = True
|
||||
self.conv1 = deconv(3, 64, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
|
||||
if not deconv:
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
|
||||
#this line is really recent, take extreme care if the result is not good.
|
||||
if channel_deconv:
|
||||
self.deconv1=channel_deconv()
|
||||
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, deconv=deconv)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, deconv=deconv)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, deconv=deconv)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, deconv=deconv)
|
||||
self.linear = nn.Linear(512*block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride, deconv):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride, deconv))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self,'bn1'):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
else:
|
||||
out = F.relu(self.conv1(x))
|
||||
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
if hasattr(self, 'deconv1'):
|
||||
out = self.deconv1(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def_deconv = partial(FastDeconv,bias=True, eps=1e-5, n_iter=5,block=64,sampling_stride=3)
|
||||
#channel_deconv=partial(ChannelDeconv, block=512,eps=1e-5, n_iter=5,sampling_stride=3) #Pas forcément conseillé
|
||||
|
||||
def ResNet18_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet34_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet50_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet101_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet152_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
import math
|
||||
class Wide_ResNet_Cifar_DC(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, wfactor, num_classes=10, deconv=None, channel_deconv=None):
|
||||
super(Wide_ResNet_Cifar_DC, self).__init__()
|
||||
self.depth=layers[0]*6+2
|
||||
self.widen_factor=wfactor
|
||||
|
||||
self.inplanes = 16
|
||||
self.conv1 = deconv(3, 16, kernel_size=3, stride=1, padding=1)
|
||||
if channel_deconv:
|
||||
self.deconv1=channel_deconv()
|
||||
# self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
# self.bn1 = nn.BatchNorm2d(16)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(block, 16*wfactor, layers[0], stride=1, deconv=deconv)
|
||||
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2, deconv=deconv)
|
||||
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2, deconv=deconv)
|
||||
self.avgpool = nn.AvgPool2d(8, stride=1)
|
||||
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride, deconv):
|
||||
# downsample = None
|
||||
# if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
# downsample = nn.Sequential(
|
||||
# nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
# nn.BatchNorm2d(planes * block.expansion)
|
||||
# )
|
||||
|
||||
# layers = []
|
||||
# layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
# self.inplanes = planes * block.expansion
|
||||
# for _ in range(1, blocks):
|
||||
# layers.append(block(self.inplanes, planes))
|
||||
|
||||
# return nn.Sequential(*layers)
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.inplanes, planes, stride, deconv))
|
||||
self.inplanes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
# x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
if hasattr(self, 'deconv1'):
|
||||
out = self.deconv1(out)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "Wide_ResNet_cifar_DC%d_%d"%(self.depth,self.widen_factor)
|
||||
|
||||
def WRN_DC26_10(depth=26, width=10, deconv=def_deconv, channel_deconv=None, **kwargs):
|
||||
assert (depth - 2) % 6 == 0
|
||||
n = int((depth - 2) / 6)
|
||||
return Wide_ResNet_Cifar_DC(BasicBlock, [n, n, n], width, deconv=deconv,channel_deconv=channel_deconv, **kwargs)
|
||||
|
||||
def test():
|
||||
net = ResNet18_DC()
|
||||
y = net(torch.randn(1,3,32,32))
|
||||
print(y.size())
|
||||
|
||||
# test()
|
98
higher/smart_aug/nets/wideresnet.py
Normal file
98
higher/smart_aug/nets/wideresnet.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
_bn_momentum = 0.1
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
|
||||
|
||||
|
||||
def conv_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
init.xavier_uniform_(m.weight, gain=np.sqrt(2))
|
||||
init.constant_(m.bias, 0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class WideBasic(nn.Module):
|
||||
def __init__(self, in_planes, planes, dropout_rate, stride=1):
|
||||
super(WideBasic, self).__init__()
|
||||
assert dropout_rate==0.0, 'dropout layer not used'
|
||||
self.bn1 = nn.BatchNorm2d(in_planes, momentum=_bn_momentum)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
|
||||
#self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.bn2 = nn.BatchNorm2d(planes, momentum=_bn_momentum)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# out = self.dropout(self.conv1(F.relu(self.bn1(x))))
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out += self.shortcut(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class WideResNet(nn.Module):
|
||||
def __init__(self, depth, widen_factor, dropout_rate, num_classes):
|
||||
super(WideResNet, self).__init__()
|
||||
self.depth=depth
|
||||
self.widen_factor=widen_factor
|
||||
self.in_planes = 16
|
||||
|
||||
assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
|
||||
n = int((depth - 4) / 6)
|
||||
k = widen_factor
|
||||
|
||||
nStages = [16, 16*k, 32*k, 64*k]
|
||||
|
||||
self.conv1 = conv3x3(3, nStages[0])
|
||||
self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
|
||||
self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
|
||||
self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)
|
||||
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=_bn_momentum)
|
||||
self.linear = nn.Linear(nStages[3], num_classes)
|
||||
|
||||
# self.apply(conv_init)
|
||||
|
||||
def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, dropout_rate, stride))
|
||||
self.in_planes = planes
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = F.relu(self.bn1(out))
|
||||
# out = F.avg_pool2d(out, 8)
|
||||
out = F.adaptive_avg_pool2d(out, (1, 1))
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
|
||||
return out
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "Wide_ResNet%d_%d"%(self.depth,self.widen_factor)
|
119
higher/smart_aug/nets/wideresnet_cifar.py
Normal file
119
higher/smart_aug/nets/wideresnet_cifar.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
wide resnet for cifar in pytorch
|
||||
Reference:
|
||||
[1] S. Zagoruyko and N. Komodakis. Wide residual networks. In BMVC, 2016.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
#from models.resnet_cifar import BasicBlock
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
" 3x3 convolution with padding "
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion=1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class Wide_ResNet_Cifar(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, wfactor, num_classes=10):
|
||||
super(Wide_ResNet_Cifar, self).__init__()
|
||||
self.depth=layers[0]*6+2
|
||||
self.widen_factor=wfactor
|
||||
|
||||
self.inplanes = 16
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(block, 16*wfactor, layers[0])
|
||||
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(8, stride=1)
|
||||
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion)
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "Wide_ResNet_cifar%d_%d"%(self.depth,self.widen_factor)
|
||||
|
||||
|
||||
def wide_resnet_cifar(depth, width, **kwargs):
|
||||
assert (depth - 2) % 6 == 0
|
||||
n = int((depth - 2) / 6)
|
||||
return Wide_ResNet_Cifar(BasicBlock, [n, n, n], width, **kwargs)
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
net = wide_resnet_cifar(20, 10)
|
||||
y = net(torch.randn(1, 3, 32, 32))
|
||||
print(isinstance(net, Wide_ResNet_Cifar))
|
||||
print(y.size())
|
|
@ -1063,3 +1063,362 @@ class AugmentedDataset(VisionDataset):
|
|||
|
||||
def __str__(self):
|
||||
return "CIFAR10(Sup:{}-Unsup:{}-{}TF)".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF))
|
||||
|
||||
class Data_augV7(nn.Module): #Proba sequentielles
|
||||
"""Data augmentation module with learnable parameters.
|
||||
|
||||
Applies transformations (TF) to batch of data.
|
||||
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
|
||||
The TF probabilities defines a distribution from which we sample the TF applied.
|
||||
|
||||
Replace the use of TF by TF sets which are combinaisons of classic TF.
|
||||
|
||||
Attributes:
|
||||
_data_augmentation (bool): Wether TF will be applied during forward pass.
|
||||
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
|
||||
_TF (list) : List of TF names.
|
||||
_TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
|
||||
_nb_tf (int) : Number of TF used.
|
||||
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
|
||||
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Beware using shared mag with basic color TF as their lowest magnitude is at PARAMETER_MAX/2.
|
||||
_fixed_mag (bool): Wether to lock the TF magnitudes.
|
||||
_fixed_prob (bool): Wether to lock the TF probabilies.
|
||||
_samples (list): Sampled TF index during last forward pass.
|
||||
_temp (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
|
||||
_fixed_temp (bool): Wether we lock the mix distribution factor.
|
||||
_params (nn.ParameterDict): Learnable parameters.
|
||||
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
|
||||
_reg_mask (list): Mask selecting the TF considered for the regularisation.
|
||||
"""
|
||||
def __init__(self, TF_dict, N_TF=1, temp=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
|
||||
"""Init Data_augv7.
|
||||
|
||||
Args:
|
||||
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
|
||||
N_TF (int): Number of TF to be applied sequentially to each inputs. Minimum 2, otherwise prefer using Data_augV5. (default: 2)
|
||||
temp (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-temp)*Uniform_distribution + temp*Real_distribution. If None is given, try to learn this parameter. (default: 0)
|
||||
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
|
||||
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
|
||||
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
|
||||
TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
|
||||
"""
|
||||
super(Data_augV7, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
assert N_TF>=0
|
||||
|
||||
if N_TF<2:
|
||||
print("WARNING: Data_augv7 isn't designed to use less than 2 sequentials TF. Please use Data_augv5 instead.")
|
||||
|
||||
self._data_augmentation = True
|
||||
|
||||
#TF
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._TF_ignore_mag= TF_ignore_mag
|
||||
self._nb_tf= len(self._TF)
|
||||
self._N_seqTF = N_TF
|
||||
|
||||
#Mag
|
||||
self._shared_mag = shared_mag
|
||||
self._fixed_mag = fixed_mag
|
||||
if not self._fixed_mag and len([tf for tf in self._TF if tf not in self._TF_ignore_mag])==0:
|
||||
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
|
||||
self._fixed_mag=True
|
||||
|
||||
#Distribution
|
||||
self._fixed_prob=fixed_prob
|
||||
self._samples = []
|
||||
|
||||
# self._temp = False
|
||||
# if temp != 0.0: #Mix dist
|
||||
# self._temp = True
|
||||
|
||||
self._fixed_temp=True
|
||||
if temp is None: #Learn Temperature
|
||||
print("WARNING: Learning Temperature parameter isn't working with this version (No grad)")
|
||||
self._fixed_temp = False
|
||||
temp=0.5
|
||||
|
||||
#TF sets
|
||||
#import itertools
|
||||
#itertools.product(range(self._nb_tf), repeat=self._N_seqTF)
|
||||
|
||||
#no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}} #Specific No consecutive ops
|
||||
no_consecutive={idx for idx, t in enumerate(self._TF) if t not in {'Identity'}} #No consecutive same ops (except Identity)
|
||||
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
|
||||
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF (exclude cons_test arrangement)
|
||||
TF_sets=[]
|
||||
if set_size>1:
|
||||
for i in range(n_TF):
|
||||
if not cons_test(i, idx_prefix):
|
||||
TF_sets += generate_TF_sets(n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i])
|
||||
else:
|
||||
TF_sets+=[[idx_prefix+[i]] for i in range(n_TF) if not cons_test(i, idx_prefix)]
|
||||
return TF_sets
|
||||
|
||||
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
|
||||
self._nb_TF_sets=len(self._TF_sets)
|
||||
print("Number of TF sets:",self._nb_TF_sets)
|
||||
#print(self._TF_sets)
|
||||
self._prob_mem=torch.zeros(self._nb_TF_sets)
|
||||
|
||||
#Params
|
||||
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||
self._params = nn.ParameterDict({
|
||||
#"prob": nn.Parameter(torch.ones(self._nb_TF_sets)/self._nb_TF_sets), #Distribution prob uniforme
|
||||
"prob": nn.Parameter(torch.ones(self._nb_TF_sets)),
|
||||
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
||||
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
"temp": nn.Parameter(torch.tensor(temp))#.clamp(min=0.0,max=0.999))
|
||||
})
|
||||
|
||||
#for tf in TF.TF_no_grad :
|
||||
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
if self._shared_mag :
|
||||
self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
else:
|
||||
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in self._TF_ignore_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
def forward(self, x):
|
||||
""" Main method of the Data augmentation module.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
|
||||
Returns:
|
||||
Tensor : Batch of tranformed data.
|
||||
"""
|
||||
self._samples = None
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
|
||||
## Echantillonage ##
|
||||
# uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1)
|
||||
|
||||
# if not self._temp:
|
||||
# self._distrib = uniforme_dist
|
||||
# else:
|
||||
# prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
# prob = F.softmax(prob, dim=0)
|
||||
# temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"]
|
||||
# self._distrib = (temp*prob+(1-temp)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_TF_sets), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
|
||||
self._samples=sample
|
||||
TF_samples=self._TF_sets[sample,:].to(device) #[Batch_size, TFseq]
|
||||
|
||||
for i in range(self._N_seqTF):
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, TF_samples[:,i])
|
||||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
""" Applies the sampled transformations.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of tranformed data.
|
||||
"""
|
||||
device = x.device
|
||||
batch_size, channels, h, w = x.shape
|
||||
smps_x=[]
|
||||
|
||||
for tf_idx in range(self._nb_tf):
|
||||
mask = sampled_TF==tf_idx #Create selection mask
|
||||
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
|
||||
|
||||
if smp_x.shape[0]!=0: #if there's data to TF
|
||||
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
|
||||
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
|
||||
|
||||
tf=self._TF[tf_idx]
|
||||
|
||||
#In place
|
||||
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||
|
||||
#Out of place
|
||||
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||
idx= mask.nonzero()
|
||||
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
x=x.scatter(dim=0, index=idx, src=smp_x)
|
||||
|
||||
return x
|
||||
|
||||
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||
""" Enforce limitations to the learned parameters.
|
||||
|
||||
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
|
||||
|
||||
Args:
|
||||
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
|
||||
"""
|
||||
# if not self._fixed_prob:
|
||||
# if soft :
|
||||
# self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
# else:
|
||||
# self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||
# self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
if not self._fixed_mag:
|
||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||
|
||||
if not self._fixed_temp:
|
||||
self._params['temp'].data = self._params['temp'].data.clamp(min=0.0, max=0.999)
|
||||
|
||||
def loss_weight(self, batch_norm=True):
|
||||
""" Weights for the loss.
|
||||
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
|
||||
Should be applied to the loss before reduction.
|
||||
|
||||
Do not take into account the order of application of the TF. See Data_augV7.
|
||||
|
||||
Args:
|
||||
batch_norm (bool): Wether to normalize mean of the weights. (Default: True)
|
||||
Returns:
|
||||
Tensor : Loss weights.
|
||||
"""
|
||||
if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
|
||||
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
# prob = F.softmax(prob, dim=0)
|
||||
|
||||
w_loss = torch.zeros((self._samples.shape[0],self._nb_TF_sets), device=self._samples.device)
|
||||
w_loss.scatter_(1, self._samples.view(-1,1), 1)
|
||||
|
||||
#Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
|
||||
w_loss = w_loss * prob
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
if batch_norm:
|
||||
w_min = w_loss.min()
|
||||
w_loss = w_loss-w_min if w_min<0 else w_loss
|
||||
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
|
||||
|
||||
#Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if temp=1.
|
||||
# w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
# w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
# if mean_norm:
|
||||
# w_loss = w_loss * prob
|
||||
# w_loss = torch.sum(w_loss,dim=1)
|
||||
# w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
|
||||
# else:
|
||||
# w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
# w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
""" Regularisation term used to learn the magnitudes.
|
||||
Use an L2 loss to encourage high magnitudes TF.
|
||||
|
||||
Args:
|
||||
reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005)
|
||||
Returns:
|
||||
Tensor containing the regularisation loss value.
|
||||
"""
|
||||
if self._fixed_mag or self._fixed_prob: #Not enough DOF
|
||||
return torch.tensor(0)
|
||||
else:
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
||||
return max_mag_reg
|
||||
|
||||
def TF_prob(self):
|
||||
""" Gives an estimation of the individual TF probabilities.
|
||||
|
||||
Be warry that the probability returned isn't exact. The TF distribution isn't fully represented by those.
|
||||
Each probability should be taken individualy. They only represent the chance for a specific TF to be picked at least once.
|
||||
|
||||
Returms:
|
||||
Tensor containing the single TF probabilities of applications.
|
||||
"""
|
||||
if torch.all(self._params['prob']!=self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
|
||||
self._prob_mem=self._params['prob'].data.detach_()
|
||||
prob = F.softmax(self._params["prob"]*self._params["temp"], dim=0)
|
||||
self._single_TF_prob=torch.zeros(self._nb_tf)
|
||||
for idx_tf in range(self._nb_tf):
|
||||
for i, t_set in enumerate(self._TF_sets):
|
||||
#uni, count = np.unique(t_set, return_counts=True)
|
||||
#if idx_tf in uni:
|
||||
# res[idx_tf]+=self._params['prob'][i]*int(count[np.where(uni==idx_tf)])
|
||||
if idx_tf in t_set:
|
||||
self._single_TF_prob[idx_tf]+=prob[i]
|
||||
|
||||
return self._single_TF_prob
|
||||
|
||||
def train(self, mode=True):
|
||||
""" Set the module training mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
||||
"""
|
||||
#if mode is None :
|
||||
# mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(Data_augV7, self).train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
""" Set the module to evaluation mode.
|
||||
"""
|
||||
return self.train(mode=False)
|
||||
|
||||
def augment(self, mode=True):
|
||||
""" Set the augmentation mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
|
||||
"""
|
||||
self._data_augmentation=mode
|
||||
|
||||
def is_augmenting(self):
|
||||
""" Return wether data augmentation is applied.
|
||||
|
||||
Returns:
|
||||
bool : True if data augmentation is applied.
|
||||
"""
|
||||
return self._data_augmentation
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Access to the learnable parameters
|
||||
Args:
|
||||
key (string): Name of the learnable parameter to access.
|
||||
|
||||
Returns:
|
||||
nn.Parameter.
|
||||
"""
|
||||
if key == 'prob': #Override prob access
|
||||
return self.TF_prob()
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
"""Name of the module
|
||||
|
||||
Returns:
|
||||
String containing the name of the module as well as the higher levels parameters.
|
||||
"""
|
||||
dist_param=''
|
||||
if self._fixed_prob: dist_param+='Fx'
|
||||
mag_param='Mag'
|
||||
if self._fixed_mag: mag_param+= 'Fx'
|
||||
if self._shared_mag: mag_param+= 'Sh'
|
||||
# if not self._temp:
|
||||
# return "Data_augV7(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
if self._fixed_temp:
|
||||
return "Data_augV7(T%.1f%s-%dTFx%d-%s)" % (self._params['temp'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
else:
|
||||
return "Data_augV7(T%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
|
|
|
@ -2,17 +2,18 @@ from utils import *
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
#'''
|
||||
files=[
|
||||
"../res/HPsearch/log/Aug_mod(Data_augV5(Mix0.5-14TFx3-Mag)-ResNet)-200 epochs (dataug:0)- 1 in_it-0.json",
|
||||
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-0.json",
|
||||
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-1.json",
|
||||
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-2.json",
|
||||
#"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
]
|
||||
#files = ["../res/benchmark/CIFAR10/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,0.17,'wide_resnet50_2', str(run)) for run in range(3)]
|
||||
#files = ["../res/benchmark/CIFAR10/log/Aug_mod(RandAug(14TFx%d-Mag%d)-%s)-200 epochs (dataug:0)- 0 in_it-%s.json"%(2,1,'resnet18', str(run)) for run in range(1)]
|
||||
files = ["../res/benchmark/CIFAR10/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-Mag)-%s)-200 epochs (dataug:0)- 3 in_it-%s.json"%(0.5,3,'resnet18', str(run)) for run in range(1)]
|
||||
'''
|
||||
# files=[
|
||||
# "../res/log/Aug_mod(Data_augV5(Mix0.5-18TFx3-Mag)-efficientnet-b1)-200 epochs (dataug 0)- 1 in_it__AL2.json",
|
||||
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-0.json",
|
||||
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-1.json",
|
||||
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-2.json",
|
||||
# #"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
# ]
|
||||
files = ["../res/benchmark/CIFAR10/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,1,'resnet18', str(run)) for run in range(3)]
|
||||
#files = ["../res/benchmark/CIFAR10/log/Aug_mod(RandAug(18TFx%d-Mag%d)-%s)-200 epochs (dataug:0)- 0 in_it-%s.json"%(2,1,'resnet18', str(run)) for run in range(3)]
|
||||
#files = ["../res/benchmark/CIFAR10/log/Aug_mod(Data_augV5(Mix%.1f-18TFx%d-Mag)-%s)-200 epochs (dataug:0)- 1 in_it-%s.json"%(0.5,3,'resnet18', str(run)) for run in range(3)]
|
||||
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
|
@ -20,7 +21,41 @@ if __name__ == "__main__":
|
|||
data = json.load(json_file)
|
||||
plot_resV2(data['Log'], fig_name=file.replace("/log","").replace(".json",""), param_names=data['Param_names'], f1=True)
|
||||
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
||||
#'''
|
||||
'''
|
||||
|
||||
#Res print
|
||||
# '''
|
||||
nb_run=3
|
||||
accs = []
|
||||
aug_accs = []
|
||||
f1_max = []
|
||||
f1_min = []
|
||||
times = []
|
||||
mem = []
|
||||
|
||||
files = ["../res/benchmark/log/Aug_mod(Data_augV5(T0.5-19TFx3-Mag)-resnet18)-200 epochs (dataug 0)- 1 in_it__%s.json"%(str(run)) for run in range(1, nb_run+1)]
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
accs.append(data['Accuracy'])
|
||||
aug_accs.append(data['Aug_Accuracy'][1])
|
||||
times.append(data['Time'][0])
|
||||
mem.append(data['Memory'][1])
|
||||
|
||||
acc_idx = [x['acc'] for x in data['Log']].index(data['Accuracy'])
|
||||
f1_max.append(max(data['Log'][acc_idx]['f1'])*100)
|
||||
f1_min.append(min(data['Log'][acc_idx]['f1'])*100)
|
||||
print(idx, accs[-1], aug_accs[-1])
|
||||
|
||||
print(files[0])
|
||||
print("Acc : %.2f ~ %.2f / Aug_Acc %d: %.2f ~ %.2f"%(np.mean(accs), np.std(accs), data['Aug_Accuracy'][0], np.mean(aug_accs), np.std(aug_accs)))
|
||||
print("F1 max : %.2f ~ %.2f / F1 min : %.2f ~ %.2f"%(np.mean(f1_max), np.std(f1_max), np.mean(f1_min), np.std(f1_min)))
|
||||
print("Time (h): %.2f ~ %.2f"%(np.mean(times)/3600, np.std(times)/3600))
|
||||
print("Mem (MB): %d ~ %d"%(np.mean(mem), np.std(mem)))
|
||||
# '''
|
||||
|
||||
## Loss , Acc, Proba = f(epoch) ##
|
||||
#plot_compare(filenames=files, fig_name="res/compare")
|
||||
|
||||
|
@ -79,37 +114,21 @@ if __name__ == "__main__":
|
|||
plt.close()
|
||||
'''
|
||||
|
||||
#Res print
|
||||
'''
|
||||
nb_run=3
|
||||
accs = []
|
||||
times = []
|
||||
files = ["res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-Mag)-LeNet)-150epochs(dataug:0)-1in_it-%s.json"%str(run) for run in range(nb_run)]
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
accs.append(data['Accuracy'])
|
||||
times.append(data['Time'][0])
|
||||
print(idx, data['Accuracy'])
|
||||
|
||||
print(files[0], np.mean(accs), np.std(accs), np.mean(times))
|
||||
'''
|
||||
|
||||
'''
|
||||
#HP search
|
||||
inner_its = [1]
|
||||
dist_mix = [0.3, 0.5, 0.8, 1.0] #Uniform
|
||||
N_seq_TF= [5]
|
||||
N_seq_TF= [3]
|
||||
nb_run= 3
|
||||
|
||||
for n_inner_iter in inner_its:
|
||||
for n_tf in N_seq_TF:
|
||||
for dist in dist_mix:
|
||||
|
||||
#files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-MagFxSh)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(dist, n_tf, str(run)) for run in range(nb_run)]
|
||||
files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Uniform-14TFx%d-MagFxSh)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(n_tf, str(run)) for run in range(nb_run)]
|
||||
files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-Mag)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(dist, n_tf, str(run)) for run in range(nb_run)]
|
||||
#files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Uniform-14TFx%d-MagFxSh)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(n_tf, str(run)) for run in range(nb_run)]
|
||||
accs = []
|
||||
times = []
|
||||
for idx, file in enumerate(files):
|
||||
|
|
|
@ -2,22 +2,14 @@
|
|||
|
||||
"""
|
||||
import sys
|
||||
from LeNet import *
|
||||
from dataug import *
|
||||
#from utils import *
|
||||
from train_utils import *
|
||||
from transformations import TF_loader
|
||||
# from arg_parser import *
|
||||
|
||||
postfix='-MetaScheduler2'
|
||||
TF_loader=TF_loader()
|
||||
|
||||
device = torch.device('cuda') #Select device to use
|
||||
|
||||
if device == torch.device('cpu'):
|
||||
device_name = 'CPU'
|
||||
else:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True #Faster if same input size #Not recommended for reproductibility
|
||||
|
||||
#Increase reproductibility
|
||||
|
@ -27,46 +19,68 @@ np.random.seed(0)
|
|||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
#Task to perform
|
||||
tasks={
|
||||
#'classic',
|
||||
'aug_model'
|
||||
}
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
res_folder=args.res_folder
|
||||
postfix=args.postfix
|
||||
|
||||
if args.dtype == 'FP32':
|
||||
def_type=torch.float32
|
||||
elif args.dtype == 'FP16':
|
||||
# def_type=torch.float16 #Default : float32
|
||||
def_type=torch.bfloat16
|
||||
else:
|
||||
raise Exception('dtype not supported :', args.dtype)
|
||||
torch.set_default_dtype(def_type) #Default : float32
|
||||
|
||||
|
||||
device = torch.device(args.device) #Select device to use
|
||||
if device == torch.device('cpu'):
|
||||
device_name = 'CPU'
|
||||
else:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
#Parameters
|
||||
n_inner_iter = 1
|
||||
epochs = 100
|
||||
n_inner_iter = args.K
|
||||
epochs = args.epochs
|
||||
dataug_epoch_start=0
|
||||
Nb_TF_seq=3
|
||||
Nb_TF_seq= args.N
|
||||
optim_param={
|
||||
'Meta':{
|
||||
'optim':'Adam',
|
||||
'lr':1e-3, #1e-2
|
||||
'epoch_start': 2, #0 / 2 (Resnet?)
|
||||
'reg_factor': 0.001,
|
||||
'scheduler': 'multiStep', #None, 'multiStep'
|
||||
'lr':args.mlr,
|
||||
'epoch_start': args.meta_epoch_start, #0 / 2 (Resnet?)
|
||||
'reg_factor': args.mag_reg,
|
||||
'scheduler': None, #None, 'multiStep'
|
||||
},
|
||||
'Inner':{
|
||||
'optim': 'SGD',
|
||||
'lr':1e-1, #1e-2/1e-1 (ResNet)
|
||||
'momentum':0.9, #0.9
|
||||
'decay':0.0005, #0.0005
|
||||
'nesterov':False, #False (True: Bad behavior w/ Data_aug)
|
||||
'scheduler':'cosine', #None, 'cosine', 'multiStep', 'exponential'
|
||||
'lr':args.lr, #1e-2/1e-1 (ResNet)
|
||||
'momentum':args.momentum, #0.9
|
||||
'weight_decay':args.decay, #0.0005
|
||||
'nesterov':args.nesterov, #False (True: Bad behavior w/ Data_aug)
|
||||
'scheduler': args.scheduler, #None, 'cosine', 'multiStep', 'exponential'
|
||||
'warmup':{
|
||||
'multiplier': args.warmup, #2 #+ batch_size => + mutliplier #No warmup = 0
|
||||
'epochs': 5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#Models
|
||||
#model = LeNet(3,10)
|
||||
#model = ResNet(num_classes=10)
|
||||
import torchvision.models as models
|
||||
#model=models.resnet18()
|
||||
model_name = 'resnet18' #'wide_resnet50_2' #'resnet18' #str(model)
|
||||
model = getattr(models.resnet, model_name)(pretrained=False, num_classes=len(dl_train.dataset.classes))
|
||||
#Info params
|
||||
F1=True
|
||||
sample_save=None
|
||||
print_f= epochs/4
|
||||
|
||||
#Load network
|
||||
model, model_name= load_model(args.model, num_classes=len(dl_train.dataset.classes), pretrained=args.pretrained)
|
||||
|
||||
#### Classic ####
|
||||
if 'classic' in tasks:
|
||||
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
|
||||
torch.cuda.reset_max_memory_cached() #reset_peak_stats
|
||||
if not args.augment:
|
||||
if device_name != 'CPU':
|
||||
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
|
||||
torch.cuda.reset_max_memory_cached() #reset_peak_stats
|
||||
t0 = time.perf_counter()
|
||||
|
||||
model = model.to(device)
|
||||
|
@ -74,12 +88,17 @@ if __name__ == "__main__":
|
|||
|
||||
print("{} on {} for {} epochs{}".format(model_name, device_name, epochs, postfix))
|
||||
#print("RandAugment(N{}-M{:.2f})-{} on {} for {} epochs{}".format(rand_aug['N'],rand_aug['M'],model_name, device_name, epochs, postfix))
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10)
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=print_f)
|
||||
#log= train_classic_higher(model=model, epochs=epochs)
|
||||
|
||||
exec_time=time.perf_counter() - t0
|
||||
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
|
||||
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
|
||||
|
||||
if device_name != 'CPU':
|
||||
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
|
||||
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
|
||||
else:
|
||||
max_allocated = 0.0
|
||||
max_cached=0.0
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
|
@ -94,7 +113,7 @@ if __name__ == "__main__":
|
|||
filename = "{}-{} epochs".format(model_name,epochs)+postfix
|
||||
#print("RandAugment-",model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
#filename = "RandAugment(N{}-M{:.2f})-{}-{} epochs".format(rand_aug['N'],rand_aug['M'],model_name,epochs)+postfix
|
||||
with open("../res/log/%s.json" % filename, "w+") as f:
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
try:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
@ -103,7 +122,7 @@ if __name__ == "__main__":
|
|||
print(sys.exc_info()[1])
|
||||
|
||||
try:
|
||||
plot_resV2(log, fig_name="../res/"+filename)
|
||||
plot_resV2(log, fig_name=res_folder+filename, f1=F1)
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
print(sys.exc_info()[1])
|
||||
|
@ -112,55 +131,63 @@ if __name__ == "__main__":
|
|||
print('-'*9)
|
||||
|
||||
#### Augmented Model ####
|
||||
if 'aug_model' in tasks:
|
||||
tf_config='../config/invScale_wide_tf_config.json'#'../config/base_tf_config.json'
|
||||
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(tf_config)
|
||||
else:
|
||||
# tf_config='../config/invScale_wide_tf_config.json'#'../config/invScale_wide_tf_config.json'#'../config/base_tf_config.json'
|
||||
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(args.tf_config)
|
||||
|
||||
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
|
||||
torch.cuda.reset_max_memory_cached() #reset_peak_stats
|
||||
if device_name != 'CPU':
|
||||
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
|
||||
torch.cuda.reset_max_memory_cached() #reset_peak_stats
|
||||
t0 = time.perf_counter()
|
||||
|
||||
model = Higher_model(model, model_name) #run_dist_dataugV3
|
||||
dataug_mod = 'Data_augV8' if args.learn_seq else 'Data_augV5'
|
||||
if n_inner_iter !=0:
|
||||
aug_model = Augmented_model(
|
||||
Data_augV5(TF_dict=tf_dict,
|
||||
globals()[dataug_mod](TF_dict=tf_dict,
|
||||
N_TF=Nb_TF_seq,
|
||||
mix_dist=0.5,
|
||||
temp=args.temp,
|
||||
fixed_prob=False,
|
||||
fixed_mag=False,
|
||||
shared_mag=False,
|
||||
fixed_mag=args.fixed_mag,
|
||||
shared_mag=args.shared_mag,
|
||||
TF_ignore_mag=tf_ignore_mag), model).to(device)
|
||||
else:
|
||||
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=Nb_TF_seq), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it{}".format(str(aug_model), device_name, epochs, n_inner_iter, postfix))
|
||||
log= run_dist_dataugV3(model=aug_model,
|
||||
log, aug_acc = run_dist_dataugV3(model=aug_model,
|
||||
epochs=epochs,
|
||||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=20,
|
||||
unsup_loss=1,
|
||||
hp_opt=False,
|
||||
save_sample_freq=None)
|
||||
augment_loss=args.augment_loss,
|
||||
hp_opt=False, #False #['lr', 'momentum', 'weight_decay']
|
||||
print_freq=print_f,
|
||||
save_sample_freq=sample_save)
|
||||
|
||||
exec_time=time.perf_counter() - t0
|
||||
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
|
||||
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
|
||||
if device_name != 'CPU':
|
||||
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
|
||||
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
|
||||
else:
|
||||
max_allocated = 0.0
|
||||
max_cached = 0.0
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]),
|
||||
"Aug_Accuracy": [args.augment_loss, aug_acc],
|
||||
"Time": (np.mean(times),np.std(times), exec_time),
|
||||
'Optimizer': optim_param,
|
||||
"Device": device_name,
|
||||
"Memory": [max_allocated, max_cached],
|
||||
"TF_config": tf_config,
|
||||
"TF_config": args.tf_config,
|
||||
"Param_names": aug_model.TF_names(),
|
||||
"Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)+postfix
|
||||
with open("../res/log/%s.json" % filename, "w+") as f:
|
||||
print(str(aug_model),": acc", out["Accuracy"], "/ aug_acc", out["Aug_Accuracy"][1] , "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{}_epochs-{}_in_it-AL{}".format(str(aug_model),epochs,n_inner_iter,args.augment_loss)+postfix
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
try:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
@ -168,7 +195,7 @@ if __name__ == "__main__":
|
|||
print("Failed to save logs :",f.name)
|
||||
print(sys.exc_info()[1])
|
||||
try:
|
||||
plot_resV2(log, fig_name="../res/"+filename, param_names=aug_model.TF_names())
|
||||
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names(), f1=F1)
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
print(sys.exc_info()[1])
|
||||
|
|
|
@ -1,29 +1,35 @@
|
|||
""" Utilities function for training.
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
import torch
|
||||
#import torch.optim
|
||||
import torchvision
|
||||
#import torchvision
|
||||
import higher
|
||||
import higher_patch
|
||||
|
||||
from datasets import *
|
||||
from utils import *
|
||||
|
||||
from transformations import Normalizer, translate, zero_stack
|
||||
norm = Normalizer(MEAN, STD)
|
||||
confmat = ConfusionMatrix(num_classes=len(dl_test.dataset.classes))
|
||||
|
||||
def test(model):
|
||||
max_grad = 1 #Max gradient value #Limite catastrophic drop
|
||||
|
||||
def test(model, augment=0):
|
||||
"""Evaluate a model on test data.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to test.
|
||||
augment (int): Number of augmented example for each sample. (Default : 0)
|
||||
|
||||
Returns:
|
||||
(float, Tensor) Returns the accuracy and F1 score of the model.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
# model['model']['functional'].set_mode('mixed') #ABN
|
||||
|
||||
#for i, (features, labels) in enumerate(dl_test):
|
||||
# features,labels = features.to(device), labels.to(device)
|
||||
|
@ -34,12 +40,30 @@ def test(model):
|
|||
correct = 0
|
||||
total = 0
|
||||
#loss = []
|
||||
global confmat
|
||||
confmat.reset()
|
||||
with torch.no_grad():
|
||||
for features, labels in dl_test:
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
outputs = model(features)
|
||||
if augment>0: #Test Time Augmentation
|
||||
model.augment(True)
|
||||
# V2
|
||||
features=torch.cat([features for _ in range(augment)], dim=0) # (B,C,H,W)=>(B*augment,C,H,W)
|
||||
outputs=model(features)
|
||||
outputs=torch.cat([o.unsqueeze(dim=0) for o in outputs.chunk(chunks=augment, dim=0)],dim=0) # (B*augment,nb_class)=>(augment,B,nb_class)
|
||||
|
||||
w_losses=model['data_aug'].loss_weight(batch_norm=False) #(B*augment) if Dataug
|
||||
if w_losses.shape[0]==1: #RandAugment
|
||||
outputs=torch.sum(outputs, axis=0)/augment #mean
|
||||
else: #Dataug
|
||||
w_losses=torch.cat([w.unsqueeze(dim=0) for w in w_losses.chunk(chunks=augment, dim=0)], dim=0) #(augment, B)
|
||||
w_losses = w_losses / w_losses.sum(axis=0, keepdim=True) #sum(w_losses)=1 pour un même echantillons
|
||||
|
||||
outputs=torch.sum(outputs*w_losses.unsqueeze(dim=2).expand_as(outputs), axis=0)/augment #Weighted mean
|
||||
else:
|
||||
outputs = model(features)
|
||||
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
@ -74,10 +98,12 @@ def compute_vaLoss(model, dl_it, dl):
|
|||
xs, ys = next(dl_it)
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
model.eval() #Validation sans transfornations !
|
||||
return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys)
|
||||
model.eval() #Validation sans transformations !
|
||||
# model['model']['functional'].set_mode('mixed') #ABN
|
||||
# return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys)
|
||||
return F.cross_entropy(model(xs), ys)
|
||||
|
||||
def mixed_loss(xs, ys, model, unsup_factor=1):
|
||||
def mixed_loss(xs, ys, model, unsup_factor=1, augment=1):
|
||||
"""Evaluate a model on a batch of data.
|
||||
|
||||
Compute a combinaison of losses:
|
||||
|
@ -94,40 +120,64 @@ def mixed_loss(xs, ys, model, unsup_factor=1):
|
|||
ys (Tensor): Batch of labels.
|
||||
model (nn.Module): Augmented model (see dataug.py).
|
||||
unsup_factor (float): Factor by which unsupervised CE and KL div loss are multiplied.
|
||||
augment (int): Number of augmented example for each sample. (Default : 1)
|
||||
|
||||
Returns:
|
||||
(Tensor) Mixed loss if there's data augmentation, just supervised CE loss otherwise.
|
||||
"""
|
||||
|
||||
#TODO: add test to prevent augmented model error and redirect to classic loss
|
||||
if unsup_factor!=0 and model.is_augmenting():
|
||||
if unsup_factor!=0 and model.is_augmenting() and augment>0:
|
||||
|
||||
# Supervised loss (classic)
|
||||
# Supervised loss - Cross-entropy
|
||||
model.augment(mode=False)
|
||||
sup_logits = model(xs)
|
||||
model.augment(mode=True)
|
||||
|
||||
log_sup = F.log_softmax(sup_logits, dim=1)
|
||||
sup_loss = F.cross_entropy(log_sup, ys)
|
||||
sup_loss = F.nll_loss(log_sup, ys)
|
||||
# sup_loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
# Unsupervised loss
|
||||
aug_logits = model(xs)
|
||||
w_loss = model['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
log_aug = F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
|
||||
aug_loss = (aug_loss * w_loss).mean()
|
||||
if augment>1:
|
||||
# Unsupervised loss - Cross-Entropy
|
||||
xs_a=torch.cat([xs for _ in range(augment)], dim=0) # (B,C,H,W)=>(B*augment,C,H,W)
|
||||
ys_a=torch.cat([ys for _ in range(augment)], dim=0)
|
||||
aug_logits=model(xs_a) # (B*augment,nb_class)
|
||||
|
||||
w_loss=model['data_aug'].loss_weight() #(B*augment) if Dataug
|
||||
|
||||
log_aug = F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss = F.nll_loss(log_aug, ys_a , reduction='none')
|
||||
# aug_loss = F.cross_entropy(log_aug, ys_a , reduction='none')
|
||||
aug_loss = (aug_loss * w_loss).mean()
|
||||
|
||||
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
|
||||
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
|
||||
kl_loss = (w_loss * kl_loss).mean()
|
||||
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
|
||||
sup_logits_a=torch.cat([sup_logits for _ in range(augment)], dim=0)
|
||||
log_sup_a=torch.cat([log_sup for _ in range(augment)], dim=0)
|
||||
|
||||
kl_loss = (F.softmax(sup_logits_a, dim=1)*(log_sup_a-log_aug)).sum(dim=-1)
|
||||
kl_loss = (w_loss * kl_loss).mean()
|
||||
else:
|
||||
# Unsupervised loss - Cross-Entropy
|
||||
aug_logits = model(xs)
|
||||
w_loss = model['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
log_aug = F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss = F.nll_loss(log_aug, ys , reduction='none')
|
||||
# aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
|
||||
aug_loss = (aug_loss * w_loss).mean()
|
||||
|
||||
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
|
||||
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
|
||||
kl_loss = (w_loss * kl_loss).mean()
|
||||
|
||||
loss = sup_loss + unsup_factor * (aug_loss + kl_loss)
|
||||
|
||||
else: #Supervised loss (classic)
|
||||
else: #Supervised loss - Cross-Entropy
|
||||
sup_logits = model(xs)
|
||||
log_sup = F.log_softmax(sup_logits, dim=1)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
loss = F.cross_entropy(sup_logits, ys)
|
||||
# log_sup = F.log_softmax(sup_logits, dim=1)
|
||||
# loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
return loss
|
||||
|
||||
|
@ -150,7 +200,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
optim = torch.optim.SGD(model.parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
momentum=opt_param['Inner']['momentum'],
|
||||
weight_decay=opt_param['Inner']['decay'],
|
||||
weight_decay=opt_param['Inner']['weight_decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
#Scheduler
|
||||
|
@ -168,6 +218,9 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
elif opt_param['Inner']['scheduler'] is not None:
|
||||
raise ValueError("Lr scheduler unknown : %s"%opt_param['Inner']['scheduler'])
|
||||
|
||||
# from warmup_scheduler import GradualWarmupScheduler
|
||||
# inner_scheduler=GradualWarmupScheduler(optim, multiplier=2, total_epoch=5, after_scheduler=inner_scheduler)
|
||||
|
||||
#Training
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
|
@ -188,9 +241,12 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
# print_graph(loss, '../samples/torchvision_WRN') #to visualize computational graph
|
||||
# sys.exit()
|
||||
|
||||
if inner_scheduler is not None:
|
||||
inner_scheduler.step()
|
||||
# print(optim.param_groups[0]['lr'])
|
||||
|
||||
#### Tests ####
|
||||
tf = time.perf_counter()
|
||||
|
@ -222,7 +278,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
|
||||
return log
|
||||
|
||||
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, print_freq=1, unsup_loss=1, hp_opt=False, save_sample_freq=None):
|
||||
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, unsup_loss=1, augment_loss=1, hp_opt=False, print_freq=1, save_sample_freq=None):
|
||||
"""Training of an augmented model with higher.
|
||||
|
||||
This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
|
||||
|
@ -237,9 +293,10 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
epochs (int): Number of epochs to perform. (default: 1)
|
||||
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
|
||||
dataug_epoch_start (int): Epoch when to start data augmentation. (default: 0)
|
||||
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
|
||||
augment_loss (int): Number of augmented example for each sample in loss computation. (Default : 1)
|
||||
hp_opt (bool): Wether to learn inner optimizer parameters. (default: False)
|
||||
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, no sample will be saved. (default: None)
|
||||
|
||||
Returns:
|
||||
|
@ -247,6 +304,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
"""
|
||||
device = next(model.parameters()).device
|
||||
log = []
|
||||
# kl_log={"prob":[], "mag":[]}
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
high_grad_track = True
|
||||
|
@ -261,12 +319,12 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
inner_opt = torch.optim.SGD(model['model']['original'].parameters(),
|
||||
lr=opt_param['Inner']['lr'],
|
||||
momentum=opt_param['Inner']['momentum'],
|
||||
weight_decay=opt_param['Inner']['decay'],
|
||||
weight_decay=opt_param['Inner']['weight_decay'],
|
||||
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
diffopt = model['model'].get_diffopt(
|
||||
inner_opt,
|
||||
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
||||
grad_callback=(lambda grads: clip_norm(grads, max_norm=max_grad)),
|
||||
track_higher_grads=high_grad_track)
|
||||
|
||||
#Scheduler
|
||||
|
@ -281,23 +339,33 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
elif opt_param['Inner']['scheduler']=='exponential':
|
||||
#inner_scheduler=torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.1) #Wrong gamma
|
||||
inner_scheduler=torch.optim.lr_scheduler.LambdaLR(inner_opt, lambda epoch: (1 - epoch / epochs) ** 0.9)
|
||||
elif opt_param['Inner']['scheduler'] is not None:
|
||||
elif not(opt_param['Inner']['scheduler'] is None or opt_param['Inner']['scheduler']==''):
|
||||
raise ValueError("Lr scheduler unknown : %s"%opt_param['Inner']['scheduler'])
|
||||
|
||||
#Warmup
|
||||
if opt_param['Inner']['warmup']['multiplier']>=1:
|
||||
from warmup_scheduler import GradualWarmupScheduler
|
||||
inner_scheduler=GradualWarmupScheduler(inner_opt,
|
||||
multiplier=opt_param['Inner']['warmup']['multiplier'],
|
||||
total_epoch=opt_param['Inner']['warmup']['epochs'],
|
||||
after_scheduler=inner_scheduler)
|
||||
|
||||
#Meta Opt
|
||||
hyper_param = list(model['data_aug'].parameters())
|
||||
if hp_opt :
|
||||
if hp_opt : #(deprecated)
|
||||
for param_group in diffopt.param_groups:
|
||||
for param in list(opt_param['Inner'].keys())[1:]:
|
||||
# print(param_group)
|
||||
for param in hp_opt:
|
||||
param_group[param]=torch.tensor(param_group[param]).to(device).requires_grad_()
|
||||
hyper_param += [param_group[param]]
|
||||
meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||
meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr'])
|
||||
|
||||
#Meta-Scheduler (deprecated)
|
||||
meta_scheduler=None
|
||||
if opt_param['Meta']['scheduler']=='multiStep':
|
||||
meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt,
|
||||
milestones=[int(epochs/3), int(epochs*2/3)],# int(epochs*2.7/3)],
|
||||
gamma=3.16)#10)
|
||||
gamma=3.16)
|
||||
elif opt_param['Meta']['scheduler'] is not None:
|
||||
raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler'])
|
||||
|
||||
|
@ -314,7 +382,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
|
||||
if(unsup_loss==0):
|
||||
#Methode uniforme
|
||||
logits = model(xs) # modified `params` can also be passed as a kwarg
|
||||
|
@ -327,31 +395,35 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
|
||||
else:
|
||||
#Methode mixed
|
||||
loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
|
||||
loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss, augment=augment_loss)
|
||||
|
||||
#print_graph(loss) #to visualize computational graph
|
||||
# print_graph(loss, '../samples/pytorch_WRN') #to visualize computational graph
|
||||
# sys.exit()
|
||||
|
||||
#t = time.process_time()
|
||||
# t = time.process_time()
|
||||
diffopt.step(loss)#(opt.zero_grad, loss.backward, opt.step)
|
||||
#print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
|
||||
|
||||
# print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
|
||||
|
||||
if(high_grad_track and i>0 and i%inner_it==0 and epoch>=opt_param['Meta']['epoch_start']): #Perform Meta step
|
||||
#print("meta")
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss(opt_param['Meta']['reg_factor'])
|
||||
model.train()
|
||||
#print_graph(val_loss) #to visualize computational graph
|
||||
val_loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=max_grad, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
# print("Grad mix",model['data_aug']["temp"].grad)
|
||||
# prv_param=model['data_aug']._params
|
||||
meta_opt.step()
|
||||
# kl_log["prob"].append(F.kl_div(prv_param["prob"],model['data_aug']["prob"], reduction='batchmean').item())
|
||||
# kl_log["mag"].append(F.kl_div(prv_param["mag"],model['data_aug']["mag"], reduction='batchmean').item())
|
||||
|
||||
#Adjust Hyper-parameters
|
||||
model['data_aug'].adjust_param() #Contrainte sum(proba)=1
|
||||
model['data_aug'].adjust_param()
|
||||
if hp_opt:
|
||||
for param_group in diffopt.param_groups:
|
||||
for param in list(opt_param['Inner'].keys())[1:]:
|
||||
param_group[param].data = param_group[param].data.clamp(min=1e-4)
|
||||
for param in hp_opt:
|
||||
param_group[param].data = param_group[param].data.clamp(min=1e-5)
|
||||
|
||||
#Reset gradients
|
||||
diffopt.detach_()
|
||||
|
@ -374,18 +446,26 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
diff_param_group['lr'] = param_group['lr']
|
||||
if meta_scheduler is not None:
|
||||
meta_scheduler.step()
|
||||
|
||||
# if epoch<epochs/3:
|
||||
# model['data_aug']['temp'].data=torch.tensor(0.5, device=device)
|
||||
# elif epoch>epochs/3 and epoch<(epochs*2/3):
|
||||
# model['data_aug']['temp'].data=torch.tensor(0.75, device=device)
|
||||
# elif epoch>(epochs*2/3):
|
||||
# model['data_aug']['temp'].data=torch.tensor(1.0, device=device)
|
||||
# model['data_aug']['temp'].data=torch.tensor(1./3+2/3*(epoch/epochs), device=device)
|
||||
# print('Temp',model['data_aug']['temp'])
|
||||
|
||||
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
model.train()
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
|
||||
model.eval()
|
||||
except:
|
||||
print("Couldn't save samples epoch"+epoch)
|
||||
print("Couldn't save samples epoch %d : %s"%(epoch, str(sys.exc_info()[1])))
|
||||
pass
|
||||
|
||||
|
||||
if(not val_loss): #Compute val loss for logs
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
|
@ -405,8 +485,8 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
|
||||
"param": param,
|
||||
}
|
||||
if not model['data_aug']._fixed_mix: data["mix_dist"]=model['data_aug']['mix_dist'].item()
|
||||
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'].item(), 'momentum': p_grp['momentum'].item()} for p_grp in diffopt.param_groups]
|
||||
if not model['data_aug']._fixed_temp: data["temp"]=model['data_aug']['temp'].item()
|
||||
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'], 'momentum': p_grp['momentum']} for p_grp in diffopt.param_groups]
|
||||
log.append(data)
|
||||
#############
|
||||
#### Print ####
|
||||
|
@ -420,11 +500,19 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
if not model['data_aug']._fixed_prob: print('TF Proba :', ["{0:0.4f}".format(p) for p in model['data_aug']['prob']])
|
||||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
if not model['data_aug']._fixed_mag: print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']])
|
||||
if not model['data_aug']._fixed_mag:
|
||||
if model['data_aug']._shared_mag:
|
||||
print('TF Mag :', "{0:0.4f}".format(model['data_aug']['mag']))
|
||||
else:
|
||||
print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']])
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
if not model['data_aug']._fixed_mix: print('Mix:', model['data_aug']['mix_dist'].item())
|
||||
if not model['data_aug']._fixed_temp: print('Temp:', model['data_aug']['temp'].item())
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
|
||||
# if len(kl_log["prob"])!=0:
|
||||
# print("KL prob : mean %f, std %f, max %f, min %f"%(np.mean(kl_log["prob"]), np.std(kl_log["prob"]), max(kl_log["prob"]), min(kl_log["prob"])))
|
||||
# print("KL mag : mean %f, std %f, max %f, min %f"%(np.mean(kl_log["mag"]), np.std(kl_log["mag"]), max(kl_log["mag"]), min(kl_log["mag"])))
|
||||
# kl_log={"prob":[], "mag":[]}
|
||||
|
||||
if hp_opt :
|
||||
for param_group in diffopt.param_groups:
|
||||
print('Opt param - lr:', param_group['lr'].item(),'- momentum:', param_group['momentum'].item())
|
||||
|
@ -439,63 +527,66 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
high_grad_track = True
|
||||
diffopt = model['model'].get_diffopt(
|
||||
inner_opt,
|
||||
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
||||
grad_callback=(lambda grads: clip_norm(grads, max_norm=max_grad)),
|
||||
track_higher_grads=high_grad_track)
|
||||
|
||||
return log
|
||||
aug_acc, aug_f1 = test(model, augment=augment_loss)
|
||||
|
||||
def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1):
|
||||
"""Simple training of an augmented model with higher.
|
||||
return log, aug_acc
|
||||
|
||||
This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
|
||||
Ex : Augmented_model(Data_augV5(...), Higher_model(model))
|
||||
#OLD
|
||||
# def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1):
|
||||
# """Simple training of an augmented model with higher.
|
||||
|
||||
Training loss can either be computed directly from augmented inputs (unsup_loss=0).
|
||||
However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
|
||||
# This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
|
||||
# Ex : Augmented_model(Data_augV5(...), Higher_model(model))
|
||||
|
||||
Does not support LR scheduler.
|
||||
# Training loss can either be computed directly from augmented inputs (unsup_loss=0).
|
||||
# However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
|
||||
|
||||
Args:
|
||||
model (nn.Module): Augmented model to train.
|
||||
opt_param (dict): Dictionnary containing optimizers parameters.
|
||||
epochs (int): Number of epochs to perform. (default: 1)
|
||||
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
|
||||
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
|
||||
# Does not support LR scheduler.
|
||||
|
||||
# Args:
|
||||
# model (nn.Module): Augmented model to train.
|
||||
# opt_param (dict): Dictionnary containing optimizers parameters.
|
||||
# epochs (int): Number of epochs to perform. (default: 1)
|
||||
# inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
|
||||
# print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
# unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
|
||||
|
||||
Returns:
|
||||
(dict) A dictionary containing a whole state of the trained network.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
# Returns:
|
||||
# (dict) A dictionary containing a whole state of the trained network.
|
||||
# """
|
||||
# device = next(model.parameters()).device
|
||||
|
||||
## Optimizers ##
|
||||
hyper_param = list(model['data_aug'].parameters())
|
||||
model.start_bilevel_opt(inner_it=inner_it, hp_list=hyper_param, opt_param=opt_param, dl_val=dl_val)
|
||||
# ## Optimizers ##
|
||||
# hyper_param = list(model['data_aug'].parameters())
|
||||
# model.start_bilevel_opt(inner_it=inner_it, hp_list=hyper_param, opt_param=opt_param, dl_val=dl_val)
|
||||
|
||||
model.train()
|
||||
# model.train()
|
||||
|
||||
for epoch in range(1, epochs+1):
|
||||
t0 = time.process_time()
|
||||
# for epoch in range(1, epochs+1):
|
||||
# t0 = time.process_time()
|
||||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
# for i, (xs, ys) in enumerate(dl_train):
|
||||
# xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
#Methode mixed
|
||||
loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
|
||||
# #Methode mixed
|
||||
# loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
|
||||
|
||||
model.step(loss) #(opt.zero_grad, loss.backward, opt.step) + automatic meta-optimisation
|
||||
# model.step(loss) #(opt.zero_grad, loss.backward, opt.step) + automatic meta-optimisation
|
||||
|
||||
tf = time.process_time()
|
||||
# tf = time.process_time()
|
||||
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', model.val_loss().item())
|
||||
if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
|
||||
if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
|
||||
if not model['data_aug']._fixed_mix: print('Mix:', model['data_aug']['mix_dist'].item())
|
||||
#############
|
||||
# #### Print ####
|
||||
# if(print_freq and epoch%print_freq==0):
|
||||
# print('-'*9)
|
||||
# print('Epoch : %d/%d'%(epoch,epochs))
|
||||
# print('Time : %.00f'%(tf - t0))
|
||||
# print('Train loss :',loss.item(), '/ val loss', model.val_loss().item())
|
||||
# if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
|
||||
# if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
|
||||
# if not model['data_aug']._fixed_temp: print('Temp:', model['data_aug']['temp'].item())
|
||||
# #############
|
||||
|
||||
return model['model'].state_dict()
|
||||
# return model['model'].state_dict()
|
|
@ -21,8 +21,8 @@ import json
|
|||
|
||||
#TF that don't have use for magnitude parameter.
|
||||
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend', 'identity', 'flip'}
|
||||
#TF which implemetation doesn't allow gradient propagaition.
|
||||
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize', 'posterize','solarize'}
|
||||
#TF which implemetation doesn't allow gradient propagaition.
|
||||
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize', 'posterize','solarize'} #Numpy implementation would be better ?
|
||||
#TF for which magnitude should be ignored (Magnitude fixed).
|
||||
TF_ignore_mag= TF_no_mag | TF_no_grad
|
||||
|
||||
|
@ -38,6 +38,17 @@ PARAMETER_MIN = 0.01
|
|||
# PARAMETER_MIN:{'Rotate','TranslateX','TranslateY','ShearX','ShearY'},
|
||||
#}
|
||||
|
||||
class Normalizer(object):
|
||||
def __init__(self, mean, std):
|
||||
self.mean=torch.tensor(mean)
|
||||
self.std=torch.tensor(std)
|
||||
def __call__(self, x):
|
||||
# return x.sub_(self.mean.to(x.device)[None, :, None, None]).div_(self.std.to(x.device)[None, :, None, None])
|
||||
return kornia.color.normalize(x, self.mean, self.std)
|
||||
def reverse(self, x):
|
||||
# return x.mul_(self.std.to(x.device)[None, :, None, None]).add_(self.mean.to(x.device)[None, :, None, None])
|
||||
return kornia.color.denormalize(x, self.mean, self.std).to(torch.half).to(torch.float)
|
||||
|
||||
class TF_loader(object):
|
||||
""" Transformations builder.
|
||||
|
||||
|
@ -91,13 +102,15 @@ class TF_loader(object):
|
|||
else:
|
||||
raise Exception("Unknown TF axis : %s in %s"%(tf['function'], self._filename))
|
||||
|
||||
elif tf['function'] in {'translate', 'shear'}:
|
||||
rand_fct= 'invScale_rand_floats' if tf['param']['invScale'] else 'rand_floats'
|
||||
self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'], tf['param']['absolute'], tf['param']['axis'])
|
||||
# elif tf['function'] in {'translate', 'shear'}:
|
||||
# rand_fct= 'invScale_rand_floats' if tf['param']['invScale'] else 'rand_floats'
|
||||
# self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'], tf['param']['absolute'], tf['param']['axis'])
|
||||
|
||||
else:
|
||||
axis = tf['param']['axis'] if 'axis' in tf['param'].keys() else None
|
||||
absolute = tf['param']['absolute'] if 'absolute' in tf['param'].keys() else True
|
||||
rand_fct= 'invScale_rand_floats' if tf['param']['invScale'] else 'rand_floats'
|
||||
self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'])
|
||||
self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'], absolute, axis)
|
||||
|
||||
return self._TF_dict, self._TF_ignore_mag
|
||||
|
||||
|
@ -130,7 +143,7 @@ class TF_loader(object):
|
|||
size=x.shape[0],
|
||||
mag=mag,
|
||||
minval=minval,
|
||||
maxval=maxval)))
|
||||
maxval=max_val_fct(max(x.shape[2],x.shape[3])))))
|
||||
elif axis =='X':
|
||||
return (lambda x, mag:
|
||||
globals()[fct_name](
|
||||
|
@ -419,6 +432,46 @@ def posterize(x, bits):
|
|||
|
||||
return float_image(x & mask)
|
||||
|
||||
def cutout(img, length):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): Batch of images. Expect image value between [0, 1].
|
||||
length (Tensor): The length (in pixels) of each square patch.
|
||||
Returns:
|
||||
Tensor: Images with single holes of dimension length x length cut out of it.
|
||||
"""
|
||||
device = img.device
|
||||
(batch_size, channels, h, w) = img.shape
|
||||
|
||||
# mask = np.ones((h, w), np.float32)
|
||||
mask = torch.ones((batch_size, h, w), device=device)
|
||||
length=length.type(torch.uint8)
|
||||
|
||||
# y = np.random.randint(h)
|
||||
# x = np.random.randint(w)
|
||||
y = torch.randint(low=0, high=h, size=(batch_size,)).to(device)
|
||||
x = torch.randint(low=0, high=w, size=(batch_size,)).to(device)
|
||||
|
||||
# y1 = np.clip(y - length // 2, 0, h)
|
||||
# y2 = np.clip(y + length // 2, 0, h)
|
||||
# x1 = np.clip(x - length // 2, 0, w)
|
||||
# x2 = np.clip(x + length // 2, 0, w)
|
||||
y1 = (y - length // 2).clamp(min=0, max=h)
|
||||
y2 = (y + length // 2).clamp(min=0, max=h)
|
||||
x1 = (x - length // 2).clamp(min=0, max=w)
|
||||
x2 = (x + length // 2).clamp(min=0, max=w)
|
||||
|
||||
# mask[y1: y2, x1: x2] = 0.
|
||||
for idx in range(batch_size): #Pas opti pour des batch
|
||||
mask[idx, y1[idx]: y2[idx], x1[idx]: x2[idx]]= 0.
|
||||
|
||||
# mask = torch.from_numpy(mask)
|
||||
# mask = mask.expand_as(img)
|
||||
mask = mask.unsqueeze(dim=1).expand_as(img)
|
||||
img = img * mask
|
||||
|
||||
return img
|
||||
|
||||
import torch.nn.functional as F
|
||||
def solarize(x, thresholds):
|
||||
"""Invert all pixel values above a threshold.
|
||||
|
|
|
@ -3,17 +3,88 @@
|
|||
"""
|
||||
import numpy as np
|
||||
import json, math, time, os
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') #https://stackoverflow.com/questions/4706451/how-to-save-a-figure-remotely-with-pylab
|
||||
import matplotlib.pyplot as plt
|
||||
import copy
|
||||
import gc
|
||||
|
||||
from torchviz import make_dot
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import time
|
||||
|
||||
from nets.LeNet import *
|
||||
from nets.wideresnet import *
|
||||
from nets.wideresnet_cifar import *
|
||||
import nets.resnet_abn as resnet_abn
|
||||
import nets.resnet_deconv as resnet_DC
|
||||
from efficientnet_pytorch import EfficientNet
|
||||
from efficientnet_pytorch.utils import url_map as EfficientNet_map
|
||||
import torchvision.models as models
|
||||
def load_model(model, num_classes, pretrained=False):
|
||||
if model in models.resnet.__all__ :
|
||||
model_name = model #'resnet18' #'resnet34' #'wide_resnet50_2'
|
||||
if pretrained :
|
||||
print("Using pretrained weights")
|
||||
model = getattr(models.resnet, model_name)(pretrained=True)
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, num_classes)
|
||||
else:
|
||||
model = getattr(models.resnet, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model in models.vgg.__all__ :
|
||||
model_name = model #'vgg11', 'vgg1_bn'
|
||||
if pretrained :
|
||||
print("Using pretrained weights")
|
||||
model = getattr(models.vgg, model_name)(pretrained=True)
|
||||
num_ftrs = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
|
||||
else :
|
||||
model = getattr(models.vgg, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model in models.densenet.__all__ :
|
||||
model_name = model #'densenet121' #'densenet201'
|
||||
if pretrained :
|
||||
print("Using pretrained weights")
|
||||
model = getattr(models.densenet, model_name)(pretrained=True)
|
||||
num_ftrs = model.classifier.in_features
|
||||
model.classifier = nn.Linear(num_ftrs, num_classes)
|
||||
else:
|
||||
model = getattr(models.densenet, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model == 'LeNet':
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
model = LeNet(3,num_classes)
|
||||
model_name=str(model)
|
||||
elif model == 'WideResNet':
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
# model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_classes)
|
||||
# model = WideResNet(16, 4, dropout_rate=0.0, num_classes=num_classes)
|
||||
model = wide_resnet_cifar(26, 10, num_classes=num_classes)
|
||||
# model = wide_resnet_cifar(20, 10, num_classes=num_classes)
|
||||
model_name=str(model)
|
||||
elif model in EfficientNet_map.keys():
|
||||
model_name=model # efficientnet-b0 , efficientnet-b1, efficientnet-b4
|
||||
if pretrained: #ImageNet ou Advprop (Meilleurs perf normalement mais normalisation differentes)
|
||||
print("Using pretrained weights")
|
||||
model = EfficientNet.from_pretrained(model_name, advprop=False)
|
||||
else:
|
||||
model = EfficientNet.from_name(model_name)
|
||||
elif model in resnet_abn.__all__ :
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
model_name=model
|
||||
model = getattr(resnet_abn, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model in resnet_DC.__all__:
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
model_name = model
|
||||
model = getattr(resnet_DC, model_name)(num_classes=num_classes)
|
||||
else:
|
||||
raise Exception('Unknown model')
|
||||
|
||||
return model, model_name
|
||||
|
||||
class ConfusionMatrix(object):
|
||||
""" Confusion matrix.
|
||||
|
||||
|
@ -120,7 +191,8 @@ class ConfusionMatrix(object):
|
|||
f1=f1.mean()
|
||||
return f1
|
||||
|
||||
def print_graph(PyTorch_obj, fig_name='graph'):
|
||||
#from torchviz import make_dot
|
||||
def print_graph(PyTorch_obj=torch.randn(1, 3, 32, 32), fig_name='graph'):
|
||||
"""Save the computational graph.
|
||||
|
||||
Args:
|
||||
|
@ -128,7 +200,7 @@ def print_graph(PyTorch_obj, fig_name='graph'):
|
|||
fig_name (string): Relative path where to save the graph. (default: graph)
|
||||
"""
|
||||
graph=make_dot(PyTorch_obj)
|
||||
graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||
graph.format = 'png' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||
graph.render(fig_name)
|
||||
|
||||
def plot_resV2(log, fig_name='res', param_names=None, f1=True):
|
||||
|
@ -157,8 +229,13 @@ def plot_resV2(log, fig_name='res', param_names=None, f1=True):
|
|||
#'''
|
||||
#print(log[0]["f1"])
|
||||
if isinstance(log[0]["f1"], list):
|
||||
for c in range(len(log[0]["f1"])):
|
||||
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
|
||||
if len(log[0]["f1"])>10:
|
||||
print("Plotting results : Too many class for F1, plotting only min/max")
|
||||
ax[1, 0].plot(epochs,[max(x["f1"])*100 for x in log], label='F1-Max', ls='--')
|
||||
ax[1, 0].plot(epochs,[min(x["f1"])*100 for x in log], label='F1-Min', ls='--')
|
||||
else:
|
||||
for c in range(len(log[0]["f1"])):
|
||||
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
|
||||
else:
|
||||
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1', ls='--')
|
||||
#'''
|
||||
|
@ -251,7 +328,6 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
|
|||
fig_name (string): Relative path where to save the graph. (default: data_sample)
|
||||
weight_labels (Tensor): Weights associated to each labels. (default: None)
|
||||
"""
|
||||
|
||||
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
|
||||
|
||||
plt.figure(figsize=(10,10))
|
||||
|
@ -262,7 +338,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
|
|||
plt.grid(False)
|
||||
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
|
||||
label = str(labels[i].item())
|
||||
if weight_labels is not None : label+= (" - p %.2f" % weight_labels[i].item())
|
||||
if torch.is_tensor(weight_labels): label+= (" - p %.2f" % weight_labels[i].item())
|
||||
plt.xlabel(label)
|
||||
|
||||
plt.savefig(fig_name)
|
||||
|
@ -348,10 +424,13 @@ def clip_norm(tensors, max_norm, norm_type=2):
|
|||
else:
|
||||
total_norm = 0
|
||||
for t in tensors:
|
||||
if t is None:
|
||||
continue
|
||||
param_norm = t.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
total_norm = total_norm ** (1. / norm_type)
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef >= 1:
|
||||
return tensors
|
||||
return [t.mul(clip_coef) for t in tensors]
|
||||
#return [t.mul(clip_coef) for t in tensors]
|
||||
return [t if t is None else t.mul(clip_coef) for t in tensors]
|
Loading…
Add table
Add a link
Reference in a new issue