diff --git a/higher/smart_aug/datasets.py b/higher/smart_aug/datasets.py index 431749e..1ec3241 100755 --- a/higher/smart_aug/datasets.py +++ b/higher/smart_aug/datasets.py @@ -3,7 +3,6 @@ MNIST / CIFAR10 """ import torch -from torch.utils.data import SubsetRandomSampler from torch.utils.data.dataset import ConcatDataset import torchvision @@ -72,26 +71,90 @@ data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=downloa #Validation set size [0, 1] -#valid_size=0.1 +valid_size=0.1 #train_subset_indices=range(int(len(data_train)*(1-valid_size))) #val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train)) #train_subset_indices=range(BATCH_SIZE*10) #val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20) +#from torch.utils.data import SubsetRandomSampler #dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory) #dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory) dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) #Cross Validation +''' from skorch.dataset import CVSplit -cvs = CVSplit(cv=5) +import numpy as np +cvs = CVSplit(cv=valid_size, stratified=True) #Stratified =True for unbalanced dataset #ShuffleSplit def next_CVSplit(): - train_subset, val_subset = cvs(data_train) + train_subset, val_subset = cvs(data_train, y=np.array(data_train.targets)) dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) return dl_train, dl_val -dl_train, dl_val = next_CVSplit() \ No newline at end of file +dl_train, dl_val = next_CVSplit() +''' +import numpy as np +from sklearn.model_selection import ShuffleSplit +from sklearn.model_selection import StratifiedShuffleSplit +class CVSplit(object): + """Class that perform train/valid split on a dataset. + + Inspired from : https://skorch.readthedocs.io/en/latest/user/dataset.html + + Attributes: + _stratified (bool): Wether the split should be stratified. Recommended to be True for unbalanced dataset. + _val_size (float, int): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split. + If int, represents the absolute number of validation samples. + _data (Dataset): Dataset to split. + _targets (np.array): Targets of the dataset used if _stratified is set to True. + _cv (BaseShuffleSplit) : Scikit learn object used to split. + + """ + def __init__(self, data, val_size=0.1, stratified=True): + """ Intialize CVSplit. + + Args: + data (Dataset): Dataset to split. + val_size (float, int): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split. + If int, represents the absolute number of validation samples. (Default: 0.1) + stratified (bool): Wether the split should be stratified. Recommended to be True for unbalanced dataset. + """ + self._stratified=stratified + self._val_size=val_size + + self._data=data + if self._stratified: + cv_cls = StratifiedShuffleSplit + self._targets= np.array(data_train.targets) + else: + cv_cls = ShuffleSplit + + self._cv= cv_cls(test_size=val_size, random_state=0) + + def next_split(self): + """ Get next cross-validation split. + + Returns: + Train DataLoader, Validation DataLoader + """ + args=(np.arange(len(self._data)),) + if self._stratified: + args = args + (self._targets,) + + idx_train, idx_valid = next(iter(self._cv.split(*args))) + + train_subset = torch.utils.data.Subset(self._data, idx_train) + val_subset = torch.utils.data.Subset(self._data, idx_valid) + + dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) + dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) + + return dl_train, dl_val + +cvs = CVSplit(data_train, val_size=valid_size) +dl_train, dl_val = cvs.next_split() \ No newline at end of file diff --git a/higher/smart_aug/test_dataug.py b/higher/smart_aug/test_dataug.py index ac81561..34d25f3 100755 --- a/higher/smart_aug/test_dataug.py +++ b/higher/smart_aug/test_dataug.py @@ -79,7 +79,7 @@ if __name__ == "__main__": } #Parameters n_inner_iter = 1 - epochs = 150 + epochs = 2 dataug_epoch_start=0 optim_param={ 'Meta':{ @@ -156,7 +156,7 @@ if __name__ == "__main__": inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, opt_param=optim_param, - print_freq=20, + print_freq=1, unsup_loss=1, hp_opt=False, save_sample_freq=None) diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index 4b95c9e..deafa06 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -177,7 +177,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1): print('Time : %.00f'%(tf - t0)) print('Train loss :',loss.item(), '/ val loss', val_loss.item()) print('Accuracy max:', accuracy) - print('F1 :', f1) + print('F1 :', ["{0:0.4f}".format(i) for i in f1]) #### Log #### data={ @@ -185,7 +185,7 @@ def train_classic(model, opt_param, epochs=1, print_freq=1): "train_loss": loss.item(), "val_loss": val_loss.item(), "acc": accuracy, - "f1": f1.cpu().numpy().tolist(), + "f1": f1.tolist(), "time": tf - t0, "param": None, @@ -253,7 +253,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start for epoch in range(1, epochs+1): t0 = time.perf_counter() - dl_train, dl_val = next_CVSplit() + dl_train, dl_val = cvs.next_split() dl_val_it = iter(dl_val) for i, (xs, ys) in enumerate(dl_train): @@ -333,7 +333,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start "train_loss": loss.item(), "val_loss": val_loss.item(), "acc": accuracy, - "f1": f1.cpu().numpy().tolist(), + "f1": f1.tolist(), "time": tf - t0, "param": param, @@ -349,11 +349,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start print('Time : %.00f'%(tf - t0)) print('Train loss :',loss.item(), '/ val loss', val_loss.item()) print('Accuracy max:', max([x["acc"] for x in log])) - print('F1 :', f1) + print('F1 :', ["{0:0.4f}".format(i) for i in f1]) print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) - if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data) + if not model['data_aug']._fixed_prob: print('TF Proba :', ["{0:0.4f}".format(p) for p in model['data_aug']['prob']]) #print('proba grad',model['data_aug']['prob'].grad) - if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data) + if not model['data_aug']._fixed_mag: print('TF Mag :', ["{0:0.4f}".format(m) for m in model['data_aug']['mag']]) #print('Mag grad',model['data_aug']['mag'].grad) if not model['data_aug']._fixed_mix: print('Mix:', model['data_aug']['mix_dist'].item()) #print('Reg loss:', model['data_aug'].reg_loss().item()) diff --git a/higher/smart_aug/utils.py b/higher/smart_aug/utils.py index 9dedfdc..4a4b5ae 100755 --- a/higher/smart_aug/utils.py +++ b/higher/smart_aug/utils.py @@ -15,31 +15,76 @@ import torch.nn.functional as F import time class ConfusionMatrix(object): + """ Confusion matrix. + + Helps computing the confusion matrix and F1 scores. + + Example use :: + confmat = ConfusionMatrix(...) + + confmat.reset() + for data in dataset: + ... + confmat.update(...) + + confmat.f1_metric(...) + + Attributes: + num_classes (int): Number of classes. + mat (Tensor): Confusion matrix. Filled by update method. + """ def __init__(self, num_classes): + """ Initialize ConfusionMatrix. + + Args: + num_classes (int): Number of classes. + """ self.num_classes = num_classes self.mat = None - def update(self, a, b): + def update(self, target, pred): + """ Update the confusion matrix. + + Args: + target (Tensor): Target labels. + pred (Tensor): Prediction. + """ n = self.num_classes if self.mat is None: - self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) + self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device) with torch.no_grad(): - k = (a >= 0) & (a < n) - inds = n * a[k].to(torch.int64) + b[k] + k = (target >= 0) & (target < n) + inds = n * target[k].to(torch.int64) + pred[k] self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) def reset(self): + """ Reset the Confusion matrix. + + """ if self.mat is not None: self.mat.zero_() - def compute(self): - h = self.mat.float() - acc_global = torch.diag(h).sum() / h.sum() - acc = torch.diag(h) / h.sum(1) - return acc_global, acc - def f1_metric(self, average=None): - #https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6 + """ Compute the F1 score. + + Inspired from : + https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html + https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6 + + Args: + average (str): Type of averaging performed on the data. (Default: None) + ``None``: + The scores for each class are returned. + ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + Return: + Tensor containing the F1 score. It's shape is either 1, if there was averaging, or (num_classes). + """ + h = self.mat.float() TP = torch.diag(h) TN = [] @@ -75,14 +120,6 @@ class ConfusionMatrix(object): f1=f1.mean() return f1 - def __str__(self): - acc_global, acc = self.compute() - return ( - 'global correct: {:.1f}\n' - 'average row correct: {}').format( - acc_global.item() * 100, - ['{:.1f}'.format(i) for i in (acc * 100).tolist()]) - def print_graph(PyTorch_obj, fig_name='graph'): """Save the computational graph.