Comment Confmat + Cross-Val (sans Skorch) + minor improv

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-03 17:46:32 -05:00
parent 385bc9977c
commit be8491268a
4 changed files with 133 additions and 33 deletions

View file

@ -3,7 +3,6 @@
MNIST / CIFAR10
"""
import torch
from torch.utils.data import SubsetRandomSampler
from torch.utils.data.dataset import ConcatDataset
import torchvision
@ -72,26 +71,90 @@ data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=downloa
#Validation set size [0, 1]
#valid_size=0.1
valid_size=0.1
#train_subset_indices=range(int(len(data_train)*(1-valid_size)))
#val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train))
#train_subset_indices=range(BATCH_SIZE*10)
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
#from torch.utils.data import SubsetRandomSampler
#dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
#dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
#Cross Validation
'''
from skorch.dataset import CVSplit
cvs = CVSplit(cv=5)
import numpy as np
cvs = CVSplit(cv=valid_size, stratified=True) #Stratified =True for unbalanced dataset #ShuffleSplit
def next_CVSplit():
train_subset, val_subset = cvs(data_train)
train_subset, val_subset = cvs(data_train, y=np.array(data_train.targets))
dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
return dl_train, dl_val
dl_train, dl_val = next_CVSplit()
dl_train, dl_val = next_CVSplit()
'''
import numpy as np
from sklearn.model_selection import ShuffleSplit
from sklearn.model_selection import StratifiedShuffleSplit
class CVSplit(object):
"""Class that perform train/valid split on a dataset.
Inspired from : https://skorch.readthedocs.io/en/latest/user/dataset.html
Attributes:
_stratified (bool): Wether the split should be stratified. Recommended to be True for unbalanced dataset.
_val_size (float, int): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split.
If int, represents the absolute number of validation samples.
_data (Dataset): Dataset to split.
_targets (np.array): Targets of the dataset used if _stratified is set to True.
_cv (BaseShuffleSplit) : Scikit learn object used to split.
"""
def __init__(self, data, val_size=0.1, stratified=True):
""" Intialize CVSplit.
Args:
data (Dataset): Dataset to split.
val_size (float, int): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split.
If int, represents the absolute number of validation samples. (Default: 0.1)
stratified (bool): Wether the split should be stratified. Recommended to be True for unbalanced dataset.
"""
self._stratified=stratified
self._val_size=val_size
self._data=data
if self._stratified:
cv_cls = StratifiedShuffleSplit
self._targets= np.array(data_train.targets)
else:
cv_cls = ShuffleSplit
self._cv= cv_cls(test_size=val_size, random_state=0)
def next_split(self):
""" Get next cross-validation split.
Returns:
Train DataLoader, Validation DataLoader
"""
args=(np.arange(len(self._data)),)
if self._stratified:
args = args + (self._targets,)
idx_train, idx_valid = next(iter(self._cv.split(*args)))
train_subset = torch.utils.data.Subset(self._data, idx_train)
val_subset = torch.utils.data.Subset(self._data, idx_valid)
dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
return dl_train, dl_val
cvs = CVSplit(data_train, val_size=valid_size)
dl_train, dl_val = cvs.next_split()

View file

@ -79,7 +79,7 @@ if __name__ == "__main__":
}
#Parameters
n_inner_iter = 1
epochs = 150
epochs = 2
dataug_epoch_start=0
optim_param={
'Meta':{
@ -156,7 +156,7 @@ if __name__ == "__main__":
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
print_freq=20,
print_freq=1,
unsup_loss=1,
hp_opt=False,
save_sample_freq=None)

View file

@ -177,7 +177,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy max:', accuracy)
print('F1 :', f1)
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
#### Log ####
data={
@ -185,7 +185,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"f1": f1.cpu().numpy().tolist(),
"f1": f1.tolist(),
"time": tf - t0,
"param": None,
@ -253,7 +253,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
for epoch in range(1, epochs+1):
t0 = time.perf_counter()
dl_train, dl_val = next_CVSplit()
dl_train, dl_val = cvs.next_split()
dl_val_it = iter(dl_val)
for i, (xs, ys) in enumerate(dl_train):
@ -333,7 +333,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"f1": f1.cpu().numpy().tolist(),
"f1": f1.tolist(),
"time": tf - t0,
"param": param,
@ -349,11 +349,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy max:', max([x["acc"] for x in log]))
print('F1 :', f1)
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
if not model['data_aug']._fixed_prob: print('TF Proba :', ["{0:0.4f}".format(p) for p in model['data_aug']['prob']])
#print('proba grad',model['data_aug']['prob'].grad)
if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
if not model['data_aug']._fixed_mag: print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']])
#print('Mag grad',model['data_aug']['mag'].grad)
if not model['data_aug']._fixed_mix: print('Mix:', model['data_aug']['mix_dist'].item())
#print('Reg loss:', model['data_aug'].reg_loss().item())

View file

@ -15,31 +15,76 @@ import torch.nn.functional as F
import time
class ConfusionMatrix(object):
""" Confusion matrix.
Helps computing the confusion matrix and F1 scores.
Example use ::
confmat = ConfusionMatrix(...)
confmat.reset()
for data in dataset:
...
confmat.update(...)
confmat.f1_metric(...)
Attributes:
num_classes (int): Number of classes.
mat (Tensor): Confusion matrix. Filled by update method.
"""
def __init__(self, num_classes):
""" Initialize ConfusionMatrix.
Args:
num_classes (int): Number of classes.
"""
self.num_classes = num_classes
self.mat = None
def update(self, a, b):
def update(self, target, pred):
""" Update the confusion matrix.
Args:
target (Tensor): Target labels.
pred (Tensor): Prediction.
"""
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
with torch.no_grad():
k = (a >= 0) & (a < n)
inds = n * a[k].to(torch.int64) + b[k]
k = (target >= 0) & (target < n)
inds = n * target[k].to(torch.int64) + pred[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
""" Reset the Confusion matrix.
"""
if self.mat is not None:
self.mat.zero_()
def compute(self):
h = self.mat.float()
acc_global = torch.diag(h).sum() / h.sum()
acc = torch.diag(h) / h.sum(1)
return acc_global, acc
def f1_metric(self, average=None):
#https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
""" Compute the F1 score.
Inspired from :
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
Args:
average (str): Type of averaging performed on the data. (Default: None)
``None``:
The scores for each class are returned.
``'micro'``:
Calculate metrics globally by counting the total true positives,
false negatives and false positives.
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account.
Return:
Tensor containing the F1 score. It's shape is either 1, if there was averaging, or (num_classes).
"""
h = self.mat.float()
TP = torch.diag(h)
TN = []
@ -75,14 +120,6 @@ class ConfusionMatrix(object):
f1=f1.mean()
return f1
def __str__(self):
acc_global, acc = self.compute()
return (
'global correct: {:.1f}\n'
'average row correct: {}').format(
acc_global.item() * 100,
['{:.1f}'.format(i) for i in (acc * 100).tolist()])
def print_graph(PyTorch_obj, fig_name='graph'):
"""Save the computational graph.