mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
200 lines
5.9 KiB
Python
200 lines
5.9 KiB
Python
|
from __future__ import print_function
|
||
|
from collections import defaultdict, deque
|
||
|
import datetime
|
||
|
import math
|
||
|
import time
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
|
||
|
import os
|
||
|
from fastprogress import progress_bar
|
||
|
|
||
|
class SmoothedValue(object):
|
||
|
"""Track a series of values and provide access to smoothed values over a
|
||
|
window or the global series average.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, window_size=20, fmt=None):
|
||
|
if fmt is None:
|
||
|
fmt = "{global_avg:.4f}"
|
||
|
self.deque = deque(maxlen=window_size)
|
||
|
self.total = 0.0
|
||
|
self.count = 0
|
||
|
self.fmt = fmt
|
||
|
|
||
|
def update(self, value, n=1):
|
||
|
self.deque.append(value)
|
||
|
self.count += n
|
||
|
self.total += value * n
|
||
|
|
||
|
@property
|
||
|
def median(self):
|
||
|
d = torch.tensor(list(self.deque))
|
||
|
return d.median().item()
|
||
|
|
||
|
@property
|
||
|
def avg(self):
|
||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||
|
return d.mean().item()
|
||
|
|
||
|
@property
|
||
|
def global_avg(self):
|
||
|
return self.total / self.count
|
||
|
|
||
|
@property
|
||
|
def max(self):
|
||
|
return max(self.deque)
|
||
|
|
||
|
@property
|
||
|
def value(self):
|
||
|
return self.deque[-1]
|
||
|
|
||
|
def __str__(self):
|
||
|
return self.fmt.format(
|
||
|
median=self.median,
|
||
|
avg=self.avg,
|
||
|
global_avg=self.global_avg,
|
||
|
max=self.max,
|
||
|
value=self.value)
|
||
|
|
||
|
|
||
|
class ConfusionMatrix(object):
|
||
|
def __init__(self, num_classes):
|
||
|
self.num_classes = num_classes
|
||
|
self.mat = None
|
||
|
|
||
|
def update(self, a, b):
|
||
|
n = self.num_classes
|
||
|
if self.mat is None:
|
||
|
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||
|
with torch.no_grad():
|
||
|
k = (a >= 0) & (a < n)
|
||
|
inds = n * a[k].to(torch.int64) + b[k]
|
||
|
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||
|
|
||
|
def reset(self):
|
||
|
self.mat.zero_()
|
||
|
|
||
|
def compute(self):
|
||
|
h = self.mat.float()
|
||
|
acc_global = torch.diag(h).sum() / h.sum()
|
||
|
acc = torch.diag(h) / h.sum(1)
|
||
|
return acc_global, acc
|
||
|
|
||
|
|
||
|
def __str__(self):
|
||
|
acc_global, acc = self.compute()
|
||
|
return (
|
||
|
'global correct: {:.1f}\n'
|
||
|
'average row correct: {}').format(
|
||
|
acc_global.item() * 100,
|
||
|
['{:.1f}'.format(i) for i in (acc * 100).tolist()])
|
||
|
|
||
|
|
||
|
class MetricLogger(object):
|
||
|
def __init__(self, delimiter="\t"):
|
||
|
self.meters = defaultdict(SmoothedValue)
|
||
|
self.delimiter = delimiter
|
||
|
|
||
|
def update(self, **kwargs):
|
||
|
for k, v in kwargs.items():
|
||
|
if isinstance(v, torch.Tensor):
|
||
|
v = v.item()
|
||
|
assert isinstance(v, (float, int))
|
||
|
self.meters[k].update(v)
|
||
|
|
||
|
def __getattr__(self, attr):
|
||
|
if attr in self.meters:
|
||
|
return self.meters[attr]
|
||
|
if attr in self.__dict__:
|
||
|
return self.__dict__[attr]
|
||
|
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||
|
type(self).__name__, attr))
|
||
|
|
||
|
def __str__(self):
|
||
|
loss_str = []
|
||
|
for name, meter in self.meters.items():
|
||
|
loss_str.append(
|
||
|
"{}: {}".format(name, str(meter))
|
||
|
)
|
||
|
return self.delimiter.join(loss_str)
|
||
|
|
||
|
|
||
|
def add_meter(self, name, meter):
|
||
|
self.meters[name] = meter
|
||
|
|
||
|
def log_every(self, iterable, parent, header=None, **kwargs):
|
||
|
if not header:
|
||
|
header = ''
|
||
|
log_msg = self.delimiter.join([
|
||
|
'{meters}'
|
||
|
])
|
||
|
|
||
|
progrss = progress_bar(iterable, parent=parent, **kwargs)
|
||
|
|
||
|
for idx, obj in enumerate(progrss):
|
||
|
yield idx, obj
|
||
|
progrss.comment = log_msg.format(
|
||
|
meters=str(self))
|
||
|
|
||
|
print('{header} {meters}'.format(header=header, meters=str(self)))
|
||
|
|
||
|
def accuracy(output, target, topk=(1,)):
|
||
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||
|
with torch.no_grad():
|
||
|
maxk = max(topk)
|
||
|
batch_size = target.size(0)
|
||
|
|
||
|
_, pred = output.topk(maxk, 1, True, True)
|
||
|
pred = pred.t()
|
||
|
correct = pred.eq(target[None])
|
||
|
|
||
|
res = []
|
||
|
for k in topk:
|
||
|
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
|
||
|
res.append(correct_k * (100.0 / batch_size))
|
||
|
return res
|
||
|
|
||
|
class EarlyStopping:
|
||
|
"""Early stops the training if validation loss doesn't improve after a given patience."""
|
||
|
def __init__(self, patience=7, verbose=False, delta=0):
|
||
|
"""
|
||
|
Args:
|
||
|
patience (int): How long to wait after last time validation loss improved.
|
||
|
Default: 7
|
||
|
verbose (bool): If True, prints a message for each validation loss improvement.
|
||
|
Default: False
|
||
|
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
||
|
Default: 0
|
||
|
"""
|
||
|
self.patience = patience
|
||
|
self.verbose = verbose
|
||
|
self.counter = 0
|
||
|
self.best_score = None
|
||
|
self.early_stop = False
|
||
|
self.val_loss_min = np.Inf
|
||
|
self.delta = delta
|
||
|
|
||
|
def __call__(self, val_loss, model):
|
||
|
|
||
|
score = -val_loss
|
||
|
|
||
|
if self.best_score is None:
|
||
|
self.best_score = score
|
||
|
self.save_checkpoint(val_loss, model)
|
||
|
elif score < self.best_score - self.delta:
|
||
|
self.counter += 1
|
||
|
# print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||
|
# if self.counter >= self.patience:
|
||
|
# self.early_stop = True
|
||
|
else:
|
||
|
self.best_score = score
|
||
|
self.save_checkpoint(val_loss, model)
|
||
|
self.counter = 0
|
||
|
|
||
|
def save_checkpoint(self, val_loss, model):
|
||
|
'''Saves model when validation loss decrease.'''
|
||
|
if self.verbose:
|
||
|
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
|
||
|
torch.save(model.state_dict(), 'checkpoint.pt')
|
||
|
self.val_loss_min = val_loss
|