Changes since Teledyne

This commit is contained in:
Antoine Harlé 2024-08-20 11:53:35 +02:00 committed by AntoineH
parent 03ffd7fe05
commit b89dac9084
185 changed files with 16668 additions and 484 deletions

View file

@ -0,0 +1,73 @@
import argparse
#Argparse
parser = argparse.ArgumentParser(description='Run smart augmentation')
parser.add_argument('-dv','--device', default='cuda', dest='device',
help='Device : cpu / cuda')
parser.add_argument('-dt','--dtype', default='FP32', dest='dtype',
help='Data type (Default: Float32)')
parser.add_argument('-m','--model', default='resnet18', dest='model',
help='Network')
parser.add_argument('-pt','--pretrained', default='', dest='pretrained',
help='Use pretrained weight if possible')
parser.add_argument('-ep','--epochs', type=int, default=10, dest='epochs',
help='epoch')
# parser.add_argument('-ot', '--optimizer', default='SGD', dest='opt_type',
# help='Model optimizer')
parser.add_argument('-lr', type=float, default=1e-1, dest='lr',
help='Model learning rate')
parser.add_argument('-mo', '--momentum', type=float, default=0.9, dest='momentum',
help='Momentum')
parser.add_argument('-dc', '--decay', type=float, default=0.0005, dest='decay',
help='Weight decay')
parser.add_argument('-ns','--nesterov', type=bool, default=False, dest='nesterov',
help='Nesterov momentum ?')
parser.add_argument('-sc', '--scheduler', default='cosine', dest='scheduler',
help='Model learning rate scheduler')
parser.add_argument('-wu', '--warmup', type=float, default=0, dest='warmup',
help='Warmup multiplier')
parser.add_argument('-a','--augment', type=bool, default=False, dest='augment',
help='Data augmentation ?')
parser.add_argument('-N', type=int, default=1,
help='Combination of TF')
parser.add_argument('-K', type=int, default=0,
help='Number inner iteration')
parser.add_argument('-al','--augment_loss', type=int, default=1, dest='augment_loss',
help='Number of augmented example for each sample in loss computation.')
parser.add_argument('-t', '--temp', type=float, default=0.5, dest='temp',
help='Probability distribution temperature')
parser.add_argument('-tfc','--tf_config', default='../config/invScale_wide_tf_config.json', dest='tf_config',
help='TF config')
parser.add_argument('-ls', '--learn_seq', type=bool, default=False, dest='learn_seq',
help='Learn order of application of TF (DataugV7-8) ?')
parser.add_argument('-fm', '--fixed_mag', type=bool, default=False, dest='fixed_mag',
help='Fixed magnitude when learning data augmentation ?')
parser.add_argument('-sm', '--shared_mag', type=bool, default=False, dest='shared_mag',
help='Shared magnitude when learning data augmentation ?')
# parser.add_argument('-mot', '--metaoptimizer', default='Adam', dest='meta_opt_type',
# help='Meta optimizer (Augmentations)')
parser.add_argument('-mlr', type=float, default=1e-2, dest='mlr',
help='Meta learning rate (Augmentations)')
parser.add_argument('-ms', type=int, default=0, dest='meta_epoch_start',
help='Epoch at which start meta learning')
parser.add_argument('-mr', type=float, default=0.001, dest='mag_reg',
help='Augmentation magnitudes regulation factor')
parser.add_argument('-rf','--res_folder', default='../res/', dest='res_folder',
help='Results folder')
parser.add_argument('-pf','--postfix', default='', dest='postfix',
help='Res postfix')
parser.add_argument('-dr','--dataroot', default='~/scratch/data', dest='dataroot',
help='Datasets folder')
parser.add_argument('-ds','--dataset', default='CIFAR10', dest='dataset',
help='Dataset')
parser.add_argument('-bs','--batch_size', type=int, default=256, dest='batch_size',
help='Batch size') #256 (WRN) / 512
parser.add_argument('-w','--workers', type=int, default=6, dest='workers',
help='Numer of workers (Nb CPU cores).')

View file

@ -7,19 +7,22 @@ from train_utils import *
from transformations import TF_loader
import torchvision.models as models
from LeNet import *
#model_list={models.resnet: ['resnet18', 'resnet50','wide_resnet50_2']} #lr=0.1
model_list={models.resnet: ['resnet18']}
model_list={models.resnet: ['wide_resnet50_2']}
optim_param={
'Meta':{
'optim':'Adam',
'lr':1e-2, #1e-2
'lr':5e-3, #1e-2
'epoch_start': 2, #0 / 2 (Resnet?)
'reg_factor': 0.001,
'scheduler': None, #None, 'multiStep'
},
'Inner':{
'optim': 'SGD',
'lr':1e-1, #1e-2/1e-1 (ResNet)
'lr':1e-2, #1e-2/1e-1 (ResNet)
'momentum':0.9, #0.9
'decay':0.0005, #0.0005
'nesterov':False, #False (True: Bad behavior w/ Data_aug)
@ -28,16 +31,17 @@ optim_param={
}
res_folder="../res/benchmark/CIFAR10/"
#res_folder="../res/benchmark/MNIST/"
#res_folder="../res/HPsearch/"
epochs= 200
dataug_epoch_start=0
nb_run= 1
nb_run= 3
tf_config='../config/wide_tf_config.json' #'../config/wide_tf_config.json'#'../config/base_tf_config.json'
tf_config='../config/bad_tf_config.json' #'../config/wide_tf_config.json'#'../config/base_tf_config.json'
TF_loader=TF_loader()
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(tf_config)
device = torch.device('cuda')
device = torch.device('cuda:1')
if device == torch.device('cpu'):
device_name = 'CPU'
@ -54,8 +58,8 @@ np.random.seed(0)
if __name__ == "__main__":
### Benchmark ###
#'''
inner_its = [3]
'''
inner_its = [0]
dist_mix = [0.5]
N_seq_TF= [3]
mag_setup = [(False, False)] #[(True, True), (False, False)] #(FxSh, Independant)
@ -74,6 +78,8 @@ if __name__ == "__main__":
t0 = time.perf_counter()
model = getattr(model_type, model_name)(pretrained=False, num_classes=len(dl_train.dataset.classes))
#model_name = 'LeNet'
#model = LeNet(3,10)
model = Higher_model(model, model_name) #run_dist_dataugV3
if n_inner_iter!=0:
@ -122,12 +128,17 @@ if __name__ == "__main__":
print('Log :\"',f.name, '\" saved !')
except:
print("Failed to save logs :",f.name)
try:
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names())
except:
print("Failed to plot res")
print(sys.exc_info()[1])
print('Execution Time : %.00f '%(exec_time))
print('-'*9)
#'''
### Benchmark - RandAugment/Vanilla ###
'''
### Benchmark - RandAugment/Vanilla ###
#'''
for model_type in model_list.keys():
for model_name in model_list[model_type]:
for run in range(nb_run):
@ -155,7 +166,7 @@ if __name__ == "__main__":
#"Rand_Aug": rand_aug,
"Log": log}
print(model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs -{}".format(model_name,epochs, run)
filename = "{}-{} epochs -{}-basicDA".format(model_name,epochs, run)
#print("RandAugment-",model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
#filename = "RandAugment(N{}-M{:.2f})-{}-{} epochs -{}".format(rand_aug['N'],rand_aug['M'],model_name,epochs, run)
with open(res_folder+"log/%s.json" % filename, "w+") as f:
@ -166,11 +177,14 @@ if __name__ == "__main__":
print("Failed to save logs :",f.name)
print(sys.exc_info()[1])
#plot_resV2(log, fig_name=res_folder+filename)
try:
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names())
except:
print("Failed to plot res")
print(sys.exc_info()[1])
print('Execution Time : %.00f '%(exec_time))
print('-'*9)
'''
#'''
### HP Search ###
'''
from LeNet import *

View file

@ -2,41 +2,47 @@
MNIST / CIFAR10 / CIFAR100 / SVHN / ImageNet
"""
import os
import torch
from torch.utils.data.dataset import ConcatDataset
import torchvision
from arg_parser import *
#Train/Validation batch size.
BATCH_SIZE = 512
#Test batch size.
TEST_SIZE = BATCH_SIZE
#TEST_SIZE = 10000 #legerement +Rapide / + Consomation memoire !
args = parser.parse_args()
#Wether to download data.
download_data=False
#Number of worker to use.
num_workers=2 #4
#Pin GPU memory
pin_memory=False #True :+ GPU memory / + Lent
#Data storage folder
dataroot="../data"
dataroot=args.dataroot
# if args.dtype == 'FP32':
# def_type=torch.float32
# elif args.dtype == 'FP16':
# # def_type=torch.float16 #Default : float32
# def_type=torch.bfloat16
# else:
# raise Exception('dtype not supported :', args.dtype)
#ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1]
#transform_train = torchvision.transforms.Compose([
# torchvision.transforms.RandomHorizontalFlip(),
# torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #CIFAR10
#])
transform = torchvision.transforms.Compose([
transform = [
#torchvision.transforms.Grayscale(3), #MNIST
#torchvision.transforms.Resize((224,224), interpolation=2)#VGG
torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #CIFAR10
])
#torchvision.transforms.Normalize(MEAN, STD), #CIFAR10
# torchvision.transforms.Lambda(lambda tensor: tensor.to(def_type)),
]
transform_train = torchvision.transforms.Compose([
transform_train = [
#transforms.RandomHorizontalFlip(),
#transforms.RandomVerticalFlip(),
#torchvision.transforms.Grayscale(3), #MNIST
#torchvision.transforms.Resize((224,224), interpolation=2)
torchvision.transforms.ToTensor(),
])
#torchvision.transforms.Normalize(MEAN, STD), #CIFAR10
# torchvision.transforms.Lambda(lambda tensor: tensor.to(def_type)),
]
## RandAugment ##
#from RandAugment import RandAugment
@ -49,20 +55,77 @@ transform_train = torchvision.transforms.Compose([
#transform_train.transforms.insert(0, RandAugment(n=rand_aug['N'], m=rand_aug['M']))
### Classic Dataset ###
BATCH_SIZE = args.batch_size
TEST_SIZE = BATCH_SIZE
# Load Dataset
if args.dataset == 'MNIST':
transform_train.insert(0, torchvision.transforms.Grayscale(3))
transform.insert(0, torchvision.transforms.Grayscale(3))
#MNIST
#data_train = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=transform_train)
#data_val = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=transform)
#data_test = torchvision.datasets.MNIST(dataroot, train=False, download=True, transform=transform)
val_set=False
data_train = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=torchvision.transforms.Compose(transform_train))
data_val = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=torchvision.transforms.Compose(transform))
data_test = torchvision.datasets.MNIST(dataroot, train=False, download=True, transform=torchvision.transforms.Compose(transform))
elif args.dataset == 'CIFAR10': #(32x32 RGB)
val_set=False
MEAN=(0.4914, 0.4822, 0.4465)
STD=(0.2023, 0.1994, 0.2010)
data_train = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform_train))
data_val = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform))
data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=download_data, transform=torchvision.transforms.Compose(transform))
elif args.dataset == 'CIFAR100': #(32x32 RGB)
val_set=False
MEAN=(0.4914, 0.4822, 0.4465)
STD=(0.2023, 0.1994, 0.2010)
data_train = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform_train))
data_val = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=torchvision.transforms.Compose(transform))
data_test = torchvision.datasets.CIFAR100(dataroot, train=False, download=download_data, transform=torchvision.transforms.Compose(transform))
elif args.dataset == 'TinyImageNet': #(Train:100k, Val:5k, Test:5k) (64x64 RGB)
image_size=64 #128 / 224
print('Using image size', image_size)
transform_train=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform_train
transform=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform
val_set=True
MEAN=(0.485, 0.456, 0.406)
STD=(0.229, 0.224, 0.225)
data_train = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/train'), transform=torchvision.transforms.Compose(transform_train))
data_val = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/val'), transform=torchvision.transforms.Compose(transform))
data_test = torchvision.datasets.ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/test'), transform=torchvision.transforms.Compose(transform))
elif args.dataset == 'ImageNet': #
image_size=128 #224
print('Using image size', image_size)
transform_train=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform_train
transform=[torchvision.transforms.Resize(image_size), torchvision.transforms.CenterCrop(image_size)]+transform
val_set=False
MEAN=(0.485, 0.456, 0.406)
STD=(0.229, 0.224, 0.225)
data_train = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/train'), transform=torchvision.transforms.Compose(transform_train))
data_val = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/train'), transform=torchvision.transforms.Compose(transform))
data_test = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'ImageNet/validation'), transform=torchvision.transforms.Compose(transform))
#CIFAR
data_train = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=transform_train)
data_val = torchvision.datasets.CIFAR10(dataroot, train=True, download=download_data, transform=transform)
data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=download_data, transform=transform)
else:
raise Exception('Unknown dataset')
# Ready dataloader
if not val_set : #Split Training set into Train/Val
#Validation set size [0, 1]
valid_size=0.1
train_subset_indices=range(int(len(data_train)*(1-valid_size)))
val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train))
#train_subset_indices=range(BATCH_SIZE*10)
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
from torch.utils.data import SubsetRandomSampler
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=args.workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=args.workers, pin_memory=pin_memory)
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=args.workers, pin_memory=pin_memory)
else:
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.workers, pin_memory=pin_memory)
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=args.workers, pin_memory=pin_memory)
#data_train = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=transform_train)
#data_val = torchvision.datasets.CIFAR100(dataroot, train=True, download=download_data, transform=transform)
#data_test = torchvision.datasets.CIFAR100(dataroot, train=False, download=download_data, transform=transform)
#SVHN
#trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=download_data, transform=transform_train)
@ -76,19 +139,6 @@ data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=downloa
#data_train = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='train', transform=transform_train)
#data_test = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)
#Validation set size [0, 1]
valid_size=0.1
train_subset_indices=range(int(len(data_train)*(1-valid_size)))
val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train))
#train_subset_indices=range(BATCH_SIZE*10)
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
from torch.utils.data import SubsetRandomSampler
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
#Cross Validation
'''
import numpy as np

View file

@ -18,13 +18,17 @@ import numpy as np
import copy
import transformations as TF
import torchvision
import higher
import higher_patch
from utils import clip_norm
from utils import clip_norm
from train_utils import compute_vaLoss
from datasets import MEAN, STD
norm = TF.Normalizer(MEAN, STD)
### Data augmenter ###
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
"""Data augmentation module with learnable parameters.
@ -46,19 +50,19 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
_fixed_mag (bool): Wether to lock the TF magnitudes.
_fixed_prob (bool): Wether to lock the TF probabilies.
_samples (list): Sampled TF index during last forward pass.
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
_fixed_mix (bool): Wether we lock the mix distribution factor.
_temp (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
_fixed_temp (bool): Wether we lock the mix distribution factor.
_params (nn.ParameterDict): Learnable parameters.
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
_reg_mask (list): Mask selecting the TF considered for the regularisation.
"""
def __init__(self, TF_dict, N_TF=1, mix_dist=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
def __init__(self, TF_dict, N_TF=1, temp=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
"""Init Data_augv5.
Args:
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0.5)
temp (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-temp)*Uniform_distribution + temp*Real_distribution. If None is given, try to learn this parameter. (default: 0.5)
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
@ -88,27 +92,30 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._fixed_prob=fixed_prob
self._samples = []
self._mix_dist = False
if mix_dist != 0.0: #Mix dist
self._mix_dist = True
# self._temp = False
# if temp != 0.0: #Mix dist
# self._temp = True
self._fixed_mix=True
if mix_dist is None: #Learn Mix dist
self._fixed_mix = False
mix_dist=0.5
self._fixed_temp=True
if temp is None: #Learn Temp
print("WARNING: Learning Temperature parameter isn't working with this version (No grad)")
self._fixed_temp = False
temp=0.5
#Params
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
#"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
"prob": nn.Parameter(torch.ones(self._nb_tf)),
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
"temp": nn.Parameter(torch.tensor(temp))#.clamp(min=0.0,max=0.999))
})
for tf in self._TF_ignore_mag :
self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
if not self._shared_mag:
for tf in self._TF_ignore_mag :
self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
#Mag regularisation
if not self._fixed_mag:
@ -117,7 +124,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
else:
TF_mag=[t for t in self._TF if t not in self._TF_ignore_mag] #TF w/ differentiable mag
self._reg_mask=[self._TF.index(t) for t in TF_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX, dtype=self._params['mag'].dtype) #Encourage amplitude max
#Prevent Identity
#print(TF.TF_identity)
@ -137,28 +144,44 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Tensor : Batch of tranformed data.
"""
self._samples = torch.Tensor([])
if self._data_augmentation:# and TF.random.random() < 0.5:
if self._data_augmentation:
device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
# x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
if not self._mix_dist:
self._distrib = uniforme_dist
else:
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"]
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
self._distrib = F.softmax(prob*temp, dim=0)
# prob = F.softmax(prob[1:], dim=0) #Bernouilli
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
self._samples=cat_distrib.sample([self._N_seqTF])
#Bernoulli (Requiert Identité en position 0)
#assert(self._TF[0]=="Identity")
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf-1), device=device)*self._distrib)
# bern_distrib = Bernoulli(torch.tensor([0.5], device=device))
# mask = bern_distrib.sample([self._N_seqTF, batch_size]).squeeze()
# self._samples=(cat_distrib.sample([self._N_seqTF])+1)*mask
for sample in self._samples:
## Transformations ##
x = self.apply_TF(x, sample)
# self._samples.to(device)
# for n in range(self._N_seqTF):
# # print('temp', (temp+0.3*n))
# self._distrib = F.softmax(prob*(temp+0.2*n), dim=0)
# # prob = F.softmax(prob[1:], dim=0) #Bernouilli
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
# new_sample=cat_distrib.sample()
# self._samples=torch.cat((self._samples.to(device).to(new_sample.dtype), new_sample.unsqueeze(dim=0)), dim=0)
# x = self.apply_TF(x, new_sample)
# print('sample',self._samples.shape)
return x
def apply_TF(self, x, sampled_TF):
@ -204,20 +227,20 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Args:
soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False)
"""
if not self._fixed_prob:
if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
else:
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
# if not self._fixed_prob:
# if soft :
# self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
# else:
# self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
# self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
if not self._fixed_mag:
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
if not self._fixed_mix:
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
if not self._fixed_temp:
self._params['temp'].data = self._params['temp'].data.clamp(min=0.0, max=0.999)
def loss_weight(self, mean_norm=False):
def loss_weight(self, batch_norm=True):
""" Weights for the loss.
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction.
@ -225,30 +248,37 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Do not take into account the order of application of the TF. See Data_augV7.
Args:
mean_norm (bool): Wether to normalize weights by mean or by distribution. (Default: Normalize by distribution.)
Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if mix_dist=1.
batch_norm (bool): Wether to normalize mean of the weights. (Default: True)
Returns:
Tensor : Loss weights.
"""
if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
#prob = F.softmax(prob, dim=0)
#Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !)
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
for sample in self._samples:
for sample in self._samples.to(torch.long):
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
w_loss += tmp_w
if mean_norm:
w_loss = w_loss * prob
w_loss = torch.sum(w_loss,dim=1)
#w_loss=w_loss/w_loss.sum(dim=1, keepdim=True) #Bernoulli
#Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
w_loss = w_loss * prob
w_loss = torch.sum(w_loss,dim=1)
if batch_norm:
w_min = w_loss.min()
w_loss = w_loss-w_min if w_min<0 else w_loss
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
else:
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
w_loss = torch.sum(w_loss,dim=1)
#Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if temp=1.
# w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
# w_loss = torch.sum(w_loss,dim=1)
return w_loss
def reg_loss(self, reg_factor=0.005):
@ -310,6 +340,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Returns:
nn.Parameter.
"""
if key == 'prob': #Override prob access
return F.softmax(self._params["prob"]*self._params["temp"], dim=0)
return self._params[key]
def __str__(self):
@ -323,22 +355,20 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
mag_param='Mag'
if self._fixed_mag: mag_param+= 'Fx'
if self._shared_mag: mag_param+= 'Sh'
if not self._mix_dist:
return "Data_augV5(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
elif self._fixed_mix:
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
# if not self._temp:
# return "Data_augV5(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
if self._fixed_temp:
return "Data_augV5(T%.1f%s-%dTFx%d-%s)" % (self._params['temp'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
else:
return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
return "Data_augV5(T%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
class Data_augV7(nn.Module): #Proba sequentielles
class Data_augV8(nn.Module): #Apprentissage proba sequentielles
"""Data augmentation module with learnable parameters.
Applies transformations (TF) to batch of data.
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
The TF probabilities defines a distribution from which we sample the TF applied.
Replace the use of TF by TF sets which are combinaisons of classic TF.
Attributes:
_data_augmentation (bool): Wether TF will be applied during forward pass.
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
@ -350,37 +380,34 @@ class Data_augV7(nn.Module): #Proba sequentielles
_fixed_mag (bool): Wether to lock the TF magnitudes.
_fixed_prob (bool): Wether to lock the TF probabilies.
_samples (list): Sampled TF index during last forward pass.
_mix_dist (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
_fixed_mix (bool): Wether we lock the mix distribution factor.
_temp (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
_fixed_temp (bool): Wether we lock the mix distribution factor.
_params (nn.ParameterDict): Learnable parameters.
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
_reg_mask (list): Mask selecting the TF considered for the regularisation.
"""
def __init__(self, TF_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
"""Init Data_augv7.
def __init__(self, TF_dict, N_TF=1, temp=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
"""Init Data_augv8.
Args:
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
N_TF (int): Number of TF to be applied sequentially to each inputs. Minimum 2, otherwise prefer using Data_augV5. (default: 2)
mix_dist (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-mix_dist)*Uniform_distribution + mix_dist*Real_distribution. If None is given, try to learn this parameter. (default: 0)
N_TF (int): Number of TF to be applied sequentially to each inputs. (default: 1)
temp (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-temp)*Uniform_distribution + temp*Real_distribution. If None is given, try to learn this parameter. (default: 0.5)
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
"""
super(Data_augV7, self).__init__()
super(Data_augV8, self).__init__()
assert len(TF_dict)>0
assert N_TF>=0
if N_TF<2:
print("WARNING: Data_augv7 isn't designed to use less than 2 sequentials TF. Please use Data_augv5 instead.")
self._data_augmentation = True
#TF
self._TF_dict = TF_dict
self._TF= list(self._TF_dict.keys())
self._TF_ignore_mag= TF_ignore_mag
self._TF_ignore_mag=TF_ignore_mag
self._nb_tf= len(self._TF)
self._N_seqTF = N_TF
@ -395,58 +422,50 @@ class Data_augV7(nn.Module): #Proba sequentielles
self._fixed_prob=fixed_prob
self._samples = []
self._mix_dist = False
if mix_dist != 0.0: #Mix dist
self._mix_dist = True
# self._temp = False
# if temp != 0.0: #Mix dist
# self._temp = True
self._fixed_mix=True
if mix_dist is None: #Learn Mix dist
self._fixed_mix = False
mix_dist=0.5
self._fixed_temp=True
if temp is None: #Learn temp
print("WARNING: Learning Temperature parameter isn't working with this version (No grad)")
self._fixed_temp = False
temp=0.5
#TF sets
#import itertools
#itertools.product(range(self._nb_tf), repeat=self._N_seqTF)
#no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}} #Specific No consecutive ops
no_consecutive={idx for idx, t in enumerate(self._TF) if t not in {'Identity'}} #No consecutive same ops (except Identity)
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF (exclude cons_test arrangement)
TF_sets=[]
if set_size>1:
for i in range(n_TF):
if not cons_test(i, idx_prefix):
TF_sets += generate_TF_sets(n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i])
else:
TF_sets+=[[idx_prefix+[i]] for i in range(n_TF) if not cons_test(i, idx_prefix)]
return TF_sets
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
self._nb_TF_sets=len(self._TF_sets)
print("Number of TF sets:",self._nb_TF_sets)
#print(self._TF_sets)
self._prob_mem=torch.zeros(self._nb_TF_sets)
#Params
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_TF_sets)/self._nb_TF_sets), #Distribution prob uniforme
#"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
# "prob": nn.Parameter(torch.ones([self._nb_tf for _ in range(self._N_seqTF)])),
"prob": nn.Parameter(torch.ones(self._nb_tf**self._N_seqTF)),
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
"mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999))
"temp": nn.Parameter(torch.tensor(temp))#.clamp(min=0.0,max=0.999))
})
#for tf in TF.TF_no_grad :
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
self._prob_mem=torch.zeros(self._nb_tf**self._N_seqTF)
if not self._shared_mag:
for tf in self._TF_ignore_mag :
self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
#Mag regularisation
if not self._fixed_mag:
if self._shared_mag :
self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
else:
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in self._TF_ignore_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
TF_mag=[t for t in self._TF if t not in self._TF_ignore_mag] #TF w/ differentiable mag
self._reg_mask=[self._TF.index(t) for t in TF_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX, dtype=self._params['mag'].dtype) #Encourage amplitude max
#Prevent Identity
#print(TF.TF_identity)
#self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=0.0)
#for val in TF.TF_identity.keys():
# idx=[self._reg_mask.index(self._TF.index(t)) for t in TF_mag if t in TF.TF_identity[val]]
# self._reg_tgt[idx]=val
#print(TF_mag, self._reg_tgt)
def forward(self, x):
""" Main method of the Data augmentation module.
@ -457,32 +476,54 @@ class Data_augV7(nn.Module): #Proba sequentielles
Returns:
Tensor : Batch of tranformed data.
"""
self._samples = None
if self._data_augmentation:# and TF.random.random() < 0.5:
self._samples = torch.Tensor([])
if self._data_augmentation:
device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
# x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1)
# if not self._temp:
# self._distrib = torch.ones(1,self._nb_tf**self._N_seqTF,device=device).softmax(dim=1)
# else:
# prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] #Uniform dist
# # print(prob.shape)
# # prob = prob.view(1, -1)
# # prob = F.softmax(prob, dim=0)
if not self._mix_dist:
self._distrib = uniforme_dist
else:
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"]
self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
# temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"] #Temperature
# self._distrib = F.softmax(temp*prob, dim=0)
# # self._distrib = (temp*prob+(1-temp)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
# # print(prob.shape)
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_TF_sets), device=device)*self._distrib)
sample = cat_distrib.sample()
self._samples=sample
TF_samples=self._TF_sets[sample,:].to(device) #[Batch_size, TFseq]
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"] #Temperature
self._distrib = F.softmax(temp*prob, dim=0)
for i in range(self._N_seqTF):
cat_distrib= Categorical(probs=torch.ones((self._nb_tf**self._N_seqTF), device=device)*self._distrib)
samples=cat_distrib.sample([batch_size]) # (batch_size)
# print(samples.shape)
samples=torch.zeros((batch_size, self._nb_tf**self._N_seqTF), dtype=torch.bool, device=device).scatter_(dim=1, index=samples.unsqueeze(dim=1), value=1)
self._samples=samples
# print(samples.shape)
# print(samples)
samples=samples.view((batch_size,)+tuple([self._nb_tf for _ in range(self._N_seqTF)]))
# print(samples.shape)
# print(samples)
samples= torch.nonzero(samples)[:,1:].T #Find indexes (TF sequence) => (N_seqTF, batch_size)
# print(samples.shape)
#Bernoulli (Requiert Identité en position 0)
#assert(self._TF[0]=="Identity")
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf-1), device=device)*self._distrib)
# bern_distrib = Bernoulli(torch.tensor([0.5], device=device))
# mask = bern_distrib.sample([self._N_seqTF, batch_size]).squeeze()
# self._samples=(cat_distrib.sample([self._N_seqTF])+1)*mask
for sample in samples:
## Transformations ##
x = self.apply_TF(x, TF_samples[:,i])
x = self.apply_TF(x, sample)
return x
def apply_TF(self, x, sampled_TF):
@ -526,37 +567,55 @@ class Data_augV7(nn.Module): #Proba sequentielles
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
Args:
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False)
"""
if not self._fixed_prob:
if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
else:
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
# if not self._fixed_prob:
# if soft :
# self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
# else:
# self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
# self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
if not self._fixed_mag:
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
if not self._fixed_mix:
self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999)
if not self._fixed_temp:
self._params['temp'].data = self._params['temp'].data.clamp(min=0.0, max=0.999)
def loss_weight(self):
def loss_weight(self, batch_norm=True):
""" Weights for the loss.
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction.
Args:
batch_norm (bool): Wether to normalize mean of the weights. (Default: True)
Returns:
Tensor : Loss weights.
"""
if self._samples is None : return 1 #Pas d'echantillon = pas de ponderation
device=self._params["prob"].device
if len(self._samples)==0 : return torch.tensor(1, device=device) #Pas d'echantillon = pas de ponderation
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
w_loss = torch.zeros((self._samples.shape[0],self._nb_TF_sets), device=self._samples.device)
w_loss.scatter_(1, self._samples.view(-1,1), 1)
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
# print("prob",prob.shape)
# print(self._samples.shape)
#w_loss=w_loss/w_loss.sum(dim=1, keepdim=True) #Bernoulli
#Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
w_loss = self._samples * prob
w_loss = torch.sum(w_loss,dim=1)
# print("W_loss",w_loss.shape)
# print(w_loss)
if batch_norm:
w_min = w_loss.min()
w_loss = w_loss-w_min if w_min<0 else w_loss
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
#Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if temp=1.
# w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
# w_loss = torch.sum(w_loss,dim=1)
return w_loss
def reg_loss(self, reg_factor=0.005):
@ -573,30 +632,43 @@ class Data_augV7(nn.Module): #Proba sequentielles
else:
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Close to target ?
#max_mag_reg = - reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean') #Far from target ?
return max_mag_reg
def TF_prob(self):
""" Gives an estimation of the individual TF probabilities.
Be warry that the probability returned isn't exact. The TF distribution isn't fully represented by those.
Each probability should be taken individualy. They only represent the chance for a specific TF to be picked at least once.
Returms:
Tensor containing the single TF probabilities of applications.
"""
if torch.all(self._params['prob']!=self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
self._prob_mem=self._params['prob'].data.detach_()
self._single_TF_prob=torch.zeros(self._nb_tf)
for idx_tf in range(self._nb_tf):
for i, t_set in enumerate(self._TF_sets):
#uni, count = np.unique(t_set, return_counts=True)
#if idx_tf in uni:
# res[idx_tf]+=self._params['prob'][i]*int(count[np.where(uni==idx_tf)])
if idx_tf in t_set:
self._single_TF_prob[idx_tf]+=self._params['prob'][i]
# if not torch.all(self._params['prob']==self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
# self._prob_mem=self._params['prob'].data.detach_()
return self._single_TF_prob
# p = self._params['prob'].view([self._nb_tf for _ in range(self._N_seqTF)])
# # print('prob',p)
# self._single_TF_prob=p.mean(dim=[i+1 for i in range(self._N_seqTF-1)]) #Reduce to 1D tensor
# # print(self._single_TF_prob)
# self._single_TF_prob=F.softmax(self._single_TF_prob, dim=0)
# print('Soft',self._single_TF_prob)
p=F.softmax(self._params['prob']*self._params["temp"], dim=0) #Sampling dist
p=p.view([self._nb_tf for _ in range(self._N_seqTF)])
p=p.mean(dim=[i+1 for i in range(self._N_seqTF-1)]) #Reduce to 1D tensor
#Means over each dim
# dim_idx=tuple(range(self._N_seqTF))
# means=[]
# for d in dim_idx:
# dim_mean=list(dim_idx)
# dim_mean.remove(d)
# means.append(p.mean(dim=dim_mean).unsqueeze(dim=1))
# means=torch.cat(means,dim=1)
# print(means)
# p=means.mean(dim=1)
# print(p)
return p
def train(self, mode=True):
""" Set the module training mode.
@ -607,7 +679,7 @@ class Data_augV7(nn.Module): #Proba sequentielles
#if mode is None :
# mode=self._data_augmentation
self.augment(mode=mode) #Inutile si mode=None
super(Data_augV7, self).train(mode)
super(Data_augV8, self).train(mode)
return self
def eval(self):
@ -654,12 +726,13 @@ class Data_augV7(nn.Module): #Proba sequentielles
mag_param='Mag'
if self._fixed_mag: mag_param+= 'Fx'
if self._shared_mag: mag_param+= 'Sh'
if not self._mix_dist:
return "Data_augV7(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
elif self._fixed_mix:
return "Data_augV7(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
# if not self._temp:
# return "Data_augV8(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
if self._fixed_temp:
return "Data_augV8(T%.1f%s-%dTFx%d-%s)" % (self._params['temp'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
else:
return "Data_augV7(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
return "Data_augV8(T%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
"""RandAugment implementation.
@ -703,7 +776,7 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
self._shared_mag = True
self._fixed_mag = True
self._fixed_prob=True
self._fixed_mix=True
self._fixed_temp=True
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
@ -716,17 +789,24 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
Returns:
Tensor : Batch of tranformed data.
"""
if self._data_augmentation:# and TF.random.random() < 0.5:
if self._data_augmentation:
device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
# x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
## Echantillonage ## == sampled_ops = np.random.choice(transforms, N)
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist)
self._samples=cat_distrib.sample([self._N_seqTF])
#Bernoulli (Requiert Identité en position 0)
# uniforme_dist = torch.ones(1,self._nb_tf-1,device=device).softmax(dim=1)
# cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf-1), device=device)*uniforme_dist)
# bern_distrib = Bernoulli(torch.tensor([0.5], device=device))
# mask = bern_distrib.sample([self._N_seqTF, batch_size]).squeeze()
# self._samples=(cat_distrib.sample([self._N_seqTF])+1)*mask
for sample in self._samples:
## Transformations ##
x = self.apply_TF(x, sample)
@ -765,10 +845,10 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
"""
pass #Pas de parametre a opti
def loss_weight(self):
def loss_weight(self, batch_norm=False):
"""Not used
"""
return 1 #Pas d'echantillon = pas de ponderation
return torch.tensor([1], device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
def reg_loss(self, reg_factor=0.005):
"""Not used
@ -949,18 +1029,22 @@ class Augmented_model(nn.Module):
self.augment(mode=True)
def forward(self, x):
def forward(self, x, copy=False):
""" Main method of the Augmented model.
Perform the forward pass of both modules.
Args:
x (Tensor): Batch of data.
copy (Bool): Wether to alter a copy or the original input. It's recommended to use a copy for parallel use of the input. (Default: False)
Returns:
Tensor : Output of the networks. Should be logits.
"""
return self._mods['model'](self._mods['data_aug'](x))
if copy:
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
return self._mods['model'](norm(self._mods['data_aug'](x)))
# return self._mods['model'](self._mods['data_aug'](x))
def augment(self, mode=True):
""" Set the augmentation mode.
@ -970,6 +1054,12 @@ class Augmented_model(nn.Module):
"""
self._data_augmentation=mode
self._mods['data_aug'].augment(mode)
#ABN
# if mode :
# self._mods['model']['functional'].set_mode('augmented')
# else :
# self._mods['model']['functional'].set_mode('clean')
#### Encapsulation Meta Opt ####
def start_bilevel_opt(self, inner_it, hp_list, opt_param, dl_val):

View file

@ -0,0 +1,56 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
## Basic CNN ##
class LeNet(nn.Module):
"""Basic CNN.
"""
def __init__(self, num_inp, num_out):
"""Init LeNet.
"""
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(num_inp, 20, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(20, 50, 5)
self.pool2 = nn.MaxPool2d(2, 2)
#self.fc1 = nn.Linear(4*4*50, 500)
self.fc1 = nn.Linear(5*5*50, 500)
self.fc2 = nn.Linear(500, num_out)
def forward(self, x):
"""Main method of LeNet
"""
x = self.pool(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def __str__(self):
""" Get name of model
"""
return "LeNet"
#MNIST
class MLPNet(nn.Module):
def __init__(self):
super(MLPNet, self).__init__()
self.fc1 = nn.Linear(28*28, 500)
self.fc2 = nn.Linear(500, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def name(self):
return "MLP"

View file

@ -0,0 +1,426 @@
import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
# __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
# 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
# 'wide_resnet50_2', 'wide_resnet101_2']
# model_urls = {
# 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
# 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
# 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
# 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
# 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
# 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
# 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
# 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
# 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
# }
__all__ = ['ResNet_ABN', 'resnet18_ABN', 'resnet34_ABN', 'resnet50_ABN', 'resnet101_ABN',
'resnet152_ABN', 'resnext50_32x4d_ABN', 'resnext101_32x8d_ABN',
'wide_resnet50_2_ABN', 'wide_resnet101_2_ABN']
class aux_batchNorm(nn.Module):
def __init__(self, norm_layer, nb_features):
super(aux_batchNorm, self).__init__()
self.mode='clean'
self.bn=nn.ModuleDict({
'clean': norm_layer(nb_features),
'augmented': norm_layer(nb_features)
})
def forward(self, x):
if self.mode is 'mixed':
running_mean=(self.bn['clean'].running_mean+self.bn['augmented'].running_mean)/2
running_var=(self.bn['clean'].running_var+self.bn['augmented'].running_var)/2
return nn.functional.batch_norm(x, running_mean, running_var, self.bn['clean'].weight, self.bn['clean'].bias)
return self.bn[self.mode](x)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock_ABN(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock_ABN, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
#self.bn1 = norm_layer(planes)
self.bn1 = aux_batchNorm(norm_layer, planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
#self.bn2 = norm_layer(planes)
self.bn2 = aux_batchNorm(norm_layer, planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck_ABN(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck_ABN, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
#self.bn1 = norm_layer(width)
self.bn1 = aux_batchNorm(norm_layer, width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
# self.bn2 = norm_layer(width)
self.bn2 = aux_batchNorm(norm_layer, width)
self.conv3 = conv1x1(width, planes * self.expansion)
# self.bn3 = norm_layer(planes * self.expansion)
self.bn3 = aux_batchNorm(norm_layer, planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet_ABN(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet_ABN, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
#self.bn1 = norm_layer(self.inplanes)
self.bn1 = aux_batchNorm(norm_layer, self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
print('WARNING : zero_init_residual not implemented with ABN')
# for m in self.modules():
# if isinstance(m, Bottleneck):
# nn.init.constant_(m.bn3.weight, 0)
# elif isinstance(m, BasicBlock):
# nn.init.constant_(m.bn2.weight, 0)
# Memoire des BN layers pas fonctinnel avec Higher
# self.bn_layers=[]
# for m in self.modules():
# if isinstance(m, aux_batchNorm):
# self.bn_layers.append(m)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
#norm_layer(planes * block.expansion),
aux_batchNorm(norm_layer, planes * block.expansion)
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def set_mode(self, mode):
# for bn in self.bn_layers:
for m in self.modules():
if isinstance(m, aux_batchNorm):
m.mode=mode
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
# model = ResNet(block, layers, **kwargs)
# if pretrained:
# state_dict = load_state_dict_from_url(model_urls[arch],
# progress=progress)
# model.load_state_dict(state_dict)
# return model
def resnet18_ABN(pretrained=False, progress=True, **kwargs):
"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
# **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(BasicBlock_ABN, [2, 2, 2, 2], **kwargs)
def resnet34_ABN(pretrained=False, progress=True, **kwargs):
"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
# **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(BasicBlock_ABN, [3, 4, 6, 3], **kwargs)
def resnet50_ABN(pretrained=False, progress=True, **kwargs):
"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
# **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
def resnet101_ABN(pretrained=False, progress=True, **kwargs):
"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
# **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
def resnet152_ABN(pretrained=False, progress=True, **kwargs):
"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
# **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 8, 36, 3], **kwargs)
def resnext50_32x4d_ABN(pretrained=False, progress=True, **kwargs):
"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
# return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
# pretrained, progress, **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
def resnext101_32x8d_ABN(pretrained=False, progress=True, **kwargs):
"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
# return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
# pretrained, progress, **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
def wide_resnet50_2_ABN(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
# return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
# pretrained, progress, **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
def wide_resnet101_2_ABN(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
# return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
# pretrained, progress, **kwargs)
if(pretrained):
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)

View file

@ -0,0 +1,618 @@
'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
https://github.com/yechengxi/deconvolution
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import conv
from torch.nn.modules.utils import _pair
from functools import partial
__all__ = ['ResNet18_DC', 'ResNet34_DC', 'ResNet50_DC', 'ResNet101_DC', 'ResNet152_DC', 'WRN_DC26_10']
### Deconvolution ###
#iteratively solve for inverse sqrt of a matrix
def isqrt_newton_schulz_autograd(A, numIters):
dim = A.shape[0]
normA=A.norm()
Y = A.div(normA)
I = torch.eye(dim,dtype=A.dtype,device=A.device)
Z = torch.eye(dim,dtype=A.dtype,device=A.device)
for i in range(numIters):
T = 0.5*(3.0*I - Z@Y)
Y = Y@T
Z = T@Z
#A_sqrt = Y*torch.sqrt(normA)
A_isqrt = Z / torch.sqrt(normA)
return A_isqrt
def isqrt_newton_schulz_autograd_batch(A, numIters):
batchSize,dim,_ = A.shape
normA=A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1)
Y = A.div(normA)
I = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
Z = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
for i in range(numIters):
T = 0.5*(3.0*I - Z.bmm(Y))
Y = Y.bmm(T)
Z = T.bmm(Z)
#A_sqrt = Y*torch.sqrt(normA)
A_isqrt = Z / torch.sqrt(normA)
return A_isqrt
#deconvolve channels
class ChannelDeconv(nn.Module):
def __init__(self, block, eps=1e-2,n_iter=5,momentum=0.1,sampling_stride=3):
super(ChannelDeconv, self).__init__()
self.eps = eps
self.n_iter=n_iter
self.momentum=momentum
self.block = block
self.register_buffer('running_mean1', torch.zeros(block, 1))
#self.register_buffer('running_cov', torch.eye(block))
self.register_buffer('running_deconv', torch.eye(block))
self.register_buffer('running_mean2', torch.zeros(1, 1))
self.register_buffer('running_var', torch.ones(1, 1))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.sampling_stride=sampling_stride
def forward(self, x):
x_shape = x.shape
if len(x.shape)==2:
x=x.view(x.shape[0],x.shape[1],1,1)
if len(x.shape)==3:
print('Error! Unsupprted tensor shape.')
N, C, H, W = x.size()
B = self.block
#take the first c channels out for deconv
c=int(C/B)*B
if c==0:
print('Error! block should be set smaller.')
#step 1. remove mean
if c!=C:
x1=x[:,:c].permute(1,0,2,3).contiguous().view(B,-1)
else:
x1=x.permute(1,0,2,3).contiguous().view(B,-1)
if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride:
x1_s = x1[:,::self.sampling_stride**2]
else:
x1_s=x1
mean1 = x1_s.mean(-1, keepdim=True)
if self.num_batches_tracked==0:
self.running_mean1.copy_(mean1.detach())
if self.training:
self.running_mean1.mul_(1-self.momentum)
self.running_mean1.add_(mean1.detach()*self.momentum)
else:
mean1 = self.running_mean1
x1=x1-mean1
#step 2. calculate deconv@x1 = cov^(-0.5)@x1
if self.training:
cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(B, dtype=x.dtype, device=x.device)
deconv = isqrt_newton_schulz_autograd(cov, self.n_iter)
if self.num_batches_tracked==0:
#self.running_cov.copy_(cov.detach())
self.running_deconv.copy_(deconv.detach())
if self.training:
#self.running_cov.mul_(1-self.momentum)
#self.running_cov.add_(cov.detach()*self.momentum)
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
# cov = self.running_cov
deconv = self.running_deconv
x1 =deconv@x1
#reshape to N,c,J,W
x1 = x1.view(c, N, H, W).contiguous().permute(1,0,2,3)
# normalize the remaining channels
if c!=C:
x_tmp=x[:, c:].view(N,-1)
if self.sampling_stride > 1 and H>=self.sampling_stride and W>=self.sampling_stride:
x_s = x_tmp[:, ::self.sampling_stride ** 2]
else:
x_s = x_tmp
mean2=x_s.mean()
var=x_s.var()
if self.num_batches_tracked == 0:
self.running_mean2.copy_(mean2.detach())
self.running_var.copy_(var.detach())
if self.training:
self.running_mean2.mul_(1 - self.momentum)
self.running_mean2.add_(mean2.detach() * self.momentum)
self.running_var.mul_(1 - self.momentum)
self.running_var.add_(var.detach() * self.momentum)
else:
mean2 = self.running_mean2
var = self.running_var
x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt()
x1 = torch.cat([x1, x_tmp], dim=1)
if self.training:
self.num_batches_tracked.add_(1)
if len(x_shape)==2:
x1=x1.view(x_shape)
return x1
#An alternative implementation
class Delinear(nn.Module):
__constants__ = ['bias', 'in_features', 'out_features']
def __init__(self, in_features, out_features, bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=512):
super(Delinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
if block > in_features:
block = in_features
else:
if in_features%block!=0:
block=math.gcd(block,in_features)
print('block size set to:', block)
self.block = block
self.momentum = momentum
self.n_iter = n_iter
self.eps = eps
self.register_buffer('running_mean', torch.zeros(self.block))
self.register_buffer('running_deconv', torch.eye(self.block))
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
if self.training:
# 1. reshape
X=input.view(-1, self.block)
# 2. subtract mean
X_mean = X.mean(0)
X = X - X_mean.unsqueeze(0)
self.running_mean.mul_(1 - self.momentum)
self.running_mean.add_(X_mean.detach() * self.momentum)
# 3. calculate COV, COV^(-0.5), then deconv
# Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Id = torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
# track stats for evaluation
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
X_mean = self.running_mean
deconv = self.running_deconv
w = self.weight.view(-1, self.block) @ deconv
b = self.bias
if self.bias is not None:
b = b - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
w = w.view(self.weight.shape)
return F.linear(input, w, b)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
class FastDeconv(conv._ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,freeze=False,freeze_iter=100):
self.momentum = momentum
self.n_iter = n_iter
self.eps = eps
self.counter=0
self.track_running_stats=True
super(FastDeconv, self).__init__(
in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
False, _pair(0), groups, bias, padding_mode='zeros')
if block > in_channels:
block = in_channels
else:
if in_channels%block!=0:
block=math.gcd(block,in_channels)
if groups>1:
#grouped conv
block=in_channels//groups
self.block=block
self.num_features = kernel_size**2 *block
if groups==1:
self.register_buffer('running_mean', torch.zeros(self.num_features))
self.register_buffer('running_deconv', torch.eye(self.num_features))
else:
self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels))
self.register_buffer('running_deconv', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))
self.sampling_stride=sampling_stride*stride
self.counter=0
self.freeze_iter=freeze_iter
self.freeze=freeze
def forward(self, x):
N, C, H, W = x.shape
B = self.block
frozen=self.freeze and (self.counter>self.freeze_iter)
if self.training and self.track_running_stats:
self.counter+=1
self.counter %= (self.freeze_iter * 10)
if self.training and (not frozen):
# 1. im2col: N x cols x pixels -> N*pixles x cols
if self.kernel_size[0]>1:
X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
else:
#channel wise
X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]
if self.groups==1:
# (C//B*N*pixels,k*k*B)
X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
else:
X=X.view(-1,X.shape[-1])
# 2. subtract mean
X_mean = X.mean(0)
X = X - X_mean.unsqueeze(0)
# 3. calculate COV, COV^(-0.5), then deconv
if self.groups==1:
#Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Id=torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
else:
X = X.view(-1, self.groups, self.num_features).transpose(0, 1)
Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
Cov = torch.baddbmm(self.eps, Id, 1. / X.shape[1], X.transpose(1, 2), X)
deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter)
if self.track_running_stats:
self.running_mean.mul_(1 - self.momentum)
self.running_mean.add_(X_mean.detach() * self.momentum)
# track stats for evaluation
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
X_mean = self.running_mean
deconv = self.running_deconv
#4. X * deconv * conv = X * (deconv * conv)
if self.groups==1:
w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ deconv
b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
else:
w = self.weight.view(C//B, -1,self.num_features)@deconv
b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)
w = w.view(self.weight.shape)
x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)
return x
### ResNet
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, deconv=None):
super(BasicBlock, self).__init__()
if deconv:
self.conv1 = deconv(in_planes, planes, kernel_size=3, stride=stride, padding=1)
self.conv2 = deconv(planes, planes, kernel_size=3, stride=1, padding=1)
self.deconv = True
else:
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.deconv = False
self.shortcut = nn.Sequential()
if not deconv:
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
#self.bn1 = nn.GroupNorm(planes//16,planes)
#self.bn2 = nn.GroupNorm(planes//16,planes)
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
#nn.GroupNorm(self.expansion * planes//16,self.expansion * planes)
)
else:
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
deconv(in_planes, self.expansion*planes, kernel_size=1, stride=stride)
)
def forward(self, x):
if self.deconv:
out = F.relu(self.conv1(x))
out = self.conv2(out)
out += self.shortcut(x)
out = F.relu(out)
return out
else: #self.batch_norm:
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, deconv=None):
super(Bottleneck, self).__init__()
if deconv:
self.deconv = True
self.conv1 = deconv(in_planes, planes, kernel_size=1)
self.conv2 = deconv(planes, planes, kernel_size=3, stride=stride, padding=1)
self.conv3 = deconv(planes, self.expansion*planes, kernel_size=1)
else:
self.deconv = False
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.shortcut = nn.Sequential()
if not deconv:
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
else:
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
deconv(in_planes, self.expansion * planes, kernel_size=1, stride=stride)
)
def forward(self, x):
"""
No batch normalization for deconv.
"""
if self.deconv:
out = F.relu((self.conv1(x)))
out = F.relu((self.conv2(out)))
out = self.conv3(out)
out += self.shortcut(x)
out = F.relu(out)
return out
else:
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10, deconv=None,channel_deconv=None):
super(ResNet, self).__init__()
self.in_planes = 64
if deconv:
self.deconv = True
self.conv1 = deconv(3, 64, kernel_size=3, stride=1, padding=1)
else:
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
if not deconv:
self.bn1 = nn.BatchNorm2d(64)
#this line is really recent, take extreme care if the result is not good.
if channel_deconv:
self.deconv1=channel_deconv()
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, deconv=deconv)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, deconv=deconv)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, deconv=deconv)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, deconv=deconv)
self.linear = nn.Linear(512*block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride, deconv):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride, deconv))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
if hasattr(self,'bn1'):
out = F.relu(self.bn1(self.conv1(x)))
else:
out = F.relu(self.conv1(x))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
if hasattr(self, 'deconv1'):
out = self.deconv1(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def_deconv = partial(FastDeconv,bias=True, eps=1e-5, n_iter=5,block=64,sampling_stride=3)
#channel_deconv=partial(ChannelDeconv, block=512,eps=1e-5, n_iter=5,sampling_stride=3) #Pas forcément conseillé
def ResNet18_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet34_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet50_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet101_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet152_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
import math
class Wide_ResNet_Cifar_DC(nn.Module):
def __init__(self, block, layers, wfactor, num_classes=10, deconv=None, channel_deconv=None):
super(Wide_ResNet_Cifar_DC, self).__init__()
self.depth=layers[0]*6+2
self.widen_factor=wfactor
self.inplanes = 16
self.conv1 = deconv(3, 16, kernel_size=3, stride=1, padding=1)
if channel_deconv:
self.deconv1=channel_deconv()
# self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
# self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16*wfactor, layers[0], stride=1, deconv=deconv)
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2, deconv=deconv)
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2, deconv=deconv)
self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, num_blocks, stride, deconv):
# downsample = None
# if stride != 1 or self.inplanes != planes * block.expansion:
# downsample = nn.Sequential(
# nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(planes * block.expansion)
# )
# layers = []
# layers.append(block(self.inplanes, planes, stride, downsample))
# self.inplanes = planes * block.expansion
# for _ in range(1, blocks):
# layers.append(block(self.inplanes, planes))
# return nn.Sequential(*layers)
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.inplanes, planes, stride, deconv))
self.inplanes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
# x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if hasattr(self, 'deconv1'):
out = self.deconv1(out)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def __str__(self):
""" Get name of model
"""
return "Wide_ResNet_cifar_DC%d_%d"%(self.depth,self.widen_factor)
def WRN_DC26_10(depth=26, width=10, deconv=def_deconv, channel_deconv=None, **kwargs):
assert (depth - 2) % 6 == 0
n = int((depth - 2) / 6)
return Wide_ResNet_Cifar_DC(BasicBlock, [n, n, n], width, deconv=deconv,channel_deconv=channel_deconv, **kwargs)
def test():
net = ResNet18_DC()
y = net(torch.randn(1,3,32,32))
print(y.size())
# test()

View file

@ -0,0 +1,98 @@
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import numpy as np
_bn_momentum = 0.1
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
def conv_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.xavier_uniform_(m.weight, gain=np.sqrt(2))
init.constant_(m.bias, 0)
elif classname.find('BatchNorm') != -1:
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
class WideBasic(nn.Module):
def __init__(self, in_planes, planes, dropout_rate, stride=1):
super(WideBasic, self).__init__()
assert dropout_rate==0.0, 'dropout layer not used'
self.bn1 = nn.BatchNorm2d(in_planes, momentum=_bn_momentum)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
#self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(planes, momentum=_bn_momentum)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
)
def forward(self, x):
# out = self.dropout(self.conv1(F.relu(self.bn1(x))))
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out += self.shortcut(x)
return out
class WideResNet(nn.Module):
def __init__(self, depth, widen_factor, dropout_rate, num_classes):
super(WideResNet, self).__init__()
self.depth=depth
self.widen_factor=widen_factor
self.in_planes = 16
assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
n = int((depth - 4) / 6)
k = widen_factor
nStages = [16, 16*k, 32*k, 64*k]
self.conv1 = conv3x3(3, nStages[0])
self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=_bn_momentum)
self.linear = nn.Linear(nStages[3], num_classes)
# self.apply(conv_init)
def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, dropout_rate, stride))
self.in_planes = planes
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
# out = F.avg_pool2d(out, 8)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def __str__(self):
""" Get name of model
"""
return "Wide_ResNet%d_%d"%(self.depth,self.widen_factor)

View file

@ -0,0 +1,119 @@
"""
wide resnet for cifar in pytorch
Reference:
[1] S. Zagoruyko and N. Komodakis. Wide residual networks. In BMVC, 2016.
"""
import torch
import torch.nn as nn
import math
#from models.resnet_cifar import BasicBlock
def conv3x3(in_planes, out_planes, stride=1):
" 3x3 convolution with padding "
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion=1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Wide_ResNet_Cifar(nn.Module):
def __init__(self, block, layers, wfactor, num_classes=10):
super(Wide_ResNet_Cifar, self).__init__()
self.depth=layers[0]*6+2
self.widen_factor=wfactor
self.inplanes = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16*wfactor, layers[0])
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2)
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2)
self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion)
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def __str__(self):
""" Get name of model
"""
return "Wide_ResNet_cifar%d_%d"%(self.depth,self.widen_factor)
def wide_resnet_cifar(depth, width, **kwargs):
assert (depth - 2) % 6 == 0
n = int((depth - 2) / 6)
return Wide_ResNet_Cifar(BasicBlock, [n, n, n], width, **kwargs)
if __name__=='__main__':
net = wide_resnet_cifar(20, 10)
y = net(torch.randn(1, 3, 32, 32))
print(isinstance(net, Wide_ResNet_Cifar))
print(y.size())

View file

@ -1063,3 +1063,362 @@ class AugmentedDataset(VisionDataset):
def __str__(self):
return "CIFAR10(Sup:{}-Unsup:{}-{}TF)".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF))
class Data_augV7(nn.Module): #Proba sequentielles
"""Data augmentation module with learnable parameters.
Applies transformations (TF) to batch of data.
Each TF is defined by a (name, probability of application, magnitude of distorsion) tuple which can be learned. For the full definiton of the TF, see transformations.py.
The TF probabilities defines a distribution from which we sample the TF applied.
Replace the use of TF by TF sets which are combinaisons of classic TF.
Attributes:
_data_augmentation (bool): Wether TF will be applied during forward pass.
_TF_dict (dict) : A dictionnary containing the data transformations (TF) to be applied.
_TF (list) : List of TF names.
_TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
_nb_tf (int) : Number of TF used.
_N_seqTF (int) : Number of TF to be applied sequentially to each inputs
_shared_mag (bool) : Wether to share a single magnitude parameters for all TF. Beware using shared mag with basic color TF as their lowest magnitude is at PARAMETER_MAX/2.
_fixed_mag (bool): Wether to lock the TF magnitudes.
_fixed_prob (bool): Wether to lock the TF probabilies.
_samples (list): Sampled TF index during last forward pass.
_temp (bool): Wether we use a mix of an uniform distribution and the real distribution (TF probabilites). If False, only a uniform distribution is used.
_fixed_temp (bool): Wether we lock the mix distribution factor.
_params (nn.ParameterDict): Learnable parameters.
_reg_tgt (Tensor): Target for the magnitude regularisation. Only used when _fixed_mag is set to false (ie. we learn the magnitudes).
_reg_mask (list): Mask selecting the TF considered for the regularisation.
"""
def __init__(self, TF_dict, N_TF=1, temp=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True, TF_ignore_mag=TF.TF_ignore_mag):
"""Init Data_augv7.
Args:
TF_dict (dict): A dictionnary containing the data transformations (TF) to be applied. (default: use all available TF from transformations.py)
N_TF (int): Number of TF to be applied sequentially to each inputs. Minimum 2, otherwise prefer using Data_augV5. (default: 2)
temp (float): Proportion [0.0, 1.0] of the real distribution used for sampling/selection of the TF. Distribution = (1-temp)*Uniform_distribution + temp*Real_distribution. If None is given, try to learn this parameter. (default: 0)
fixed_prob (bool): Wether to lock the TF probabilies. (default: False)
fixed_mag (bool): Wether to lock the TF magnitudes. (default: True)
shared_mag (bool): Wether to share a single magnitude parameters for all TF. (default: True)
TF_ignore_mag (set): TF for which magnitude should be ignored (either it's fixed or unused).
"""
super(Data_augV7, self).__init__()
assert len(TF_dict)>0
assert N_TF>=0
if N_TF<2:
print("WARNING: Data_augv7 isn't designed to use less than 2 sequentials TF. Please use Data_augv5 instead.")
self._data_augmentation = True
#TF
self._TF_dict = TF_dict
self._TF= list(self._TF_dict.keys())
self._TF_ignore_mag= TF_ignore_mag
self._nb_tf= len(self._TF)
self._N_seqTF = N_TF
#Mag
self._shared_mag = shared_mag
self._fixed_mag = fixed_mag
if not self._fixed_mag and len([tf for tf in self._TF if tf not in self._TF_ignore_mag])==0:
print("WARNING: Mag would be fixed as current TF doesn't allow gradient propagation:",self._TF)
self._fixed_mag=True
#Distribution
self._fixed_prob=fixed_prob
self._samples = []
# self._temp = False
# if temp != 0.0: #Mix dist
# self._temp = True
self._fixed_temp=True
if temp is None: #Learn Temperature
print("WARNING: Learning Temperature parameter isn't working with this version (No grad)")
self._fixed_temp = False
temp=0.5
#TF sets
#import itertools
#itertools.product(range(self._nb_tf), repeat=self._N_seqTF)
#no_consecutive={idx for idx, t in enumerate(self._TF) if t in {'FlipUD', 'FlipLR'}} #Specific No consecutive ops
no_consecutive={idx for idx, t in enumerate(self._TF) if t not in {'Identity'}} #No consecutive same ops (except Identity)
cons_test = (lambda i, idxs: i in no_consecutive and len(idxs)!=0 and i==idxs[-1]) #Exclude selected consecutive
def generate_TF_sets(n_TF, set_size, idx_prefix=[]): #Generate every arrangement (with reuse) of TF (exclude cons_test arrangement)
TF_sets=[]
if set_size>1:
for i in range(n_TF):
if not cons_test(i, idx_prefix):
TF_sets += generate_TF_sets(n_TF, set_size=set_size-1, idx_prefix=idx_prefix+[i])
else:
TF_sets+=[[idx_prefix+[i]] for i in range(n_TF) if not cons_test(i, idx_prefix)]
return TF_sets
self._TF_sets=torch.ByteTensor(generate_TF_sets(self._nb_tf, self._N_seqTF)).squeeze()
self._nb_TF_sets=len(self._TF_sets)
print("Number of TF sets:",self._nb_TF_sets)
#print(self._TF_sets)
self._prob_mem=torch.zeros(self._nb_TF_sets)
#Params
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
self._params = nn.ParameterDict({
#"prob": nn.Parameter(torch.ones(self._nb_TF_sets)/self._nb_TF_sets), #Distribution prob uniforme
"prob": nn.Parameter(torch.ones(self._nb_TF_sets)),
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
"temp": nn.Parameter(torch.tensor(temp))#.clamp(min=0.0,max=0.999))
})
#for tf in TF.TF_no_grad :
# if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
#Mag regularisation
if not self._fixed_mag:
if self._shared_mag :
self._reg_tgt = torch.FloatTensor(TF.PARAMETER_MAX) #Encourage amplitude max
else:
self._reg_mask=[idx for idx,t in enumerate(self._TF) if t not in self._TF_ignore_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
def forward(self, x):
""" Main method of the Data augmentation module.
Args:
x (Tensor): Batch of data.
Returns:
Tensor : Batch of tranformed data.
"""
self._samples = None
if self._data_augmentation:# and TF.random.random() < 0.5:
device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
## Echantillonage ##
# uniforme_dist = torch.ones(1,self._nb_TF_sets,device=device).softmax(dim=1)
# if not self._temp:
# self._distrib = uniforme_dist
# else:
# prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
# prob = F.softmax(prob, dim=0)
# temp = self._params["temp"].detach() if self._fixed_temp else self._params["temp"]
# self._distrib = (temp*prob+(1-temp)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_TF_sets), device=device)*self._distrib)
sample = cat_distrib.sample()
self._samples=sample
TF_samples=self._TF_sets[sample,:].to(device) #[Batch_size, TFseq]
for i in range(self._N_seqTF):
## Transformations ##
x = self.apply_TF(x, TF_samples[:,i])
return x
def apply_TF(self, x, sampled_TF):
""" Applies the sampled transformations.
Args:
x (Tensor): Batch of data.
sampled_TF (Tensor): Indexes of the TF to be applied to each element of data.
Returns:
Tensor: Batch of tranformed data.
"""
device = x.device
batch_size, channels, h, w = x.shape
smps_x=[]
for tf_idx in range(self._nb_tf):
mask = sampled_TF==tf_idx #Create selection mask
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
if smp_x.shape[0]!=0: #if there's data to TF
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
tf=self._TF[tf_idx]
#In place
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
#Out of place
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
idx= mask.nonzero()
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
x=x.scatter(dim=0, index=idx, src=smp_x)
return x
def adjust_param(self, soft=False): #Detach from gradient ?
""" Enforce limitations to the learned parameters.
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
Args:
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
"""
# if not self._fixed_prob:
# if soft :
# self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
# else:
# self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
# self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
if not self._fixed_mag:
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
if not self._fixed_temp:
self._params['temp'].data = self._params['temp'].data.clamp(min=0.0, max=0.999)
def loss_weight(self, batch_norm=True):
""" Weights for the loss.
Compute the weights for the loss of each inputs depending on wich TF was applied to them.
Should be applied to the loss before reduction.
Do not take into account the order of application of the TF. See Data_augV7.
Args:
batch_norm (bool): Wether to normalize mean of the weights. (Default: True)
Returns:
Tensor : Loss weights.
"""
if len(self._samples)==0 : return torch.tensor(1, device=self._params["prob"].device) #Pas d'echantillon = pas de ponderation
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
# prob = F.softmax(prob, dim=0)
w_loss = torch.zeros((self._samples.shape[0],self._nb_TF_sets), device=self._samples.device)
w_loss.scatter_(1, self._samples.view(-1,1), 1)
#Normalizing by mean, would lend an exact normalization but can lead to unstable behavior of probabilities.
w_loss = w_loss * prob
w_loss = torch.sum(w_loss,dim=1)
if batch_norm:
w_min = w_loss.min()
w_loss = w_loss-w_min if w_min<0 else w_loss
w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
#Normalizing by distribution is a statistical approximation of the exact normalization. It lead to more smooth probabilities evolution but will only return 1 if temp=1.
# w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
# w_loss = torch.sum(w_loss,dim=1)
# if mean_norm:
# w_loss = w_loss * prob
# w_loss = torch.sum(w_loss,dim=1)
# w_loss = w_loss/w_loss.mean() #mean(w_loss)=1
# else:
# w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
# w_loss = torch.sum(w_loss,dim=1)
return w_loss
def reg_loss(self, reg_factor=0.005):
""" Regularisation term used to learn the magnitudes.
Use an L2 loss to encourage high magnitudes TF.
Args:
reg_factor (float): Factor by wich the regularisation loss is multiplied. (default: 0.005)
Returns:
Tensor containing the regularisation loss value.
"""
if self._fixed_mag or self._fixed_prob: #Not enough DOF
return torch.tensor(0)
else:
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
mags = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
return max_mag_reg
def TF_prob(self):
""" Gives an estimation of the individual TF probabilities.
Be warry that the probability returned isn't exact. The TF distribution isn't fully represented by those.
Each probability should be taken individualy. They only represent the chance for a specific TF to be picked at least once.
Returms:
Tensor containing the single TF probabilities of applications.
"""
if torch.all(self._params['prob']!=self._prob_mem.to(self._params['prob'].device)): #Prevent recompute if originial prob didn't changed
self._prob_mem=self._params['prob'].data.detach_()
prob = F.softmax(self._params["prob"]*self._params["temp"], dim=0)
self._single_TF_prob=torch.zeros(self._nb_tf)
for idx_tf in range(self._nb_tf):
for i, t_set in enumerate(self._TF_sets):
#uni, count = np.unique(t_set, return_counts=True)
#if idx_tf in uni:
# res[idx_tf]+=self._params['prob'][i]*int(count[np.where(uni==idx_tf)])
if idx_tf in t_set:
self._single_TF_prob[idx_tf]+=prob[i]
return self._single_TF_prob
def train(self, mode=True):
""" Set the module training mode.
Args:
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
"""
#if mode is None :
# mode=self._data_augmentation
self.augment(mode=mode) #Inutile si mode=None
super(Data_augV7, self).train(mode)
return self
def eval(self):
""" Set the module to evaluation mode.
"""
return self.train(mode=False)
def augment(self, mode=True):
""" Set the augmentation mode.
Args:
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
"""
self._data_augmentation=mode
def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation
def __getitem__(self, key):
"""Access to the learnable parameters
Args:
key (string): Name of the learnable parameter to access.
Returns:
nn.Parameter.
"""
if key == 'prob': #Override prob access
return self.TF_prob()
return self._params[key]
def __str__(self):
"""Name of the module
Returns:
String containing the name of the module as well as the higher levels parameters.
"""
dist_param=''
if self._fixed_prob: dist_param+='Fx'
mag_param='Mag'
if self._fixed_mag: mag_param+= 'Fx'
if self._shared_mag: mag_param+= 'Sh'
# if not self._temp:
# return "Data_augV7(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
if self._fixed_temp:
return "Data_augV7(T%.1f%s-%dTFx%d-%s)" % (self._params['temp'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param)
else:
return "Data_augV7(T%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)

View file

@ -2,17 +2,18 @@ from utils import *
if __name__ == "__main__":
#'''
files=[
"../res/HPsearch/log/Aug_mod(Data_augV5(Mix0.5-14TFx3-Mag)-ResNet)-200 epochs (dataug:0)- 1 in_it-0.json",
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-0.json",
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-1.json",
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-2.json",
#"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
]
#files = ["../res/benchmark/CIFAR10/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,0.17,'wide_resnet50_2', str(run)) for run in range(3)]
#files = ["../res/benchmark/CIFAR10/log/Aug_mod(RandAug(14TFx%d-Mag%d)-%s)-200 epochs (dataug:0)- 0 in_it-%s.json"%(2,1,'resnet18', str(run)) for run in range(1)]
files = ["../res/benchmark/CIFAR10/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-Mag)-%s)-200 epochs (dataug:0)- 3 in_it-%s.json"%(0.5,3,'resnet18', str(run)) for run in range(1)]
'''
# files=[
# "../res/log/Aug_mod(Data_augV5(Mix0.5-18TFx3-Mag)-efficientnet-b1)-200 epochs (dataug 0)- 1 in_it__AL2.json",
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-0.json",
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-1.json",
# #"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-2.json",
# #"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
# ]
files = ["../res/benchmark/CIFAR10/log/RandAugment(N%d-M%.2f)-%s-200 epochs -%s.json"%(3,1,'resnet18', str(run)) for run in range(3)]
#files = ["../res/benchmark/CIFAR10/log/Aug_mod(RandAug(18TFx%d-Mag%d)-%s)-200 epochs (dataug:0)- 0 in_it-%s.json"%(2,1,'resnet18', str(run)) for run in range(3)]
#files = ["../res/benchmark/CIFAR10/log/Aug_mod(Data_augV5(Mix%.1f-18TFx%d-Mag)-%s)-200 epochs (dataug:0)- 1 in_it-%s.json"%(0.5,3,'resnet18', str(run)) for run in range(3)]
for idx, file in enumerate(files):
#legend+=str(idx)+'-'+file+'\n'
@ -20,7 +21,41 @@ if __name__ == "__main__":
data = json.load(json_file)
plot_resV2(data['Log'], fig_name=file.replace("/log","").replace(".json",""), param_names=data['Param_names'], f1=True)
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
#'''
'''
#Res print
# '''
nb_run=3
accs = []
aug_accs = []
f1_max = []
f1_min = []
times = []
mem = []
files = ["../res/benchmark/log/Aug_mod(Data_augV5(T0.5-19TFx3-Mag)-resnet18)-200 epochs (dataug 0)- 1 in_it__%s.json"%(str(run)) for run in range(1, nb_run+1)]
for idx, file in enumerate(files):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
accs.append(data['Accuracy'])
aug_accs.append(data['Aug_Accuracy'][1])
times.append(data['Time'][0])
mem.append(data['Memory'][1])
acc_idx = [x['acc'] for x in data['Log']].index(data['Accuracy'])
f1_max.append(max(data['Log'][acc_idx]['f1'])*100)
f1_min.append(min(data['Log'][acc_idx]['f1'])*100)
print(idx, accs[-1], aug_accs[-1])
print(files[0])
print("Acc : %.2f ~ %.2f / Aug_Acc %d: %.2f ~ %.2f"%(np.mean(accs), np.std(accs), data['Aug_Accuracy'][0], np.mean(aug_accs), np.std(aug_accs)))
print("F1 max : %.2f ~ %.2f / F1 min : %.2f ~ %.2f"%(np.mean(f1_max), np.std(f1_max), np.mean(f1_min), np.std(f1_min)))
print("Time (h): %.2f ~ %.2f"%(np.mean(times)/3600, np.std(times)/3600))
print("Mem (MB): %d ~ %d"%(np.mean(mem), np.std(mem)))
# '''
## Loss , Acc, Proba = f(epoch) ##
#plot_compare(filenames=files, fig_name="res/compare")
@ -79,37 +114,21 @@ if __name__ == "__main__":
plt.close()
'''
#Res print
'''
nb_run=3
accs = []
times = []
files = ["res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-Mag)-LeNet)-150epochs(dataug:0)-1in_it-%s.json"%str(run) for run in range(nb_run)]
for idx, file in enumerate(files):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
accs.append(data['Accuracy'])
times.append(data['Time'][0])
print(idx, data['Accuracy'])
print(files[0], np.mean(accs), np.std(accs), np.mean(times))
'''
'''
#HP search
inner_its = [1]
dist_mix = [0.3, 0.5, 0.8, 1.0] #Uniform
N_seq_TF= [5]
N_seq_TF= [3]
nb_run= 3
for n_inner_iter in inner_its:
for n_tf in N_seq_TF:
for dist in dist_mix:
#files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-MagFxSh)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(dist, n_tf, str(run)) for run in range(nb_run)]
files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Uniform-14TFx%d-MagFxSh)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(n_tf, str(run)) for run in range(nb_run)]
files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Mix%.1f-14TFx%d-Mag)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(dist, n_tf, str(run)) for run in range(nb_run)]
#files = ["../res/HPsearch/log/Aug_mod(Data_augV5(Uniform-14TFx%d-MagFxSh)-ResNet)-200 epochs (dataug:0)- 1 in_it-%s.json"%(n_tf, str(run)) for run in range(nb_run)]
accs = []
times = []
for idx, file in enumerate(files):

View file

@ -2,22 +2,14 @@
"""
import sys
from LeNet import *
from dataug import *
#from utils import *
from train_utils import *
from transformations import TF_loader
# from arg_parser import *
postfix='-MetaScheduler2'
TF_loader=TF_loader()
device = torch.device('cuda') #Select device to use
if device == torch.device('cpu'):
device_name = 'CPU'
else:
device_name = torch.cuda.get_device_name(device)
torch.backends.cudnn.benchmark = True #Faster if same input size #Not recommended for reproductibility
#Increase reproductibility
@ -27,46 +19,68 @@ np.random.seed(0)
##########################################
if __name__ == "__main__":
#Task to perform
tasks={
#'classic',
'aug_model'
}
args = parser.parse_args()
print(args)
res_folder=args.res_folder
postfix=args.postfix
if args.dtype == 'FP32':
def_type=torch.float32
elif args.dtype == 'FP16':
# def_type=torch.float16 #Default : float32
def_type=torch.bfloat16
else:
raise Exception('dtype not supported :', args.dtype)
torch.set_default_dtype(def_type) #Default : float32
device = torch.device(args.device) #Select device to use
if device == torch.device('cpu'):
device_name = 'CPU'
else:
device_name = torch.cuda.get_device_name(device)
#Parameters
n_inner_iter = 1
epochs = 100
n_inner_iter = args.K
epochs = args.epochs
dataug_epoch_start=0
Nb_TF_seq=3
Nb_TF_seq= args.N
optim_param={
'Meta':{
'optim':'Adam',
'lr':1e-3, #1e-2
'epoch_start': 2, #0 / 2 (Resnet?)
'reg_factor': 0.001,
'scheduler': 'multiStep', #None, 'multiStep'
'lr':args.mlr,
'epoch_start': args.meta_epoch_start, #0 / 2 (Resnet?)
'reg_factor': args.mag_reg,
'scheduler': None, #None, 'multiStep'
},
'Inner':{
'optim': 'SGD',
'lr':1e-1, #1e-2/1e-1 (ResNet)
'momentum':0.9, #0.9
'decay':0.0005, #0.0005
'nesterov':False, #False (True: Bad behavior w/ Data_aug)
'scheduler':'cosine', #None, 'cosine', 'multiStep', 'exponential'
'lr':args.lr, #1e-2/1e-1 (ResNet)
'momentum':args.momentum, #0.9
'weight_decay':args.decay, #0.0005
'nesterov':args.nesterov, #False (True: Bad behavior w/ Data_aug)
'scheduler': args.scheduler, #None, 'cosine', 'multiStep', 'exponential'
'warmup':{
'multiplier': args.warmup, #2 #+ batch_size => + mutliplier #No warmup = 0
'epochs': 5
}
}
}
#Models
#model = LeNet(3,10)
#model = ResNet(num_classes=10)
import torchvision.models as models
#model=models.resnet18()
model_name = 'resnet18' #'wide_resnet50_2' #'resnet18' #str(model)
model = getattr(models.resnet, model_name)(pretrained=False, num_classes=len(dl_train.dataset.classes))
#Info params
F1=True
sample_save=None
print_f= epochs/4
#Load network
model, model_name= load_model(args.model, num_classes=len(dl_train.dataset.classes), pretrained=args.pretrained)
#### Classic ####
if 'classic' in tasks:
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
torch.cuda.reset_max_memory_cached() #reset_peak_stats
if not args.augment:
if device_name != 'CPU':
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
torch.cuda.reset_max_memory_cached() #reset_peak_stats
t0 = time.perf_counter()
model = model.to(device)
@ -74,12 +88,17 @@ if __name__ == "__main__":
print("{} on {} for {} epochs{}".format(model_name, device_name, epochs, postfix))
#print("RandAugment(N{}-M{:.2f})-{} on {} for {} epochs{}".format(rand_aug['N'],rand_aug['M'],model_name, device_name, epochs, postfix))
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10)
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=print_f)
#log= train_classic_higher(model=model, epochs=epochs)
exec_time=time.perf_counter() - t0
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
if device_name != 'CPU':
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
else:
max_allocated = 0.0
max_cached=0.0
####
print('-'*9)
times = [x["time"] for x in log]
@ -94,7 +113,7 @@ if __name__ == "__main__":
filename = "{}-{} epochs".format(model_name,epochs)+postfix
#print("RandAugment-",model_name,": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
#filename = "RandAugment(N{}-M{:.2f})-{}-{} epochs".format(rand_aug['N'],rand_aug['M'],model_name,epochs)+postfix
with open("../res/log/%s.json" % filename, "w+") as f:
with open(res_folder+"log/%s.json" % filename, "w+") as f:
try:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
@ -103,7 +122,7 @@ if __name__ == "__main__":
print(sys.exc_info()[1])
try:
plot_resV2(log, fig_name="../res/"+filename)
plot_resV2(log, fig_name=res_folder+filename, f1=F1)
except:
print("Failed to plot res")
print(sys.exc_info()[1])
@ -112,55 +131,63 @@ if __name__ == "__main__":
print('-'*9)
#### Augmented Model ####
if 'aug_model' in tasks:
tf_config='../config/invScale_wide_tf_config.json'#'../config/base_tf_config.json'
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(tf_config)
else:
# tf_config='../config/invScale_wide_tf_config.json'#'../config/invScale_wide_tf_config.json'#'../config/base_tf_config.json'
tf_dict, tf_ignore_mag =TF_loader.load_TF_dict(args.tf_config)
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
torch.cuda.reset_max_memory_cached() #reset_peak_stats
if device_name != 'CPU':
torch.cuda.reset_max_memory_allocated() #reset_peak_stats
torch.cuda.reset_max_memory_cached() #reset_peak_stats
t0 = time.perf_counter()
model = Higher_model(model, model_name) #run_dist_dataugV3
dataug_mod = 'Data_augV8' if args.learn_seq else 'Data_augV5'
if n_inner_iter !=0:
aug_model = Augmented_model(
Data_augV5(TF_dict=tf_dict,
globals()[dataug_mod](TF_dict=tf_dict,
N_TF=Nb_TF_seq,
mix_dist=0.5,
temp=args.temp,
fixed_prob=False,
fixed_mag=False,
shared_mag=False,
fixed_mag=args.fixed_mag,
shared_mag=args.shared_mag,
TF_ignore_mag=tf_ignore_mag), model).to(device)
else:
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=Nb_TF_seq), model).to(device)
print("{} on {} for {} epochs - {} inner_it{}".format(str(aug_model), device_name, epochs, n_inner_iter, postfix))
log= run_dist_dataugV3(model=aug_model,
log, aug_acc = run_dist_dataugV3(model=aug_model,
epochs=epochs,
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
print_freq=20,
unsup_loss=1,
hp_opt=False,
save_sample_freq=None)
augment_loss=args.augment_loss,
hp_opt=False, #False #['lr', 'momentum', 'weight_decay']
print_freq=print_f,
save_sample_freq=sample_save)
exec_time=time.perf_counter() - t0
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
if device_name != 'CPU':
max_allocated = torch.cuda.max_memory_allocated()/(1024.0 * 1024.0)
max_cached = torch.cuda.max_memory_cached()/(1024.0 * 1024.0) #torch.cuda.max_memory_reserved() #MB
else:
max_allocated = 0.0
max_cached = 0.0
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]),
"Aug_Accuracy": [args.augment_loss, aug_acc],
"Time": (np.mean(times),np.std(times), exec_time),
'Optimizer': optim_param,
"Device": device_name,
"Memory": [max_allocated, max_cached],
"TF_config": tf_config,
"TF_config": args.tf_config,
"Param_names": aug_model.TF_names(),
"Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)+postfix
with open("../res/log/%s.json" % filename, "w+") as f:
print(str(aug_model),": acc", out["Accuracy"], "/ aug_acc", out["Aug_Accuracy"][1] , "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{}_epochs-{}_in_it-AL{}".format(str(aug_model),epochs,n_inner_iter,args.augment_loss)+postfix
with open(res_folder+"log/%s.json" % filename, "w+") as f:
try:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
@ -168,7 +195,7 @@ if __name__ == "__main__":
print("Failed to save logs :",f.name)
print(sys.exc_info()[1])
try:
plot_resV2(log, fig_name="../res/"+filename, param_names=aug_model.TF_names())
plot_resV2(log, fig_name=res_folder+filename, param_names=aug_model.TF_names(), f1=F1)
except:
print("Failed to plot res")
print(sys.exc_info()[1])

View file

@ -1,29 +1,35 @@
""" Utilities function for training.
"""
import sys
import torch
#import torch.optim
import torchvision
#import torchvision
import higher
import higher_patch
from datasets import *
from utils import *
from transformations import Normalizer, translate, zero_stack
norm = Normalizer(MEAN, STD)
confmat = ConfusionMatrix(num_classes=len(dl_test.dataset.classes))
def test(model):
max_grad = 1 #Max gradient value #Limite catastrophic drop
def test(model, augment=0):
"""Evaluate a model on test data.
Args:
model (nn.Module): Model to test.
augment (int): Number of augmented example for each sample. (Default : 0)
Returns:
(float, Tensor) Returns the accuracy and F1 score of the model.
"""
device = next(model.parameters()).device
model.eval()
# model['model']['functional'].set_mode('mixed') #ABN
#for i, (features, labels) in enumerate(dl_test):
# features,labels = features.to(device), labels.to(device)
@ -34,12 +40,30 @@ def test(model):
correct = 0
total = 0
#loss = []
global confmat
confmat.reset()
with torch.no_grad():
for features, labels in dl_test:
features,labels = features.to(device), labels.to(device)
outputs = model(features)
if augment>0: #Test Time Augmentation
model.augment(True)
# V2
features=torch.cat([features for _ in range(augment)], dim=0) # (B,C,H,W)=>(B*augment,C,H,W)
outputs=model(features)
outputs=torch.cat([o.unsqueeze(dim=0) for o in outputs.chunk(chunks=augment, dim=0)],dim=0) # (B*augment,nb_class)=>(augment,B,nb_class)
w_losses=model['data_aug'].loss_weight(batch_norm=False) #(B*augment) if Dataug
if w_losses.shape[0]==1: #RandAugment
outputs=torch.sum(outputs, axis=0)/augment #mean
else: #Dataug
w_losses=torch.cat([w.unsqueeze(dim=0) for w in w_losses.chunk(chunks=augment, dim=0)], dim=0) #(augment, B)
w_losses = w_losses / w_losses.sum(axis=0, keepdim=True) #sum(w_losses)=1 pour un même echantillons
outputs=torch.sum(outputs*w_losses.unsqueeze(dim=2).expand_as(outputs), axis=0)/augment #Weighted mean
else:
outputs = model(features)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
@ -74,10 +98,12 @@ def compute_vaLoss(model, dl_it, dl):
xs, ys = next(dl_it)
xs, ys = xs.to(device), ys.to(device)
model.eval() #Validation sans transfornations !
return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys)
model.eval() #Validation sans transformations !
# model['model']['functional'].set_mode('mixed') #ABN
# return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys)
return F.cross_entropy(model(xs), ys)
def mixed_loss(xs, ys, model, unsup_factor=1):
def mixed_loss(xs, ys, model, unsup_factor=1, augment=1):
"""Evaluate a model on a batch of data.
Compute a combinaison of losses:
@ -94,40 +120,64 @@ def mixed_loss(xs, ys, model, unsup_factor=1):
ys (Tensor): Batch of labels.
model (nn.Module): Augmented model (see dataug.py).
unsup_factor (float): Factor by which unsupervised CE and KL div loss are multiplied.
augment (int): Number of augmented example for each sample. (Default : 1)
Returns:
(Tensor) Mixed loss if there's data augmentation, just supervised CE loss otherwise.
"""
#TODO: add test to prevent augmented model error and redirect to classic loss
if unsup_factor!=0 and model.is_augmenting():
if unsup_factor!=0 and model.is_augmenting() and augment>0:
# Supervised loss (classic)
# Supervised loss - Cross-entropy
model.augment(mode=False)
sup_logits = model(xs)
model.augment(mode=True)
log_sup = F.log_softmax(sup_logits, dim=1)
sup_loss = F.cross_entropy(log_sup, ys)
sup_loss = F.nll_loss(log_sup, ys)
# sup_loss = F.cross_entropy(log_sup, ys)
# Unsupervised loss
aug_logits = model(xs)
w_loss = model['data_aug'].loss_weight() #Weight loss
log_aug = F.log_softmax(aug_logits, dim=1)
aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
aug_loss = (aug_loss * w_loss).mean()
if augment>1:
# Unsupervised loss - Cross-Entropy
xs_a=torch.cat([xs for _ in range(augment)], dim=0) # (B,C,H,W)=>(B*augment,C,H,W)
ys_a=torch.cat([ys for _ in range(augment)], dim=0)
aug_logits=model(xs_a) # (B*augment,nb_class)
w_loss=model['data_aug'].loss_weight() #(B*augment) if Dataug
log_aug = F.log_softmax(aug_logits, dim=1)
aug_loss = F.nll_loss(log_aug, ys_a , reduction='none')
# aug_loss = F.cross_entropy(log_aug, ys_a , reduction='none')
aug_loss = (aug_loss * w_loss).mean()
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
kl_loss = (w_loss * kl_loss).mean()
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
sup_logits_a=torch.cat([sup_logits for _ in range(augment)], dim=0)
log_sup_a=torch.cat([log_sup for _ in range(augment)], dim=0)
kl_loss = (F.softmax(sup_logits_a, dim=1)*(log_sup_a-log_aug)).sum(dim=-1)
kl_loss = (w_loss * kl_loss).mean()
else:
# Unsupervised loss - Cross-Entropy
aug_logits = model(xs)
w_loss = model['data_aug'].loss_weight() #Weight loss
log_aug = F.log_softmax(aug_logits, dim=1)
aug_loss = F.nll_loss(log_aug, ys , reduction='none')
# aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
aug_loss = (aug_loss * w_loss).mean()
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
kl_loss = (w_loss * kl_loss).mean()
loss = sup_loss + unsup_factor * (aug_loss + kl_loss)
else: #Supervised loss (classic)
else: #Supervised loss - Cross-Entropy
sup_logits = model(xs)
log_sup = F.log_softmax(sup_logits, dim=1)
loss = F.cross_entropy(log_sup, ys)
loss = F.cross_entropy(sup_logits, ys)
# log_sup = F.log_softmax(sup_logits, dim=1)
# loss = F.cross_entropy(log_sup, ys)
return loss
@ -150,7 +200,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
optim = torch.optim.SGD(model.parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['decay'],
weight_decay=opt_param['Inner']['weight_decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
#Scheduler
@ -168,6 +218,9 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
elif opt_param['Inner']['scheduler'] is not None:
raise ValueError("Lr scheduler unknown : %s"%opt_param['Inner']['scheduler'])
# from warmup_scheduler import GradualWarmupScheduler
# inner_scheduler=GradualWarmupScheduler(optim, multiplier=2, total_epoch=5, after_scheduler=inner_scheduler)
#Training
model.train()
dl_val_it = iter(dl_val)
@ -188,9 +241,12 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
loss.backward()
optim.step()
# print_graph(loss, '../samples/torchvision_WRN') #to visualize computational graph
# sys.exit()
if inner_scheduler is not None:
inner_scheduler.step()
# print(optim.param_groups[0]['lr'])
#### Tests ####
tf = time.perf_counter()
@ -222,7 +278,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
return log
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, print_freq=1, unsup_loss=1, hp_opt=False, save_sample_freq=None):
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, unsup_loss=1, augment_loss=1, hp_opt=False, print_freq=1, save_sample_freq=None):
"""Training of an augmented model with higher.
This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
@ -237,9 +293,10 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
epochs (int): Number of epochs to perform. (default: 1)
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
dataug_epoch_start (int): Epoch when to start data augmentation. (default: 0)
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
augment_loss (int): Number of augmented example for each sample in loss computation. (Default : 1)
hp_opt (bool): Wether to learn inner optimizer parameters. (default: False)
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, no sample will be saved. (default: None)
Returns:
@ -247,6 +304,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
"""
device = next(model.parameters()).device
log = []
# kl_log={"prob":[], "mag":[]}
dl_val_it = iter(dl_val)
high_grad_track = True
@ -261,12 +319,12 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
inner_opt = torch.optim.SGD(model['model']['original'].parameters(),
lr=opt_param['Inner']['lr'],
momentum=opt_param['Inner']['momentum'],
weight_decay=opt_param['Inner']['decay'],
weight_decay=opt_param['Inner']['weight_decay'],
nesterov=opt_param['Inner']['nesterov']) #lr=1e-2 / momentum=0.9
diffopt = model['model'].get_diffopt(
inner_opt,
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
grad_callback=(lambda grads: clip_norm(grads, max_norm=max_grad)),
track_higher_grads=high_grad_track)
#Scheduler
@ -281,23 +339,33 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
elif opt_param['Inner']['scheduler']=='exponential':
#inner_scheduler=torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.1) #Wrong gamma
inner_scheduler=torch.optim.lr_scheduler.LambdaLR(inner_opt, lambda epoch: (1 - epoch / epochs) ** 0.9)
elif opt_param['Inner']['scheduler'] is not None:
elif not(opt_param['Inner']['scheduler'] is None or opt_param['Inner']['scheduler']==''):
raise ValueError("Lr scheduler unknown : %s"%opt_param['Inner']['scheduler'])
#Warmup
if opt_param['Inner']['warmup']['multiplier']>=1:
from warmup_scheduler import GradualWarmupScheduler
inner_scheduler=GradualWarmupScheduler(inner_opt,
multiplier=opt_param['Inner']['warmup']['multiplier'],
total_epoch=opt_param['Inner']['warmup']['epochs'],
after_scheduler=inner_scheduler)
#Meta Opt
hyper_param = list(model['data_aug'].parameters())
if hp_opt :
if hp_opt : #(deprecated)
for param_group in diffopt.param_groups:
for param in list(opt_param['Inner'].keys())[1:]:
# print(param_group)
for param in hp_opt:
param_group[param]=torch.tensor(param_group[param]).to(device).requires_grad_()
hyper_param += [param_group[param]]
meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr']) #lr=1e-2
meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr'])
#Meta-Scheduler (deprecated)
meta_scheduler=None
if opt_param['Meta']['scheduler']=='multiStep':
meta_scheduler=torch.optim.lr_scheduler.MultiStepLR(meta_opt,
milestones=[int(epochs/3), int(epochs*2/3)],# int(epochs*2.7/3)],
gamma=3.16)#10)
gamma=3.16)
elif opt_param['Meta']['scheduler'] is not None:
raise ValueError("Lr scheduler unknown : %s"%opt_param['Meta']['scheduler'])
@ -314,7 +382,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
if(unsup_loss==0):
#Methode uniforme
logits = model(xs) # modified `params` can also be passed as a kwarg
@ -327,31 +395,35 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
else:
#Methode mixed
loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss, augment=augment_loss)
#print_graph(loss) #to visualize computational graph
# print_graph(loss, '../samples/pytorch_WRN') #to visualize computational graph
# sys.exit()
#t = time.process_time()
# t = time.process_time()
diffopt.step(loss)#(opt.zero_grad, loss.backward, opt.step)
#print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
# print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
if(high_grad_track and i>0 and i%inner_it==0 and epoch>=opt_param['Meta']['epoch_start']): #Perform Meta step
#print("meta")
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss(opt_param['Meta']['reg_factor'])
model.train()
#print_graph(val_loss) #to visualize computational graph
val_loss.backward()
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=max_grad, norm_type=2) #Prevent exploding grad with RNN
# print("Grad mix",model['data_aug']["temp"].grad)
# prv_param=model['data_aug']._params
meta_opt.step()
# kl_log["prob"].append(F.kl_div(prv_param["prob"],model['data_aug']["prob"], reduction='batchmean').item())
# kl_log["mag"].append(F.kl_div(prv_param["mag"],model['data_aug']["mag"], reduction='batchmean').item())
#Adjust Hyper-parameters
model['data_aug'].adjust_param() #Contrainte sum(proba)=1
model['data_aug'].adjust_param()
if hp_opt:
for param_group in diffopt.param_groups:
for param in list(opt_param['Inner'].keys())[1:]:
param_group[param].data = param_group[param].data.clamp(min=1e-4)
for param in hp_opt:
param_group[param].data = param_group[param].data.clamp(min=1e-5)
#Reset gradients
diffopt.detach_()
@ -374,18 +446,26 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
diff_param_group['lr'] = param_group['lr']
if meta_scheduler is not None:
meta_scheduler.step()
# if epoch<epochs/3:
# model['data_aug']['temp'].data=torch.tensor(0.5, device=device)
# elif epoch>epochs/3 and epoch<(epochs*2/3):
# model['data_aug']['temp'].data=torch.tensor(0.75, device=device)
# elif epoch>(epochs*2/3):
# model['data_aug']['temp'].data=torch.tensor(1.0, device=device)
# model['data_aug']['temp'].data=torch.tensor(1./3+2/3*(epoch/epochs), device=device)
# print('Temp',model['data_aug']['temp'])
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
try:
viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
model.train()
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch))
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
model.eval()
except:
print("Couldn't save samples epoch"+epoch)
print("Couldn't save samples epoch %d : %s"%(epoch, str(sys.exc_info()[1])))
pass
if(not val_loss): #Compute val loss for logs
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
@ -405,8 +485,8 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
"param": param,
}
if not model['data_aug']._fixed_mix: data["mix_dist"]=model['data_aug']['mix_dist'].item()
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'].item(), 'momentum': p_grp['momentum'].item()} for p_grp in diffopt.param_groups]
if not model['data_aug']._fixed_temp: data["temp"]=model['data_aug']['temp'].item()
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'], 'momentum': p_grp['momentum']} for p_grp in diffopt.param_groups]
log.append(data)
#############
#### Print ####
@ -420,11 +500,19 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
if not model['data_aug']._fixed_prob: print('TF Proba :', ["{0:0.4f}".format(p) for p in model['data_aug']['prob']])
#print('proba grad',model['data_aug']['prob'].grad)
if not model['data_aug']._fixed_mag: print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']])
if not model['data_aug']._fixed_mag:
if model['data_aug']._shared_mag:
print('TF Mag :', "{0:0.4f}".format(model['data_aug']['mag']))
else:
print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']])
#print('Mag grad',model['data_aug']['mag'].grad)
if not model['data_aug']._fixed_mix: print('Mix:', model['data_aug']['mix_dist'].item())
if not model['data_aug']._fixed_temp: print('Temp:', model['data_aug']['temp'].item())
#print('Reg loss:', model['data_aug'].reg_loss().item())
# if len(kl_log["prob"])!=0:
# print("KL prob : mean %f, std %f, max %f, min %f"%(np.mean(kl_log["prob"]), np.std(kl_log["prob"]), max(kl_log["prob"]), min(kl_log["prob"])))
# print("KL mag : mean %f, std %f, max %f, min %f"%(np.mean(kl_log["mag"]), np.std(kl_log["mag"]), max(kl_log["mag"]), min(kl_log["mag"])))
# kl_log={"prob":[], "mag":[]}
if hp_opt :
for param_group in diffopt.param_groups:
print('Opt param - lr:', param_group['lr'].item(),'- momentum:', param_group['momentum'].item())
@ -439,63 +527,66 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
high_grad_track = True
diffopt = model['model'].get_diffopt(
inner_opt,
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
grad_callback=(lambda grads: clip_norm(grads, max_norm=max_grad)),
track_higher_grads=high_grad_track)
return log
aug_acc, aug_f1 = test(model, augment=augment_loss)
def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1):
"""Simple training of an augmented model with higher.
return log, aug_acc
This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
Ex : Augmented_model(Data_augV5(...), Higher_model(model))
#OLD
# def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1):
# """Simple training of an augmented model with higher.
Training loss can either be computed directly from augmented inputs (unsup_loss=0).
However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
# This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
# Ex : Augmented_model(Data_augV5(...), Higher_model(model))
Does not support LR scheduler.
# Training loss can either be computed directly from augmented inputs (unsup_loss=0).
# However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
Args:
model (nn.Module): Augmented model to train.
opt_param (dict): Dictionnary containing optimizers parameters.
epochs (int): Number of epochs to perform. (default: 1)
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
# Does not support LR scheduler.
# Args:
# model (nn.Module): Augmented model to train.
# opt_param (dict): Dictionnary containing optimizers parameters.
# epochs (int): Number of epochs to perform. (default: 1)
# inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
# print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
# unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
Returns:
(dict) A dictionary containing a whole state of the trained network.
"""
device = next(model.parameters()).device
# Returns:
# (dict) A dictionary containing a whole state of the trained network.
# """
# device = next(model.parameters()).device
## Optimizers ##
hyper_param = list(model['data_aug'].parameters())
model.start_bilevel_opt(inner_it=inner_it, hp_list=hyper_param, opt_param=opt_param, dl_val=dl_val)
# ## Optimizers ##
# hyper_param = list(model['data_aug'].parameters())
# model.start_bilevel_opt(inner_it=inner_it, hp_list=hyper_param, opt_param=opt_param, dl_val=dl_val)
model.train()
# model.train()
for epoch in range(1, epochs+1):
t0 = time.process_time()
# for epoch in range(1, epochs+1):
# t0 = time.process_time()
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
# for i, (xs, ys) in enumerate(dl_train):
# xs, ys = xs.to(device), ys.to(device)
#Methode mixed
loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
# #Methode mixed
# loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
model.step(loss) #(opt.zero_grad, loss.backward, opt.step) + automatic meta-optimisation
# model.step(loss) #(opt.zero_grad, loss.backward, opt.step) + automatic meta-optimisation
tf = time.process_time()
# tf = time.process_time()
#### Print ####
if(print_freq and epoch%print_freq==0):
print('-'*9)
print('Epoch : %d/%d'%(epoch,epochs))
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', model.val_loss().item())
if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
if not model['data_aug']._fixed_mix: print('Mix:', model['data_aug']['mix_dist'].item())
#############
# #### Print ####
# if(print_freq and epoch%print_freq==0):
# print('-'*9)
# print('Epoch : %d/%d'%(epoch,epochs))
# print('Time : %.00f'%(tf - t0))
# print('Train loss :',loss.item(), '/ val loss', model.val_loss().item())
# if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
# if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
# if not model['data_aug']._fixed_temp: print('Temp:', model['data_aug']['temp'].item())
# #############
return model['model'].state_dict()
# return model['model'].state_dict()

View file

@ -21,8 +21,8 @@ import json
#TF that don't have use for magnitude parameter.
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend', 'identity', 'flip'}
#TF which implemetation doesn't allow gradient propagaition.
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize', 'posterize','solarize'}
#TF which implemetation doesn't allow gradient propagaition.
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize', 'posterize','solarize'} #Numpy implementation would be better ?
#TF for which magnitude should be ignored (Magnitude fixed).
TF_ignore_mag= TF_no_mag | TF_no_grad
@ -38,6 +38,17 @@ PARAMETER_MIN = 0.01
# PARAMETER_MIN:{'Rotate','TranslateX','TranslateY','ShearX','ShearY'},
#}
class Normalizer(object):
def __init__(self, mean, std):
self.mean=torch.tensor(mean)
self.std=torch.tensor(std)
def __call__(self, x):
# return x.sub_(self.mean.to(x.device)[None, :, None, None]).div_(self.std.to(x.device)[None, :, None, None])
return kornia.color.normalize(x, self.mean, self.std)
def reverse(self, x):
# return x.mul_(self.std.to(x.device)[None, :, None, None]).add_(self.mean.to(x.device)[None, :, None, None])
return kornia.color.denormalize(x, self.mean, self.std).to(torch.half).to(torch.float)
class TF_loader(object):
""" Transformations builder.
@ -91,13 +102,15 @@ class TF_loader(object):
else:
raise Exception("Unknown TF axis : %s in %s"%(tf['function'], self._filename))
elif tf['function'] in {'translate', 'shear'}:
rand_fct= 'invScale_rand_floats' if tf['param']['invScale'] else 'rand_floats'
self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'], tf['param']['absolute'], tf['param']['axis'])
# elif tf['function'] in {'translate', 'shear'}:
# rand_fct= 'invScale_rand_floats' if tf['param']['invScale'] else 'rand_floats'
# self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'], tf['param']['absolute'], tf['param']['axis'])
else:
axis = tf['param']['axis'] if 'axis' in tf['param'].keys() else None
absolute = tf['param']['absolute'] if 'absolute' in tf['param'].keys() else True
rand_fct= 'invScale_rand_floats' if tf['param']['invScale'] else 'rand_floats'
self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'])
self._TF_dict[tf['name']]=self.build_lambda(tf['function'], rand_fct, tf['param']['min'], tf['param']['max'], absolute, axis)
return self._TF_dict, self._TF_ignore_mag
@ -130,7 +143,7 @@ class TF_loader(object):
size=x.shape[0],
mag=mag,
minval=minval,
maxval=maxval)))
maxval=max_val_fct(max(x.shape[2],x.shape[3])))))
elif axis =='X':
return (lambda x, mag:
globals()[fct_name](
@ -419,6 +432,46 @@ def posterize(x, bits):
return float_image(x & mask)
def cutout(img, length):
"""
Args:
img (Tensor): Batch of images. Expect image value between [0, 1].
length (Tensor): The length (in pixels) of each square patch.
Returns:
Tensor: Images with single holes of dimension length x length cut out of it.
"""
device = img.device
(batch_size, channels, h, w) = img.shape
# mask = np.ones((h, w), np.float32)
mask = torch.ones((batch_size, h, w), device=device)
length=length.type(torch.uint8)
# y = np.random.randint(h)
# x = np.random.randint(w)
y = torch.randint(low=0, high=h, size=(batch_size,)).to(device)
x = torch.randint(low=0, high=w, size=(batch_size,)).to(device)
# y1 = np.clip(y - length // 2, 0, h)
# y2 = np.clip(y + length // 2, 0, h)
# x1 = np.clip(x - length // 2, 0, w)
# x2 = np.clip(x + length // 2, 0, w)
y1 = (y - length // 2).clamp(min=0, max=h)
y2 = (y + length // 2).clamp(min=0, max=h)
x1 = (x - length // 2).clamp(min=0, max=w)
x2 = (x + length // 2).clamp(min=0, max=w)
# mask[y1: y2, x1: x2] = 0.
for idx in range(batch_size): #Pas opti pour des batch
mask[idx, y1[idx]: y2[idx], x1[idx]: x2[idx]]= 0.
# mask = torch.from_numpy(mask)
# mask = mask.expand_as(img)
mask = mask.unsqueeze(dim=1).expand_as(img)
img = img * mask
return img
import torch.nn.functional as F
def solarize(x, thresholds):
"""Invert all pixel values above a threshold.

View file

@ -3,17 +3,88 @@
"""
import numpy as np
import json, math, time, os
import matplotlib
matplotlib.use('Agg') #https://stackoverflow.com/questions/4706451/how-to-save-a-figure-remotely-with-pylab
import matplotlib.pyplot as plt
import copy
import gc
from torchviz import make_dot
import torch
import torch.nn.functional as F
import time
from nets.LeNet import *
from nets.wideresnet import *
from nets.wideresnet_cifar import *
import nets.resnet_abn as resnet_abn
import nets.resnet_deconv as resnet_DC
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import url_map as EfficientNet_map
import torchvision.models as models
def load_model(model, num_classes, pretrained=False):
if model in models.resnet.__all__ :
model_name = model #'resnet18' #'resnet34' #'wide_resnet50_2'
if pretrained :
print("Using pretrained weights")
model = getattr(models.resnet, model_name)(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
else:
model = getattr(models.resnet, model_name)(pretrained=False, num_classes=num_classes)
elif model in models.vgg.__all__ :
model_name = model #'vgg11', 'vgg1_bn'
if pretrained :
print("Using pretrained weights")
model = getattr(models.vgg, model_name)(pretrained=True)
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
else :
model = getattr(models.vgg, model_name)(pretrained=False, num_classes=num_classes)
elif model in models.densenet.__all__ :
model_name = model #'densenet121' #'densenet201'
if pretrained :
print("Using pretrained weights")
model = getattr(models.densenet, model_name)(pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, num_classes)
else:
model = getattr(models.densenet, model_name)(pretrained=False, num_classes=num_classes)
elif model == 'LeNet':
if pretrained :
print("Pretrained weights not available")
model = LeNet(3,num_classes)
model_name=str(model)
elif model == 'WideResNet':
if pretrained :
print("Pretrained weights not available")
# model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_classes)
# model = WideResNet(16, 4, dropout_rate=0.0, num_classes=num_classes)
model = wide_resnet_cifar(26, 10, num_classes=num_classes)
# model = wide_resnet_cifar(20, 10, num_classes=num_classes)
model_name=str(model)
elif model in EfficientNet_map.keys():
model_name=model # efficientnet-b0 , efficientnet-b1, efficientnet-b4
if pretrained: #ImageNet ou Advprop (Meilleurs perf normalement mais normalisation differentes)
print("Using pretrained weights")
model = EfficientNet.from_pretrained(model_name, advprop=False)
else:
model = EfficientNet.from_name(model_name)
elif model in resnet_abn.__all__ :
if pretrained :
print("Pretrained weights not available")
model_name=model
model = getattr(resnet_abn, model_name)(pretrained=False, num_classes=num_classes)
elif model in resnet_DC.__all__:
if pretrained :
print("Pretrained weights not available")
model_name = model
model = getattr(resnet_DC, model_name)(num_classes=num_classes)
else:
raise Exception('Unknown model')
return model, model_name
class ConfusionMatrix(object):
""" Confusion matrix.
@ -120,7 +191,8 @@ class ConfusionMatrix(object):
f1=f1.mean()
return f1
def print_graph(PyTorch_obj, fig_name='graph'):
#from torchviz import make_dot
def print_graph(PyTorch_obj=torch.randn(1, 3, 32, 32), fig_name='graph'):
"""Save the computational graph.
Args:
@ -128,7 +200,7 @@ def print_graph(PyTorch_obj, fig_name='graph'):
fig_name (string): Relative path where to save the graph. (default: graph)
"""
graph=make_dot(PyTorch_obj)
graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
graph.format = 'png' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
graph.render(fig_name)
def plot_resV2(log, fig_name='res', param_names=None, f1=True):
@ -157,8 +229,13 @@ def plot_resV2(log, fig_name='res', param_names=None, f1=True):
#'''
#print(log[0]["f1"])
if isinstance(log[0]["f1"], list):
for c in range(len(log[0]["f1"])):
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
if len(log[0]["f1"])>10:
print("Plotting results : Too many class for F1, plotting only min/max")
ax[1, 0].plot(epochs,[max(x["f1"])*100 for x in log], label='F1-Max', ls='--')
ax[1, 0].plot(epochs,[min(x["f1"])*100 for x in log], label='F1-Min', ls='--')
else:
for c in range(len(log[0]["f1"])):
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
else:
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1', ls='--')
#'''
@ -251,7 +328,6 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
fig_name (string): Relative path where to save the graph. (default: data_sample)
weight_labels (Tensor): Weights associated to each labels. (default: None)
"""
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
plt.figure(figsize=(10,10))
@ -262,7 +338,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
plt.grid(False)
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
label = str(labels[i].item())
if weight_labels is not None : label+= (" - p %.2f" % weight_labels[i].item())
if torch.is_tensor(weight_labels): label+= (" - p %.2f" % weight_labels[i].item())
plt.xlabel(label)
plt.savefig(fig_name)
@ -348,10 +424,13 @@ def clip_norm(tensors, max_norm, norm_type=2):
else:
total_norm = 0
for t in tensors:
if t is None:
continue
param_norm = t.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef >= 1:
return tensors
return [t.mul(clip_coef) for t in tensors]
#return [t.mul(clip_coef) for t in tensors]
return [t if t is None else t.mul(clip_coef) for t in tensors]