mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Rangement
This commit is contained in:
parent
ca3367d19f
commit
4166922c34
453 changed files with 9797 additions and 7 deletions
202
Old/salvador/utils.py
Executable file
202
Old/salvador/utils.py
Executable file
|
@ -0,0 +1,202 @@
|
|||
from __future__ import print_function
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import math
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
from fastprogress import progress_bar
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{global_avg:.4f}"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
class ConfusionMatrix(object):
|
||||
def __init__(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, a, b):
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||
with torch.no_grad():
|
||||
k = (a >= 0) & (a < n)
|
||||
inds = n * a[k].to(torch.int64) + b[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
self.mat.zero_()
|
||||
|
||||
def compute(self):
|
||||
h = self.mat.float()
|
||||
acc_global = torch.diag(h).sum() / h.sum()
|
||||
acc = torch.diag(h) / h.sum(1)
|
||||
return acc_global, acc
|
||||
|
||||
|
||||
def __str__(self):
|
||||
acc_global, acc = self.compute()
|
||||
return (
|
||||
'global correct: {:.1f}\n'
|
||||
'average row correct: {}').format(
|
||||
acc_global.item() * 100,
|
||||
['{:.1f}'.format(i) for i in (acc * 100).tolist()])
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, parent, header=None, **kwargs):
|
||||
if not header:
|
||||
header = ''
|
||||
log_msg = self.delimiter.join([
|
||||
'{meters}'
|
||||
])
|
||||
|
||||
progrss = progress_bar(iterable, parent=parent, **kwargs)
|
||||
|
||||
for idx, obj in enumerate(progrss):
|
||||
yield idx, obj
|
||||
progrss.comment = log_msg.format(
|
||||
meters=str(self))
|
||||
|
||||
print('{header} {meters}'.format(header=header, meters=str(self)))
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target[None])
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
|
||||
res.append(correct_k * (100.0 / batch_size))
|
||||
return res
|
||||
|
||||
class EarlyStopping:
|
||||
"""Early stops the training if validation loss doesn't improve after a given patience."""
|
||||
def __init__(self, patience=7, verbose=False, delta=0, augmented_model=False):
|
||||
"""
|
||||
Args:
|
||||
patience (int): How long to wait after last time validation loss improved.
|
||||
Default: 7
|
||||
verbose (bool): If True, prints a message for each validation loss improvement.
|
||||
Default: False
|
||||
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
||||
Default: 0
|
||||
"""
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
self.val_loss_min = np.Inf
|
||||
self.delta = delta
|
||||
|
||||
self.augmented_model = augmented_model
|
||||
|
||||
def __call__(self, val_loss, model):
|
||||
|
||||
score = -val_loss
|
||||
|
||||
if self.best_score is None:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
elif score < self.best_score - self.delta:
|
||||
self.counter += 1
|
||||
# print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||
# if self.counter >= self.patience:
|
||||
# self.early_stop = True
|
||||
else:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
self.counter = 0
|
||||
|
||||
def save_checkpoint(self, val_loss, model):
|
||||
'''Saves model when validation loss decrease.'''
|
||||
if self.verbose:
|
||||
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
|
||||
torch.save(model.state_dict(), 'checkpoint.pt') if not self.augmented_model else torch.save(model['model'].state_dict(), 'checkpoint.pt')
|
||||
self.val_loss_min = val_loss
|
Loading…
Add table
Add a link
Reference in a new issue