mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Changement permission fichiers + Simplification utilisation Augmented_dataset
This commit is contained in:
parent
adaac437b6
commit
b26fbcd2a2
619 changed files with 41 additions and 13049 deletions
15
higher/datasets.py
Normal file → Executable file
15
higher/datasets.py
Normal file → Executable file
|
@ -125,10 +125,11 @@ class AugmentedDataset(VisionDataset):
|
|||
aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image
|
||||
#aug_image = augmentation_transforms.cutout_numpy(aug_image)
|
||||
|
||||
self.unsup_data+=[aug_image]
|
||||
self.unsup_data+=[(aug_image*255.).astype(self.sup_data.dtype)]#Cast float image to uint8
|
||||
self.unsup_targets+=[self.sup_targets[idx]]
|
||||
|
||||
self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8
|
||||
#self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8
|
||||
self.unsup_data=np.array(self.unsup_data)
|
||||
self.data= np.concatenate((self.sup_data, self.unsup_data), axis=0)
|
||||
self.targets= np.concatenate((self.sup_targets, self.unsup_targets), axis=0)
|
||||
|
||||
|
@ -159,15 +160,13 @@ train_subset_indices=range(int(len(data_train)/2))
|
|||
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||
|
||||
### Augmented Dataset ###
|
||||
data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2)))
|
||||
data_train_aug.augement_data(aug_copy=1)
|
||||
print(data_train_aug)
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True)
|
||||
#data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2)))
|
||||
#data_train_aug.augement_data(aug_copy=10)
|
||||
#print(data_train_aug)
|
||||
#dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
|
||||
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue