from __future__ import print_function from collections import defaultdict, deque import datetime import math import time import torch import numpy as np import os from fastprogress import progress_bar class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{global_avg:.4f}" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) class ConfusionMatrix(object): def __init__(self, num_classes): self.num_classes = num_classes self.mat = None def update(self, a, b): n = self.num_classes if self.mat is None: self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) with torch.no_grad(): k = (a >= 0) & (a < n) inds = n * a[k].to(torch.int64) + b[k] self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) def reset(self): self.mat.zero_() def compute(self): h = self.mat.float() acc_global = torch.diag(h).sum() / h.sum() acc = torch.diag(h) / h.sum(1) return acc_global, acc def __str__(self): acc_global, acc = self.compute() return ( 'global correct: {:.1f}\n' 'average row correct: {}').format( acc_global.item() * 100, ['{:.1f}'.format(i) for i in (acc * 100).tolist()]) class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, parent, header=None, **kwargs): if not header: header = '' log_msg = self.delimiter.join([ '{meters}' ]) progrss = progress_bar(iterable, parent=parent, **kwargs) for idx, obj in enumerate(progrss): yield idx, obj progrss.comment = log_msg.format( meters=str(self)) print('{header} {meters}'.format(header=header, meters=str(self))) def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target[None]) res = [] for k in topk: correct_k = correct[:k].flatten().sum(dtype=torch.float32) res.append(correct_k * (100.0 / batch_size)) return res class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" def __init__(self, patience=7, verbose=False, delta=0): """ Args: patience (int): How long to wait after last time validation loss improved. Default: 7 verbose (bool): If True, prints a message for each validation loss improvement. Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = np.Inf self.delta = delta def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score - self.delta: self.counter += 1 # print(f'EarlyStopping counter: {self.counter} out of {self.patience}') # if self.counter >= self.patience: # self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 def save_checkpoint(self, val_loss, model): '''Saves model when validation loss decrease.''' if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') torch.save(model.state_dict(), 'checkpoint.pt') self.val_loss_min = val_loss