mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Comment Confmat + Cross-Val (sans Skorch) + minor improv
This commit is contained in:
parent
385bc9977c
commit
be8491268a
4 changed files with 133 additions and 33 deletions
|
@ -15,31 +15,76 @@ import torch.nn.functional as F
|
|||
import time
|
||||
|
||||
class ConfusionMatrix(object):
|
||||
""" Confusion matrix.
|
||||
|
||||
Helps computing the confusion matrix and F1 scores.
|
||||
|
||||
Example use ::
|
||||
confmat = ConfusionMatrix(...)
|
||||
|
||||
confmat.reset()
|
||||
for data in dataset:
|
||||
...
|
||||
confmat.update(...)
|
||||
|
||||
confmat.f1_metric(...)
|
||||
|
||||
Attributes:
|
||||
num_classes (int): Number of classes.
|
||||
mat (Tensor): Confusion matrix. Filled by update method.
|
||||
"""
|
||||
def __init__(self, num_classes):
|
||||
""" Initialize ConfusionMatrix.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes.
|
||||
"""
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, a, b):
|
||||
def update(self, target, pred):
|
||||
""" Update the confusion matrix.
|
||||
|
||||
Args:
|
||||
target (Tensor): Target labels.
|
||||
pred (Tensor): Prediction.
|
||||
"""
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
|
||||
with torch.no_grad():
|
||||
k = (a >= 0) & (a < n)
|
||||
inds = n * a[k].to(torch.int64) + b[k]
|
||||
k = (target >= 0) & (target < n)
|
||||
inds = n * target[k].to(torch.int64) + pred[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
""" Reset the Confusion matrix.
|
||||
|
||||
"""
|
||||
if self.mat is not None:
|
||||
self.mat.zero_()
|
||||
|
||||
def compute(self):
|
||||
h = self.mat.float()
|
||||
acc_global = torch.diag(h).sum() / h.sum()
|
||||
acc = torch.diag(h) / h.sum(1)
|
||||
return acc_global, acc
|
||||
|
||||
def f1_metric(self, average=None):
|
||||
#https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
|
||||
""" Compute the F1 score.
|
||||
|
||||
Inspired from :
|
||||
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
|
||||
https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
|
||||
|
||||
Args:
|
||||
average (str): Type of averaging performed on the data. (Default: None)
|
||||
``None``:
|
||||
The scores for each class are returned.
|
||||
``'micro'``:
|
||||
Calculate metrics globally by counting the total true positives,
|
||||
false negatives and false positives.
|
||||
``'macro'``:
|
||||
Calculate metrics for each label, and find their unweighted
|
||||
mean. This does not take label imbalance into account.
|
||||
Return:
|
||||
Tensor containing the F1 score. It's shape is either 1, if there was averaging, or (num_classes).
|
||||
"""
|
||||
|
||||
h = self.mat.float()
|
||||
TP = torch.diag(h)
|
||||
TN = []
|
||||
|
@ -75,14 +120,6 @@ class ConfusionMatrix(object):
|
|||
f1=f1.mean()
|
||||
return f1
|
||||
|
||||
def __str__(self):
|
||||
acc_global, acc = self.compute()
|
||||
return (
|
||||
'global correct: {:.1f}\n'
|
||||
'average row correct: {}').format(
|
||||
acc_global.item() * 100,
|
||||
['{:.1f}'.format(i) for i in (acc * 100).tolist()])
|
||||
|
||||
def print_graph(PyTorch_obj, fig_name='graph'):
|
||||
"""Save the computational graph.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue