mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Confmat / F1 + Minor fix
This commit is contained in:
parent
250ce2c3cf
commit
3ccacd0366
5 changed files with 120 additions and 32 deletions
|
@ -14,6 +14,75 @@ import torch.nn.functional as F
|
|||
|
||||
import time
|
||||
|
||||
class ConfusionMatrix(object):
|
||||
def __init__(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, a, b):
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||
with torch.no_grad():
|
||||
k = (a >= 0) & (a < n)
|
||||
inds = n * a[k].to(torch.int64) + b[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
if self.mat is not None:
|
||||
self.mat.zero_()
|
||||
|
||||
def compute(self):
|
||||
h = self.mat.float()
|
||||
acc_global = torch.diag(h).sum() / h.sum()
|
||||
acc = torch.diag(h) / h.sum(1)
|
||||
return acc_global, acc
|
||||
|
||||
def f1_metric(self, average=None):
|
||||
#https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
|
||||
h = self.mat.float()
|
||||
TP = torch.diag(h)
|
||||
TN = []
|
||||
FP = []
|
||||
FN = []
|
||||
for c in range(self.num_classes):
|
||||
idx = torch.ones(self.num_classes).bool()
|
||||
idx[c] = 0
|
||||
# all non-class samples classified as non-class
|
||||
TN.append(self.mat[idx.nonzero()[:, None], idx.nonzero()].sum()) #conf_matrix[idx[:, None], idx].sum() - conf_matrix[idx, c].sum()
|
||||
# all non-class samples classified as class
|
||||
FP.append(self.mat[idx, c].sum())
|
||||
# all class samples not classified as class
|
||||
FN.append(self.mat[c, idx].sum())
|
||||
|
||||
#print('Class {}\nTP {}, TN {}, FP {}, FN {}'.format(c, TP[c], TN[c], FP[c], FN[c]))
|
||||
|
||||
tp = (TP/h.sum(1))#.sum()
|
||||
tn = (torch.tensor(TN, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
||||
fp = (torch.tensor(FP, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
||||
fn = (torch.tensor(FN, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
||||
|
||||
if average=="micro":
|
||||
tp, tn, fp, fn = tp.sum(), tn.sum(), fp.sum(), fn.sum()
|
||||
|
||||
epsilon = 1e-7
|
||||
precision = tp / (tp + fp + epsilon)
|
||||
recall = tp / (tp + fn + epsilon)
|
||||
|
||||
f1 = 2* (precision*recall) / (precision + recall + epsilon)
|
||||
|
||||
if average=="macro":
|
||||
f1=f1.mean()
|
||||
return f1
|
||||
|
||||
def __str__(self):
|
||||
acc_global, acc = self.compute()
|
||||
return (
|
||||
'global correct: {:.1f}\n'
|
||||
'average row correct: {}').format(
|
||||
acc_global.item() * 100,
|
||||
['{:.1f}'.format(i) for i in (acc * 100).tolist()])
|
||||
|
||||
def print_graph(PyTorch_obj, fig_name='graph'):
|
||||
"""Save the computational graph.
|
||||
|
||||
|
@ -42,8 +111,21 @@ def plot_resV2(log, fig_name='res', param_names=None):
|
|||
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||||
ax[0, 0].legend()
|
||||
|
||||
ax[1, 0].set_title('Acc')
|
||||
ax[1, 0].plot(epochs,[x["acc"] for x in log])
|
||||
ax[1, 0].set_title('Test')
|
||||
ax[1, 0].plot(epochs,[x["acc"] for x in log], label='Acc')
|
||||
|
||||
if "f1" in log[0].keys():
|
||||
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1')
|
||||
'''
|
||||
#print(log[0]["f1"])
|
||||
if len(log[0]["f1"])==1:
|
||||
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1')
|
||||
else:
|
||||
for c in range(len(log[0]["f1"])):
|
||||
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c))
|
||||
'''
|
||||
|
||||
ax[1, 0].legend()
|
||||
|
||||
if log[0]["param"]!= None:
|
||||
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
||||
|
@ -73,7 +155,7 @@ def plot_resV2(log, fig_name='res', param_names=None):
|
|||
plt.sca(ax[1, 2]), plt.xticks(rotation=90)
|
||||
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
fig_name = fig_name.replace('.',',').replace(',,/','../')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue