mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
52 lines
1.5 KiB
Python
Executable file
52 lines
1.5 KiB
Python
Executable file
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import transforms
|
|
import torchvision.transforms.functional as TF
|
|
|
|
class MNIST_aug(Dataset):
|
|
|
|
training_file = 'training.pt'
|
|
test_file = 'test.pt'
|
|
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
|
|
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
|
|
|
|
def __init__(self):
|
|
self.images = [TF.to_pil_image(x) for x in torch.ByteTensor(10, 3, 48, 48)]
|
|
self.set_stage(0) # initial stage
|
|
|
|
def __getitem__(self, index):
|
|
image = self.images[index]
|
|
|
|
# Just apply your transformations here
|
|
image = self.crop(image)
|
|
x = TF.to_tensor(image)
|
|
return x
|
|
|
|
def set_stage(self, stage):
|
|
if stage == 0:
|
|
print('Using (32, 32) crops')
|
|
self.crop = transforms.RandomCrop((32, 32))
|
|
elif stage == 1:
|
|
print('Using (28, 28) crops')
|
|
self.crop = transforms.RandomCrop((28, 28))
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
|
|
dataset = MyData()
|
|
loader = DataLoader(dataset,
|
|
batch_size=2,
|
|
num_workers=2,
|
|
shuffle=True)
|
|
|
|
for batch_idx, data in enumerate(loader):
|
|
print('Batch idx {}, data shape {}'.format(
|
|
batch_idx, data.shape))
|
|
|
|
loader.dataset.set_stage(1)
|
|
|
|
for batch_idx, data in enumerate(loader):
|
|
print('Batch idx {}, data shape {}'.format(
|
|
batch_idx, data.shape))
|
|
|