smart_augmentation/Gradient-Descent-The-Ultimate-Optimizer/dataset_aug.py
Harle, Antoine (Contracteur) 3ae3e02e59 Initial Commit
2019-11-08 11:28:06 -05:00

52 lines
1.5 KiB
Python

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
class MNIST_aug(Dataset):
training_file = 'training.pt'
test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
def __init__(self):
self.images = [TF.to_pil_image(x) for x in torch.ByteTensor(10, 3, 48, 48)]
self.set_stage(0) # initial stage
def __getitem__(self, index):
image = self.images[index]
# Just apply your transformations here
image = self.crop(image)
x = TF.to_tensor(image)
return x
def set_stage(self, stage):
if stage == 0:
print('Using (32, 32) crops')
self.crop = transforms.RandomCrop((32, 32))
elif stage == 1:
print('Using (28, 28) crops')
self.crop = transforms.RandomCrop((28, 28))
def __len__(self):
return len(self.images)
dataset = MyData()
loader = DataLoader(dataset,
batch_size=2,
num_workers=2,
shuffle=True)
for batch_idx, data in enumerate(loader):
print('Batch idx {}, data shape {}'.format(
batch_idx, data.shape))
loader.dataset.set_stage(1)
for batch_idx, data in enumerate(loader):
print('Batch idx {}, data shape {}'.format(
batch_idx, data.shape))