mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Rangement
This commit is contained in:
parent
ca3367d19f
commit
4166922c34
453 changed files with 9797 additions and 7 deletions
|
@ -1,52 +0,0 @@
|
|||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
class MNIST_aug(Dataset):
|
||||
|
||||
training_file = 'training.pt'
|
||||
test_file = 'test.pt'
|
||||
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
|
||||
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
|
||||
|
||||
def __init__(self):
|
||||
self.images = [TF.to_pil_image(x) for x in torch.ByteTensor(10, 3, 48, 48)]
|
||||
self.set_stage(0) # initial stage
|
||||
|
||||
def __getitem__(self, index):
|
||||
image = self.images[index]
|
||||
|
||||
# Just apply your transformations here
|
||||
image = self.crop(image)
|
||||
x = TF.to_tensor(image)
|
||||
return x
|
||||
|
||||
def set_stage(self, stage):
|
||||
if stage == 0:
|
||||
print('Using (32, 32) crops')
|
||||
self.crop = transforms.RandomCrop((32, 32))
|
||||
elif stage == 1:
|
||||
print('Using (28, 28) crops')
|
||||
self.crop = transforms.RandomCrop((28, 28))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
|
||||
dataset = MyData()
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
shuffle=True)
|
||||
|
||||
for batch_idx, data in enumerate(loader):
|
||||
print('Batch idx {}, data shape {}'.format(
|
||||
batch_idx, data.shape))
|
||||
|
||||
loader.dataset.set_stage(1)
|
||||
|
||||
for batch_idx, data in enumerate(loader):
|
||||
print('Batch idx {}, data shape {}'.format(
|
||||
batch_idx, data.shape))
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue