Comment Confmat + Cross-Val (sans Skorch) + minor improv

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-03 17:46:32 -05:00
parent 385bc9977c
commit be8491268a
4 changed files with 133 additions and 33 deletions

View file

@ -15,31 +15,76 @@ import torch.nn.functional as F
import time
class ConfusionMatrix(object):
""" Confusion matrix.
Helps computing the confusion matrix and F1 scores.
Example use ::
confmat = ConfusionMatrix(...)
confmat.reset()
for data in dataset:
...
confmat.update(...)
confmat.f1_metric(...)
Attributes:
num_classes (int): Number of classes.
mat (Tensor): Confusion matrix. Filled by update method.
"""
def __init__(self, num_classes):
""" Initialize ConfusionMatrix.
Args:
num_classes (int): Number of classes.
"""
self.num_classes = num_classes
self.mat = None
def update(self, a, b):
def update(self, target, pred):
""" Update the confusion matrix.
Args:
target (Tensor): Target labels.
pred (Tensor): Prediction.
"""
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
with torch.no_grad():
k = (a >= 0) & (a < n)
inds = n * a[k].to(torch.int64) + b[k]
k = (target >= 0) & (target < n)
inds = n * target[k].to(torch.int64) + pred[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
""" Reset the Confusion matrix.
"""
if self.mat is not None:
self.mat.zero_()
def compute(self):
h = self.mat.float()
acc_global = torch.diag(h).sum() / h.sum()
acc = torch.diag(h) / h.sum(1)
return acc_global, acc
def f1_metric(self, average=None):
#https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
""" Compute the F1 score.
Inspired from :
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
Args:
average (str): Type of averaging performed on the data. (Default: None)
``None``:
The scores for each class are returned.
``'micro'``:
Calculate metrics globally by counting the total true positives,
false negatives and false positives.
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account.
Return:
Tensor containing the F1 score. It's shape is either 1, if there was averaging, or (num_classes).
"""
h = self.mat.float()
TP = torch.diag(h)
TN = []
@ -75,14 +120,6 @@ class ConfusionMatrix(object):
f1=f1.mean()
return f1
def __str__(self):
acc_global, acc = self.compute()
return (
'global correct: {:.1f}\n'
'average row correct: {}').format(
acc_global.item() * 100,
['{:.1f}'.format(i) for i in (acc * 100).tolist()])
def print_graph(PyTorch_obj, fig_name='graph'):
"""Save the computational graph.