Salvador tests

This commit is contained in:
root 2019-12-12 16:38:13 -05:00
parent 6c0597e7ea
commit aade27011a
57 changed files with 29210 additions and 0 deletions

585
salvador/train_dataug.py Normal file
View file

@ -0,0 +1,585 @@
import datetime
import os
import time
import sys
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from PIL import ImageEnhance
import random
import utils
from fastprogress import master_bar, progress_bar
import numpy as np
## DATA AUG ##
import higher
from dataug import *
from dataug_utils import *
tf_names = [
## Geometric TF ##
'Identity',
'FlipUD',
'FlipLR',
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
]
def compute_vaLoss(model, dl_it, dl):
device = next(model.parameters()).device
try:
xs, ys = next(dl_it)
except StopIteration: #Fin epoch val
dl_it = iter(dl)
xs, ys = next(dl_it)
xs, ys = xs.to(device), ys.to(device)
model.eval() #Validation sans transfornations !
return F.cross_entropy(model(xs), ys)
def model_copy(src,dst, patch_copy=True, copy_grad=True):
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
if patch_copy:
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
#Copie des gradients
if copy_grad:
for paramName, paramValue, in src.named_parameters():
for netCopyName, netCopyValue, in dst.named_parameters():
if paramName == netCopyName:
netCopyValue.grad = paramValue.grad
#netCopyValue=copy.deepcopy(paramValue)
try: #Data_augV4
dst['data_aug']._input_info = src['data_aug']._input_info
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
except:
pass
def optim_copy(dopt, opt):
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
for group_idx, group in enumerate(opt.param_groups):
# print('gp idx',group_idx)
for p_idx, p in enumerate(group['params']):
opt.state[p]=dopt.state[group_idx][p_idx]
#############
class Lambda(nn.Module):
"Create a layer that simply calls `func` with `x`"
def __init__(self, func):
super().__init__()
self.func=func
def forward(self, x): return self.func(x)
class SubsetSampler(torch.utils.data.SubsetRandomSampler):
def __init__(self, indices):
super().__init__(indices)
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
def sharpness(img, factor):
sharpness_factor = random.uniform(1, factor)
sharp = ImageEnhance.Sharpness(img)
sharped = sharp.enhance(sharpness_factor)
return sharped
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, master_bar):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Epoch: {}'.format(epoch)
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=master_bar):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
confmat.update(target.flatten(), output.argmax(1).flatten())
return metric_logger.loss.global_avg, confmat
def evaluate(model, criterion, data_loader, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Test:'
missed = []
with torch.no_grad():
for i, (image, target) in metric_logger.log_every(data_loader, leave=False, header=header, parent=None):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
if target.item() != output.topk(1)[1].item():
missed.append(data_loader.dataset.imgs[data_loader.sampler.indices[i]])
confmat.update(target.flatten(), output.argmax(1).flatten())
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
return metric_logger.loss.global_avg, missed, confmat
def get_train_valid_loader(args, augment, random_seed, train_size=0.5, test_size=0.1, shuffle=True, num_workers=4, pin_memory=True):
"""
Utility function for loading and returning train and valid
multi-process iterators over the CIFAR-10 dataset. A sample
9x9 grid of the images can be optionally displayed.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- augment: whether to apply the data augmentation scheme
mentioned in the paper. Only applied on the train split.
- random_seed: fix seed for reproducibility.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
- shuffle: whether to shuffle the train/validation indices.
- show_sample: plot 9x9 sample grid of the dataset.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg = "[!] test_size should be in the range [0, 1]."
assert ((test_size >= 0) and (test_size <= 1)), error_msg
# normalize = transforms.Normalize(
# mean=[0.4914, 0.4822, 0.4465],
# std=[0.2023, 0.1994, 0.2010],
# )
# define transforms
if augment:
train_transform = transforms.Compose([
# transforms.ColorJitter(brightness=0.3),
# transforms.Lambda(lambda img: sharpness(img, 5)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalize,
])
valid_transform = transforms.Compose([
# transforms.ColorJitter(brightness=0.3),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalize,
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
# normalize,
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
# normalize,
])
# load the dataset
train_dataset = torchvision.datasets.ImageFolder(
root=args.data_path, transform=train_transform
)
test_dataset = torchvision.datasets.ImageFolder(
root=args.data_path, transform=valid_transform
)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(test_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, test_idx = indices[split:], indices[:split]
train_idx, valid_idx = train_idx[:int(len(train_idx)*train_size)], train_idx[int(len(train_idx)*train_size):]
print("\nTrain", len(train_idx), "\nValid", len(valid_idx), "\nTest", len(test_idx))
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) if not args.test_only else SubsetSampler(train_idx)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx) if not args.test_only else SubsetSampler(valid_idx)
test_sampler = SubsetSampler(test_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=1, sampler=test_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
imgs = np.asarray(train_dataset.imgs)
# print('Train')
# print(imgs[train_idx])
#print('Valid')
#print(imgs[valid_idx])
return (train_loader, valid_loader, test_loader)
def main(args):
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
#augment = True if not args.test_only else False
augment = False
data_loader, dl_val, data_loader_test = get_train_valid_loader(args=args, pin_memory=True, augment=augment,
num_workers=args.workers, train_size=0.99, test_size=0.2, random_seed=999)
print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=True)
flat = list(model.children())
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(data_loader.dataset.classes)))
model = nn.Sequential(body, head)
# model.fc = nn.Linear(model.fc.in_features, 2)
# import ipdb; ipdb.set_trace()
criterion = nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.SGD(
# model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
'''
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
'''
es = utils.EarlyStopping()
if args.test_only:
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
model = model.to(device)
print('TEST')
_, missed, _ = evaluate(model, criterion, data_loader_test, device=device)
print(missed)
print('TRAIN')
_, missed, _ = evaluate(model, criterion, data_loader, device=device)
print(missed)
return
model = model.to(device)
print("Start training")
start_time = time.time()
mb = master_bar(range(args.epochs))
"""
for epoch in mb:
_, train_confmat = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mb)
lr_scheduler.step( (epoch+1)*len(data_loader) )
val_loss, _, valid_confmat = evaluate(model, criterion, data_loader_test, device=device)
es(val_loss, model)
# print('Valid Missed')
# print(valid_missed)
# print('Train')
# print(train_confmat)
print('Valid')
print(valid_confmat)
# if es.early_stop:
# break
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
"""
#######
inner_it = args.inner_it
dataug_epoch_start=0
print_freq=1
KLdiv=False
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
#model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
dl_val_it = iter(dl_val)
countcopy=0
#if inner_it!=0:
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=args.lr) #lr=1e-2
#inner_opt = torch.optim.SGD(model['model'].parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #lr=1e-2 / momentum=0.9
inner_opt = torch.optim.Adam(model['model'].parameters(), lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
inner_opt,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
high_grad_track = True
if inner_it == 0:
high_grad_track=False
model.train()
model.augment(mode=False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
i=0
for epoch in mb:
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Epoch: {}'.format(epoch)
t0 = time.process_time()
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=mb):
#for i, (xs, ys) in enumerate(dl_train):
#print_torch_mem("it"+str(i))
i+=1
image, target = image.to(device), target.to(device)
if(not KLdiv):
#Methode uniforme
logits = fmodel(image) # modified `params` can also be passed as a kwarg
output = F.log_softmax(logits, dim=1)
loss = F.cross_entropy(output, target, reduction='none') # no need to call loss.backwards()
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
else:
#Methode KL div
fmodel.augment(mode=False)
sup_logits = fmodel(xs)
log_sup=F.log_softmax(sup_logits, dim=1)
fmodel.augment(mode=True)
loss = F.cross_entropy(log_sup, ys)
if fmodel._data_augmentation:
aug_logits = fmodel(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
aug_loss=0
if epoch>50: #debut differe ?
#KL div w/ logits - Similarite predictions (distributions)
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
aug_loss=aug_loss.sum(dim=-1)
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
aug_loss = (w_loss * aug_loss).mean()
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
#print(aug_loss)
unsupp_coeff = 1
loss += aug_loss * unsupp_coeff
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
if(high_grad_track and i%inner_it==0): #Perform Meta step
#print("meta")
#Peu utile si high_grad_track = False
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
#print_graph(val_loss)
val_loss.backward()
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
#if epoch>50:
meta_opt.step()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
#model['data_aug'].next_TF_set()
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
confmat.update(target.flatten(), output.argmax(1).flatten())
if(not high_grad_track and (torch.cuda.memory_cached()/1024.0**2)>20000):
countcopy+=1
print_torch_mem("copy")
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
print_torch_mem("copy")
if(not high_grad_track):
countcopy+=1
print_torch_mem("end copy")
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
print_torch_mem("end copy")
tf = time.process_time()
#### Print ####
if(print_freq and epoch%print_freq==0):
print('-'*9)
print('Epoch : %d'%(epoch))
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
print('TF Proba :', model['data_aug']['prob'].data)
#print('proba grad',model['data_aug']['prob'].grad)
print('TF Mag :', model['data_aug']['mag'].data)
#print('Mag grad',model['data_aug']['mag'].grad)
#print('Reg loss:', model['data_aug'].reg_loss().item())
#print('Aug loss', aug_loss.item())
#############
#### Log ####
#print(type(model['data_aug']) is dataug.Data_augV5)
'''
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": param #if isinstance(model['data_aug'], Data_augV5)
#else [p.item() for p in model['data_aug']['prob']],
}
log.append(data)
'''
#############
train_confmat=confmat
lr_scheduler.step( (epoch+1)*len(data_loader) )
test_loss, _, test_confmat = evaluate(model, criterion, data_loader_test, device=device)
es(test_loss, model)
# print('Valid Missed')
# print(valid_missed)
# print('Train')
# print(train_confmat)
print('Test')
print(test_confmat)
# if es.early_stop:
# break
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser.add_argument('--data-path', default='/Salvador', help='dataset')
parser.add_argument('--model', default='resnet50', help='model')
parser.add_argument('--device', default='cuda:1', help='device')
parser.add_argument('-b', '--batch-size', default=4, type=int)
parser.add_argument('--epochs', default=3, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=4e-5, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument('--in_it', '--inner_it', default=0, type=int,
metavar='N', help='higher inner_it',
dest='inner_it')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)