mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-03 11:40:46 +02:00
Comment Confmat + Cross-Val (sans Skorch) + minor improv
This commit is contained in:
parent
385bc9977c
commit
be8491268a
4 changed files with 133 additions and 33 deletions
|
@ -3,7 +3,6 @@
|
|||
MNIST / CIFAR10
|
||||
"""
|
||||
import torch
|
||||
from torch.utils.data import SubsetRandomSampler
|
||||
from torch.utils.data.dataset import ConcatDataset
|
||||
import torchvision
|
||||
|
||||
|
@ -72,26 +71,90 @@ data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=downloa
|
|||
|
||||
|
||||
#Validation set size [0, 1]
|
||||
#valid_size=0.1
|
||||
valid_size=0.1
|
||||
#train_subset_indices=range(int(len(data_train)*(1-valid_size)))
|
||||
#val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
|
||||
|
||||
#from torch.utils.data import SubsetRandomSampler
|
||||
#dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||
#dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
||||
#Cross Validation
|
||||
'''
|
||||
from skorch.dataset import CVSplit
|
||||
cvs = CVSplit(cv=5)
|
||||
import numpy as np
|
||||
cvs = CVSplit(cv=valid_size, stratified=True) #Stratified =True for unbalanced dataset #ShuffleSplit
|
||||
|
||||
def next_CVSplit():
|
||||
|
||||
train_subset, val_subset = cvs(data_train)
|
||||
train_subset, val_subset = cvs(data_train, y=np.array(data_train.targets))
|
||||
dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
||||
return dl_train, dl_val
|
||||
|
||||
dl_train, dl_val = next_CVSplit()
|
||||
dl_train, dl_val = next_CVSplit()
|
||||
'''
|
||||
import numpy as np
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
from sklearn.model_selection import StratifiedShuffleSplit
|
||||
class CVSplit(object):
|
||||
"""Class that perform train/valid split on a dataset.
|
||||
|
||||
Inspired from : https://skorch.readthedocs.io/en/latest/user/dataset.html
|
||||
|
||||
Attributes:
|
||||
_stratified (bool): Wether the split should be stratified. Recommended to be True for unbalanced dataset.
|
||||
_val_size (float, int): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split.
|
||||
If int, represents the absolute number of validation samples.
|
||||
_data (Dataset): Dataset to split.
|
||||
_targets (np.array): Targets of the dataset used if _stratified is set to True.
|
||||
_cv (BaseShuffleSplit) : Scikit learn object used to split.
|
||||
|
||||
"""
|
||||
def __init__(self, data, val_size=0.1, stratified=True):
|
||||
""" Intialize CVSplit.
|
||||
|
||||
Args:
|
||||
data (Dataset): Dataset to split.
|
||||
val_size (float, int): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split.
|
||||
If int, represents the absolute number of validation samples. (Default: 0.1)
|
||||
stratified (bool): Wether the split should be stratified. Recommended to be True for unbalanced dataset.
|
||||
"""
|
||||
self._stratified=stratified
|
||||
self._val_size=val_size
|
||||
|
||||
self._data=data
|
||||
if self._stratified:
|
||||
cv_cls = StratifiedShuffleSplit
|
||||
self._targets= np.array(data_train.targets)
|
||||
else:
|
||||
cv_cls = ShuffleSplit
|
||||
|
||||
self._cv= cv_cls(test_size=val_size, random_state=0)
|
||||
|
||||
def next_split(self):
|
||||
""" Get next cross-validation split.
|
||||
|
||||
Returns:
|
||||
Train DataLoader, Validation DataLoader
|
||||
"""
|
||||
args=(np.arange(len(self._data)),)
|
||||
if self._stratified:
|
||||
args = args + (self._targets,)
|
||||
|
||||
idx_train, idx_valid = next(iter(self._cv.split(*args)))
|
||||
|
||||
train_subset = torch.utils.data.Subset(self._data, idx_train)
|
||||
val_subset = torch.utils.data.Subset(self._data, idx_valid)
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
||||
return dl_train, dl_val
|
||||
|
||||
cvs = CVSplit(data_train, val_size=valid_size)
|
||||
dl_train, dl_val = cvs.next_split()
|
|
@ -79,7 +79,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
#Parameters
|
||||
n_inner_iter = 1
|
||||
epochs = 150
|
||||
epochs = 2
|
||||
dataug_epoch_start=0
|
||||
optim_param={
|
||||
'Meta':{
|
||||
|
@ -156,7 +156,7 @@ if __name__ == "__main__":
|
|||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=20,
|
||||
print_freq=1,
|
||||
unsup_loss=1,
|
||||
hp_opt=False,
|
||||
save_sample_freq=None)
|
||||
|
|
|
@ -177,7 +177,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy max:', accuracy)
|
||||
print('F1 :', f1)
|
||||
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
|
@ -185,7 +185,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"f1": f1.cpu().numpy().tolist(),
|
||||
"f1": f1.tolist(),
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
|
@ -253,7 +253,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
for epoch in range(1, epochs+1):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
dl_train, dl_val = next_CVSplit()
|
||||
dl_train, dl_val = cvs.next_split()
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
|
@ -333,7 +333,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"f1": f1.cpu().numpy().tolist(),
|
||||
"f1": f1.tolist(),
|
||||
"time": tf - t0,
|
||||
|
||||
"param": param,
|
||||
|
@ -349,11 +349,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy max:', max([x["acc"] for x in log]))
|
||||
print('F1 :', f1)
|
||||
print('F1 :', ["{0:0.4f}".format(i) for i in f1])
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
|
||||
if not model['data_aug']._fixed_prob: print('TF Proba :', ["{0:0.4f}".format(p) for p in model['data_aug']['prob']])
|
||||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
|
||||
if not model['data_aug']._fixed_mag: print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']])
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
if not model['data_aug']._fixed_mix: print('Mix:', model['data_aug']['mix_dist'].item())
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
|
|
|
@ -15,31 +15,76 @@ import torch.nn.functional as F
|
|||
import time
|
||||
|
||||
class ConfusionMatrix(object):
|
||||
""" Confusion matrix.
|
||||
|
||||
Helps computing the confusion matrix and F1 scores.
|
||||
|
||||
Example use ::
|
||||
confmat = ConfusionMatrix(...)
|
||||
|
||||
confmat.reset()
|
||||
for data in dataset:
|
||||
...
|
||||
confmat.update(...)
|
||||
|
||||
confmat.f1_metric(...)
|
||||
|
||||
Attributes:
|
||||
num_classes (int): Number of classes.
|
||||
mat (Tensor): Confusion matrix. Filled by update method.
|
||||
"""
|
||||
def __init__(self, num_classes):
|
||||
""" Initialize ConfusionMatrix.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes.
|
||||
"""
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, a, b):
|
||||
def update(self, target, pred):
|
||||
""" Update the confusion matrix.
|
||||
|
||||
Args:
|
||||
target (Tensor): Target labels.
|
||||
pred (Tensor): Prediction.
|
||||
"""
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
|
||||
with torch.no_grad():
|
||||
k = (a >= 0) & (a < n)
|
||||
inds = n * a[k].to(torch.int64) + b[k]
|
||||
k = (target >= 0) & (target < n)
|
||||
inds = n * target[k].to(torch.int64) + pred[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
""" Reset the Confusion matrix.
|
||||
|
||||
"""
|
||||
if self.mat is not None:
|
||||
self.mat.zero_()
|
||||
|
||||
def compute(self):
|
||||
h = self.mat.float()
|
||||
acc_global = torch.diag(h).sum() / h.sum()
|
||||
acc = torch.diag(h) / h.sum(1)
|
||||
return acc_global, acc
|
||||
|
||||
def f1_metric(self, average=None):
|
||||
#https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
|
||||
""" Compute the F1 score.
|
||||
|
||||
Inspired from :
|
||||
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
|
||||
https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
|
||||
|
||||
Args:
|
||||
average (str): Type of averaging performed on the data. (Default: None)
|
||||
``None``:
|
||||
The scores for each class are returned.
|
||||
``'micro'``:
|
||||
Calculate metrics globally by counting the total true positives,
|
||||
false negatives and false positives.
|
||||
``'macro'``:
|
||||
Calculate metrics for each label, and find their unweighted
|
||||
mean. This does not take label imbalance into account.
|
||||
Return:
|
||||
Tensor containing the F1 score. It's shape is either 1, if there was averaging, or (num_classes).
|
||||
"""
|
||||
|
||||
h = self.mat.float()
|
||||
TP = torch.diag(h)
|
||||
TN = []
|
||||
|
@ -75,14 +120,6 @@ class ConfusionMatrix(object):
|
|||
f1=f1.mean()
|
||||
return f1
|
||||
|
||||
def __str__(self):
|
||||
acc_global, acc = self.compute()
|
||||
return (
|
||||
'global correct: {:.1f}\n'
|
||||
'average row correct: {}').format(
|
||||
acc_global.item() * 100,
|
||||
['{:.1f}'.format(i) for i in (acc * 100).tolist()])
|
||||
|
||||
def print_graph(PyTorch_obj, fig_name='graph'):
|
||||
"""Save the computational graph.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue