diff --git a/higher/datasets.py b/higher/datasets.py index 3688eda..09fa1ee 100755 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -197,8 +197,8 @@ class AugmentedDatasetV2(VisionDataset): 'Color', 'Brightness', 'Sharpness', - #'Posterize', - #'Solarize', + 'Posterize', + 'Solarize', 'Invert', 'AutoContrast', @@ -206,8 +206,9 @@ class AugmentedDatasetV2(VisionDataset): ] self._op_list =[] self.prob=0.5 + self.mag_range=(1, 10) for tf in self._TF: - for mag in range(1, 10): + for mag in range(self.mag_range[0], self.mag_range[1]): self._op_list+=[(tf, self.prob, mag)] self._nb_op = len(self._op_list) @@ -267,7 +268,7 @@ class AugmentedDatasetV2(VisionDataset): return self.dataset_info['unsup']#self.dataset_info['length'] def __str__(self): - return "CIFAR10(Sup:{}-Unsup:{}-{}TF)".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF)) + return "CIFAR10(Sup:{}-Unsup:{}-{}TF(Mag{}-{}))".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF), self.mag_range[0], self.mag_range[1]) ### Classic Dataset ###