mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Mise a jour de toute les modifs... (Higher: Ajout deux TF, modification val loss, ajout prob dans sample image, ...)
This commit is contained in:
parent
e75fb96716
commit
c8ce6c8024
6 changed files with 299 additions and 64 deletions
|
@ -35,6 +35,8 @@ import augmentation_transforms
|
|||
import numpy as np
|
||||
|
||||
download_data=False
|
||||
num_workers=0
|
||||
pin_memory=False
|
||||
|
||||
class AugmentedDataset(VisionDataset):
|
||||
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None):
|
||||
|
@ -281,14 +283,14 @@ train_subset_indices=range(int(len(data_train)/2))
|
|||
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
||||
### Augmented Dataset ###
|
||||
#data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2)))
|
||||
#data_train_aug.augement_data(aug_copy=10)
|
||||
#print(data_train_aug)
|
||||
#dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True)
|
||||
#dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
||||
|
||||
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False)
|
||||
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue