Cross Validation splits + New mesure process time (train utils)

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-03 15:08:22 -05:00
parent bce882de38
commit 385bc9977c
3 changed files with 51 additions and 30 deletions

View file

@ -19,6 +19,8 @@ download_data=False
num_workers=2 #4
#Pin GPU memory
pin_memory=False #True :+ GPU memory / + Lent
#Data storage folder
dataroot="../data"
#ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1]
#transform_train = torchvision.transforms.Compose([
@ -41,7 +43,6 @@ transform_train = torchvision.transforms.Compose([
#transform_train.transforms.insert(0, RandAugment(n=2, m=30))
### Classic Dataset ###
dataroot="../data"
#MNIST
#data_train = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=transform_train)
@ -70,11 +71,27 @@ data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=downloa
#data_test = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)
train_subset_indices=range(int(len(data_train)/2))
val_subset_indices=range(int(len(data_train)/2),len(data_train))
#Validation set size [0, 1]
#valid_size=0.1
#train_subset_indices=range(int(len(data_train)*(1-valid_size)))
#val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train))
#train_subset_indices=range(BATCH_SIZE*10)
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
#dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
#dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
#Cross Validation
from skorch.dataset import CVSplit
cvs = CVSplit(cv=5)
def next_CVSplit():
train_subset, val_subset = cvs(data_train)
dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
return dl_train, dl_val
dl_train, dl_val = next_CVSplit()