mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Cross Validation splits + New mesure process time (train utils)
This commit is contained in:
parent
bce882de38
commit
385bc9977c
3 changed files with 51 additions and 30 deletions
|
@ -19,6 +19,8 @@ download_data=False
|
|||
num_workers=2 #4
|
||||
#Pin GPU memory
|
||||
pin_memory=False #True :+ GPU memory / + Lent
|
||||
#Data storage folder
|
||||
dataroot="../data"
|
||||
|
||||
#ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1]
|
||||
#transform_train = torchvision.transforms.Compose([
|
||||
|
@ -41,7 +43,6 @@ transform_train = torchvision.transforms.Compose([
|
|||
#transform_train.transforms.insert(0, RandAugment(n=2, m=30))
|
||||
|
||||
### Classic Dataset ###
|
||||
dataroot="../data"
|
||||
|
||||
#MNIST
|
||||
#data_train = torchvision.datasets.MNIST(dataroot, train=True, download=True, transform=transform_train)
|
||||
|
@ -70,11 +71,27 @@ data_test = torchvision.datasets.CIFAR10(dataroot, train=False, download=downloa
|
|||
#data_test = torchvision.datasets.ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)
|
||||
|
||||
|
||||
train_subset_indices=range(int(len(data_train)/2))
|
||||
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
||||
#Validation set size [0, 1]
|
||||
#valid_size=0.1
|
||||
#train_subset_indices=range(int(len(data_train)*(1-valid_size)))
|
||||
#val_subset_indices=range(int(len(data_train)*(1-valid_size)),len(data_train))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||
#dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||
#dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
||||
#Cross Validation
|
||||
from skorch.dataset import CVSplit
|
||||
cvs = CVSplit(cv=5)
|
||||
|
||||
def next_CVSplit():
|
||||
|
||||
train_subset, val_subset = cvs(data_train)
|
||||
dl_train = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_val = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
||||
return dl_train, dl_val
|
||||
|
||||
dl_train, dl_val = next_CVSplit()
|
Loading…
Add table
Add a link
Reference in a new issue