Changement permission fichiers + Simplification utilisation Augmented_dataset

This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-04 14:48:11 -05:00
parent adaac437b6
commit b26fbcd2a2
619 changed files with 41 additions and 13049 deletions

15
higher/datasets.py Normal file → Executable file
View file

@ -125,10 +125,11 @@ class AugmentedDataset(VisionDataset):
aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image
#aug_image = augmentation_transforms.cutout_numpy(aug_image)
self.unsup_data+=[aug_image]
self.unsup_data+=[(aug_image*255.).astype(self.sup_data.dtype)]#Cast float image to uint8
self.unsup_targets+=[self.sup_targets[idx]]
self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8
#self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8
self.unsup_data=np.array(self.unsup_data)
self.data= np.concatenate((self.sup_data, self.unsup_data), axis=0)
self.targets= np.concatenate((self.sup_targets, self.unsup_targets), axis=0)
@ -159,15 +160,13 @@ train_subset_indices=range(int(len(data_train)/2))
val_subset_indices=range(int(len(data_train)/2),len(data_train))
#train_subset_indices=range(BATCH_SIZE*10)
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
### Augmented Dataset ###
data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2)))
data_train_aug.augement_data(aug_copy=1)
print(data_train_aug)
dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True)
#data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2)))
#data_train_aug.augement_data(aug_copy=10)
#print(data_train_aug)
#dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True)
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))