From 99f15b8946dd202fcd01a656f475765ac027e6bc Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Fri, 24 Jan 2020 15:10:08 -0500 Subject: [PATCH] Test doxygen --- .gitignore | 1 + higher/smart_aug/datasets.py | 9 ++++- higher/smart_aug/model.py | 12 ++++++ higher/smart_aug/transformations.py | 57 +++++++++++++++++------------ 4 files changed, 54 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 76591f9..5da30ff 100755 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ /higher/data /higher/samples +/higher/doc /Gradient-Descent-The-Ultimate-Optimizer/data /FAR-HO/data /__pycache__ diff --git a/higher/smart_aug/datasets.py b/higher/smart_aug/datasets.py index 51deacd..9d4b5a6 100755 --- a/higher/smart_aug/datasets.py +++ b/higher/smart_aug/datasets.py @@ -6,12 +6,17 @@ import torch from torch.utils.data import SubsetRandomSampler import torchvision +#Train/Validation batch size. BATCH_SIZE = 300 -TEST_SIZE = 300 +#Test batch size. +TEST_SIZE = BATCH_SIZE #TEST_SIZE = 10000 #legerement +Rapide / + Consomation memoire ! +#Wether to download data. download_data=False +#Number of worker to use. num_workers=2 #4 +#Pin GPU memory pin_memory=False #True :+ GPU memory / + Lent #ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1] @@ -37,8 +42,10 @@ transform = torchvision.transforms.Compose([ #) ### Classic Dataset ### +#Training data data_train = torchvision.datasets.CIFAR10("../data", train=True, download=download_data, transform=transform) #data_val = torchvision.datasets.CIFAR10("../data", train=True, download=download_data, transform=transform) +#Testing data data_test = torchvision.datasets.CIFAR10("../data", train=False, download=download_data, transform=transform) train_subset_indices=range(int(len(data_train)/2)) diff --git a/higher/smart_aug/model.py b/higher/smart_aug/model.py index a38bbae..4b377ce 100755 --- a/higher/smart_aug/model.py +++ b/higher/smart_aug/model.py @@ -5,7 +5,13 @@ import torch.nn.functional as F ## Basic CNN ## class LeNet(nn.Module): + """Basic CNN. + + """ def __init__(self, num_inp, num_out): + """Init LeNet. + + """ super(LeNet, self).__init__() self.conv1 = nn.Conv2d(num_inp, 20, 5) self.pool = nn.MaxPool2d(2, 2) @@ -15,6 +21,9 @@ class LeNet(nn.Module): self.fc2 = nn.Linear(500, num_out) def forward(self, x): + """Main method of LeNet + + """ x = self.pool(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = x.view(x.size(0), -1) @@ -23,4 +32,7 @@ class LeNet(nn.Module): return x def __str__(self): + """ Get name of model + + """ return "LeNet" diff --git a/higher/smart_aug/transformations.py b/higher/smart_aug/transformations.py index f7b661e..8534584 100755 --- a/higher/smart_aug/transformations.py +++ b/higher/smart_aug/transformations.py @@ -18,15 +18,18 @@ import torch import kornia import random +#TF that don't have use for magnitude parameter. +TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} +#TF which implemetation doesn't allow gradient propagaition. +TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} +#TF for which magnitude should be ignored (Magnitude fixed). +TF_ignore_mag= TF_no_mag | TF_no_grad -TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} #TF that don't have use for magnitude parameter. -TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} #TF which implemetation doesn't allow gradient propagaition. -TF_ignore_mag= TF_no_mag | TF_no_grad #TF for which magnitude should be ignored (Magnitude fixed). +# What is the max 'level' a transform could be predicted +PARAMETER_MAX = 1 +# What is the min 'level' a transform could be predicted +PARAMETER_MIN = 0.1 -PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted -PARAMETER_MIN = 0.1 # What is the min 'level' a transform could be predicted - -### Available TF for Dataug ### # Dictionnary mapping tranformations identifiers to their function. # Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data. TF_dict={ #Dataugv5+ @@ -416,13 +419,16 @@ def blend(x,y,alpha): return res #Not working -def auto_contrast(x): #PAS OPTIMISE POUR DES BATCH #EXTRA LENT - # Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel - print("Warning : Pas encore check !") - (batch_size, channels, h, w) = x.shape - x = int_image(x) #Expect image in the range of [0, 1] - #print('Start',x[0]) - for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image +def auto_contrast(x): + """NOT TESTED - EXTRA SLOW + + """ + # Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel + print("Warning : Pas encore check !") + (batch_size, channels, h, w) = x.shape + x = int_image(x) #Expect image in the range of [0, 1] + #print('Start',x[0]) + for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image #print(img.shape) for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel #print(chan.shape) @@ -449,19 +455,22 @@ def auto_contrast(x): #PAS OPTIMISE POUR DES BATCH #EXTRA LENT chan[chan==ix]=n_ix x[im_idx, chan_idx]=chan - #print('End',x[0]) - return float_image(x) + #print('End',x[0]) + return float_image(x) -def equalize(x): #PAS OPTIMISE POUR DES BATCH - raise Exception(self, "not implemented") - # Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel - (batch_size, channels, h, w) = x.shape - x = int_image(x) #Expect image in the range of [0, 1] - #print('Start',x[0]) - for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image +def equalize(x): + """ NOT WORKING + + """ + raise Exception(self, "not implemented") + # Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel + (batch_size, channels, h, w) = x.shape + x = int_image(x) #Expect image in the range of [0, 1] + #print('Start',x[0]) + for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image #print(img.shape) for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel #print(chan.shape) hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE - return float_image(x) \ No newline at end of file + return float_image(x) \ No newline at end of file