Salvador tests

This commit is contained in:
root 2019-12-12 16:38:13 -05:00
parent 6c0597e7ea
commit aade27011a
57 changed files with 29210 additions and 0 deletions

375
salvador/train.py Normal file
View file

@ -0,0 +1,375 @@
import datetime
import os
import time
import sys
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from PIL import ImageEnhance
import random
import utils
from fastprogress import master_bar, progress_bar
import numpy as np
## DATA AUG ##
import higher
from dataug import *
from dataug_utils import *
tf_names = [
## Geometric TF ##
'Identity',
'FlipUD',
'FlipLR',
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
]
class Lambda(nn.Module):
"Create a layer that simply calls `func` with `x`"
def __init__(self, func):
super().__init__()
self.func=func
def forward(self, x): return self.func(x)
class SubsetSampler(torch.utils.data.SubsetRandomSampler):
def __init__(self, indices):
super().__init__(indices)
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
def sharpness(img, factor):
sharpness_factor = random.uniform(1, factor)
sharp = ImageEnhance.Sharpness(img)
sharped = sharp.enhance(sharpness_factor)
return sharped
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, master_bar, Kldiv=False):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Epoch: {}'.format(epoch)
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=master_bar):
image, target = image.to(device), target.to(device)
if not Kldiv :
output = model(image)
#output = F.log_softmax(output, dim=1)
loss = criterion(output, target) #Pas de softmax ?
else : #Consume x2 memory
model.augment(mode=False)
output = model(image)
model.augment(mode=True)
log_sup=F.log_softmax(output, dim=1)
sup_loss = F.cross_entropy(log_sup, target)
aug_output = model(image)
log_aug=F.log_softmax(aug_output, dim=1)
aug_loss=F.cross_entropy(log_aug, target)
#KL div w/ logits - Similarite predictions (distributions)
KL_loss = F.softmax(output, dim=1)*(log_sup-log_aug)
KL_loss = KL_loss.sum(dim=-1)
#KL_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
KL_loss = KL_loss.mean()
unsupp_coeff = 1
loss = sup_loss + (aug_loss + KL_loss) * unsupp_coeff
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
confmat.update(target.flatten(), output.argmax(1).flatten())
return metric_logger.loss.global_avg, confmat
def evaluate(model, criterion, data_loader, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Test:'
missed = []
with torch.no_grad():
for i, (image, target) in metric_logger.log_every(data_loader, leave=False, header=header, parent=None):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
if target.item() != output.topk(1)[1].item():
missed.append(data_loader.dataset.imgs[data_loader.sampler.indices[i]])
confmat.update(target.flatten(), output.argmax(1).flatten())
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
return metric_logger.loss.global_avg, missed, confmat
def get_train_valid_loader(args, augment, random_seed, valid_size=0.1, shuffle=True, num_workers=4, pin_memory=True):
"""
Utility function for loading and returning train and valid
multi-process iterators over the CIFAR-10 dataset. A sample
9x9 grid of the images can be optionally displayed.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- augment: whether to apply the data augmentation scheme
mentioned in the paper. Only applied on the train split.
- random_seed: fix seed for reproducibility.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
- shuffle: whether to shuffle the train/validation indices.
- show_sample: plot 9x9 sample grid of the dataset.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg = "[!] valid_size should be in the range [0, 1]."
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg
# normalize = transforms.Normalize(
# mean=[0.4914, 0.4822, 0.4465],
# std=[0.2023, 0.1994, 0.2010],
# )
# define transforms
if augment:
train_transform = transforms.Compose([
# transforms.ColorJitter(brightness=0.3),
# transforms.Lambda(lambda img: sharpness(img, 5)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalize,
])
valid_transform = transforms.Compose([
# transforms.ColorJitter(brightness=0.3),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalize,
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
# normalize,
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
# normalize,
])
# load the dataset
train_dataset = torchvision.datasets.ImageFolder(
root=args.data_path, transform=train_transform
)
valid_dataset = torchvision.datasets.ImageFolder(
root=args.data_path, transform=valid_transform
)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
#np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) if not args.test_only else SubsetSampler(train_idx)
valid_sampler = SubsetSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=1, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
imgs = np.asarray(train_dataset.imgs)
# print('Train')
# print(imgs[train_idx])
#print('Valid')
#print(imgs[valid_idx])
tgt = [0,0]
for _, targets in train_loader:
for target in targets:
tgt[target]+=1
print("Train targets :", tgt)
tgt = [0,0]
for _, targets in valid_loader:
for target in targets:
tgt[target]+=1
print("Valid targets :", tgt)
return (train_loader, valid_loader)
def main(args):
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
#augment = True if not args.test_only else False
if not args.test_only and args.augment=='flip' : augment = True
else : augment = False
print("Augment", augment)
data_loader, data_loader_test = get_train_valid_loader(args=args, pin_memory=True, augment=augment,
num_workers=args.workers, valid_size=0.3, random_seed=999)
print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=True)
flat = list(model.children())
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(data_loader.dataset.classes)))
model = nn.Sequential(body, head)
Kldiv=False
if not args.test_only and (args.augment=='Rand' or args.augment=='RandKL'):
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
if args.augment=='RandKL': Kldiv=True
print("Augmodel")
# model.fc = nn.Linear(model.fc.in_features, 2)
# import ipdb; ipdb.set_trace()
criterion = nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.SGD(
# model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
es = utils.EarlyStopping()
if args.test_only:
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
model = model.to(device)
print('TEST')
_, missed, _ = evaluate(model, criterion, data_loader_test, device=device)
print(missed)
print('TRAIN')
_, missed, _ = evaluate(model, criterion, data_loader, device=device)
print(missed)
return
model = model.to(device)
print("Start training")
start_time = time.time()
mb = master_bar(range(args.epochs))
for epoch in mb:
_, train_confmat = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mb, Kldiv)
lr_scheduler.step( (epoch+1)*len(data_loader) )
val_loss, _, valid_confmat = evaluate(model, criterion, data_loader_test, device=device)
es(val_loss, model)
# print('Valid Missed')
# print(valid_missed)
# print('Train')
# print(train_confmat)
print('Valid')
print(valid_confmat)
# if es.early_stop:
# break
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser.add_argument('--data-path', default='/Salvador', help='dataset')
parser.add_argument('--model', default='resnet18', help='model') #'resnet18'
parser.add_argument('--device', default='cuda:1', help='device')
parser.add_argument('-b', '--batch-size', default=8, type=int)
parser.add_argument('--epochs', default=3, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=4e-5, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument('-a', '--augment', default='None', type=str,
metavar='N', help='Data augment',
dest='augment')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)