|
int | BATCH_SIZE = 300 |
|
int | TEST_SIZE = BATCH_SIZE |
|
bool | download_data = False |
|
int | num_workers = 2 |
|
bool | pin_memory = False |
|
| transform |
|
| data_train = torchvision.datasets.CIFAR10("../data", train=True, download=download_data, transform=transform) |
| Classic Dataset ### Training data.
|
|
| data_test = torchvision.datasets.CIFAR10("../data", train=False, download=download_data, transform=transform) |
|
| train_subset_indices = range(int(len(data_train)/2)) |
|
| val_subset_indices = range(int(len(data_train)/2),len(data_train)) |
|
| dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory) |
|
| dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory) |
|
| dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) |
|
Dataset definition.
MNIST / CIFAR10