smart_augmentation/Old/salvador/train.py

382 lines
13 KiB
Python
Raw Normal View History

2019-12-12 16:38:13 -05:00
import datetime
import os
import time
import sys
import torch
import torch.utils.data
from torch import nn
import torchvision
from torchvision import transforms
from PIL import ImageEnhance
import random
import utils
from fastprogress import master_bar, progress_bar
import numpy as np
## DATA AUG ##
import higher
from dataug import *
from dataug_utils import *
tf_names = [
## Geometric TF ##
'Identity',
'FlipUD',
'FlipLR',
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
## Color TF (Expect image in the range of [0, 1]) ##
2020-01-13 18:02:36 -05:00
#'Contrast',
#'Color',
#'Brightness',
#'Sharpness',
#'Posterize',
#'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
2019-12-12 16:38:13 -05:00
]
class Lambda(nn.Module):
"Create a layer that simply calls `func` with `x`"
def __init__(self, func):
super().__init__()
self.func=func
def forward(self, x): return self.func(x)
class SubsetSampler(torch.utils.data.SubsetRandomSampler):
def __init__(self, indices):
super().__init__(indices)
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
def sharpness(img, factor):
sharpness_factor = random.uniform(1, factor)
sharp = ImageEnhance.Sharpness(img)
sharped = sharp.enhance(sharpness_factor)
return sharped
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, master_bar, Kldiv=False):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Epoch: {}'.format(epoch)
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=master_bar):
image, target = image.to(device), target.to(device)
if not Kldiv :
output = model(image)
#output = F.log_softmax(output, dim=1)
loss = criterion(output, target) #Pas de softmax ?
else : #Consume x2 memory
model.augment(mode=False)
output = model(image)
model.augment(mode=True)
log_sup=F.log_softmax(output, dim=1)
sup_loss = F.cross_entropy(log_sup, target)
aug_output = model(image)
log_aug=F.log_softmax(aug_output, dim=1)
aug_loss=F.cross_entropy(log_aug, target)
#KL div w/ logits - Similarite predictions (distributions)
KL_loss = F.softmax(output, dim=1)*(log_sup-log_aug)
KL_loss = KL_loss.sum(dim=-1)
#KL_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
KL_loss = KL_loss.mean()
unsupp_coeff = 1
loss = sup_loss + (aug_loss + KL_loss) * unsupp_coeff
2020-01-13 18:02:36 -05:00
#print(sup_loss.item(), (aug_loss + KL_loss).item())
2019-12-12 16:38:13 -05:00
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
confmat.update(target.flatten(), output.argmax(1).flatten())
return metric_logger.loss.global_avg, confmat
def evaluate(model, criterion, data_loader, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
header = 'Test:'
missed = []
with torch.no_grad():
for i, (image, target) in metric_logger.log_every(data_loader, leave=False, header=header, parent=None):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
if target.item() != output.topk(1)[1].item():
missed.append(data_loader.dataset.imgs[data_loader.sampler.indices[i]])
confmat.update(target.flatten(), output.argmax(1).flatten())
acc1 = utils.accuracy(output, target)[0]
batch_size = image.shape[0]
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.update(loss=loss.item())
return metric_logger.loss.global_avg, missed, confmat
def get_train_valid_loader(args, augment, random_seed, valid_size=0.1, shuffle=True, num_workers=4, pin_memory=True):
"""
Utility function for loading and returning train and valid
multi-process iterators over the CIFAR-10 dataset. A sample
9x9 grid of the images can be optionally displayed.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- augment: whether to apply the data augmentation scheme
mentioned in the paper. Only applied on the train split.
- random_seed: fix seed for reproducibility.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
- shuffle: whether to shuffle the train/validation indices.
- show_sample: plot 9x9 sample grid of the dataset.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg = "[!] valid_size should be in the range [0, 1]."
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg
# normalize = transforms.Normalize(
# mean=[0.4914, 0.4822, 0.4465],
# std=[0.2023, 0.1994, 0.2010],
# )
# define transforms
if augment:
train_transform = transforms.Compose([
# transforms.ColorJitter(brightness=0.3),
# transforms.Lambda(lambda img: sharpness(img, 5)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalize,
])
valid_transform = transforms.Compose([
# transforms.ColorJitter(brightness=0.3),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalize,
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
# normalize,
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
# normalize,
])
# load the dataset
train_dataset = torchvision.datasets.ImageFolder(
root=args.data_path, transform=train_transform
)
valid_dataset = torchvision.datasets.ImageFolder(
root=args.data_path, transform=valid_transform
)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
2020-01-13 18:02:36 -05:00
np.random.seed(random_seed)
2019-12-12 16:38:13 -05:00
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) if not args.test_only else SubsetSampler(train_idx)
valid_sampler = SubsetSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=1, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
imgs = np.asarray(train_dataset.imgs)
# print('Train')
# print(imgs[train_idx])
#print('Valid')
#print(imgs[valid_idx])
tgt = [0,0]
for _, targets in train_loader:
for target in targets:
tgt[target]+=1
print("Train targets :", tgt)
tgt = [0,0]
for _, targets in valid_loader:
for target in targets:
tgt[target]+=1
print("Valid targets :", tgt)
return (train_loader, valid_loader)
def main(args):
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
#augment = True if not args.test_only else False
if not args.test_only and args.augment=='flip' : augment = True
else : augment = False
print("Augment", augment)
data_loader, data_loader_test = get_train_valid_loader(args=args, pin_memory=True, augment=augment,
num_workers=args.workers, valid_size=0.3, random_seed=999)
print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=True)
flat = list(model.children())
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(data_loader.dataset.classes)))
model = nn.Sequential(body, head)
Kldiv=False
if not args.test_only and (args.augment=='Rand' or args.augment=='RandKL'):
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
if args.augment=='RandKL': Kldiv=True
2020-01-13 18:02:36 -05:00
model['data_aug']['mag'].data = model['data_aug']['mag'].data * args.magnitude
2019-12-12 16:38:13 -05:00
print("Augmodel")
# model.fc = nn.Linear(model.fc.in_features, 2)
# import ipdb; ipdb.set_trace()
criterion = nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.SGD(
# model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
2020-01-13 18:02:36 -05:00
es = utils.EarlyStopping() if not (args.augment=='Rand' or args.augment=='RandKL') else utils.EarlyStopping(augmented_model=True)
2019-12-12 16:38:13 -05:00
if args.test_only:
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
model = model.to(device)
print('TEST')
_, missed, _ = evaluate(model, criterion, data_loader_test, device=device)
print(missed)
print('TRAIN')
_, missed, _ = evaluate(model, criterion, data_loader, device=device)
print(missed)
return
model = model.to(device)
print("Start training")
start_time = time.time()
mb = master_bar(range(args.epochs))
for epoch in mb:
_, train_confmat = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mb, Kldiv)
lr_scheduler.step( (epoch+1)*len(data_loader) )
val_loss, _, valid_confmat = evaluate(model, criterion, data_loader_test, device=device)
es(val_loss, model)
# print('Valid Missed')
# print(valid_missed)
# print('Train')
# print(train_confmat)
2020-01-13 18:02:36 -05:00
#print('Valid')
#print(valid_confmat)
2019-12-12 16:38:13 -05:00
# if es.early_stop:
# break
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
2020-01-13 18:02:36 -05:00
parser.add_argument('--data-path', default='/github/smart_augmentation/salvador/data', help='dataset')
2019-12-12 16:38:13 -05:00
parser.add_argument('--model', default='resnet18', help='model') #'resnet18'
2020-01-13 18:02:36 -05:00
parser.add_argument('--device', default='cuda:0', help='device')
2019-12-12 16:38:13 -05:00
parser.add_argument('-b', '--batch-size', default=8, type=int)
parser.add_argument('--epochs', default=3, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=4e-5, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument('-a', '--augment', default='None', type=str,
metavar='N', help='Data augment',
dest='augment')
2020-01-13 18:02:36 -05:00
parser.add_argument('-m', '--magnitude', default=1.0, type=float,
metavar='N', help='Augmentation magnitude',
dest='magnitude')
2019-12-12 16:38:13 -05:00
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)