From adaac437b66f9c18fe73722bc4551592a1062186 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Wed, 4 Dec 2019 12:58:11 -0500 Subject: [PATCH] Fix cast in Augmented Dataset --- higher/datasets.py | 36 +- ...p:25000-Unsup:25000)-LeNet-100 epochs.json | 810 ------------------ higher/test_dataug.py | 5 +- 3 files changed, 29 insertions(+), 822 deletions(-) delete mode 100644 higher/res/log/CIFAR10(Sup:25000-Unsup:25000)-LeNet-100 epochs.json diff --git a/higher/datasets.py b/higher/datasets.py index 9b00973..a04588e 100644 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -34,6 +34,8 @@ from PIL import Image import augmentation_transforms import numpy as np +download_data=False + class AugmentedDataset(VisionDataset): def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None): @@ -63,9 +65,21 @@ class AugmentedDataset(VisionDataset): self._TF = [ - 'Invert', 'Cutout', 'Sharpness', 'AutoContrast', 'Posterize', - 'ShearX', 'TranslateX', 'TranslateY', 'ShearY', 'Rotate', - 'Equalize', 'Contrast', 'Color', 'Solarize', 'Brightness' + 'Invert', + 'Cutout', + 'Sharpness', + 'AutoContrast', + 'Posterize', + 'ShearX', + 'TranslateX', + 'TranslateY', + 'ShearY', + 'Rotate', + 'Equalize', + 'Contrast', + 'Color', + 'Solarize', + 'Brightness' ] self._op_list =[] self.prob=0.5 @@ -108,13 +122,13 @@ class AugmentedDataset(VisionDataset): for _ in range(aug_copy): chosen_policy = policies[np.random.choice(len(policies))] - aug_image = augmentation_transforms.apply_policy(chosen_policy, image) + aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image #aug_image = augmentation_transforms.cutout_numpy(aug_image) self.unsup_data+=[aug_image] self.unsup_targets+=[self.sup_targets[idx]] - self.unsup_data=np.array(self.unsup_data).astype(self.sup_data.dtype) + self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8 self.data= np.concatenate((self.sup_data, self.unsup_data), axis=0) self.targets= np.concatenate((self.sup_targets, self.unsup_targets), axis=0) @@ -133,12 +147,12 @@ class AugmentedDataset(VisionDataset): return self.dataset_info['length'] def __str__(self): - return "CIFAR10(Sup:{}-Unsup:{})".format(self.dataset_info['sup'], self.dataset_info['unsup']) + return "CIFAR10(Sup:{}-Unsup:{}-{}TF)".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF)) ### Classic Dataset ### -data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform) -#data_val = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform) -data_test = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=transform) +data_train = torchvision.datasets.CIFAR10("./data", train=True, download=download_data, transform=transform) +#data_val = torchvision.datasets.CIFAR10("./data", train=True, download=download_data, transform=transform) +data_test = torchvision.datasets.CIFAR10("./data", train=False, download=download_data, transform=transform) train_subset_indices=range(int(len(data_train)/2)) @@ -149,8 +163,8 @@ val_subset_indices=range(int(len(data_train)/2),len(data_train)) dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices)) ### Augmented Dataset ### -data_train_aug = AugmentedDataset("./data", train=True, download=True, transform=transform, subset=(0,int(len(data_train)/2))) -#data_train_aug.augement_data(aug_copy=1) +data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2))) +data_train_aug.augement_data(aug_copy=1) print(data_train_aug) dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True) diff --git a/higher/res/log/CIFAR10(Sup:25000-Unsup:25000)-LeNet-100 epochs.json b/higher/res/log/CIFAR10(Sup:25000-Unsup:25000)-LeNet-100 epochs.json deleted file mode 100644 index 62206a8..0000000 --- a/higher/res/log/CIFAR10(Sup:25000-Unsup:25000)-LeNet-100 epochs.json +++ /dev/null @@ -1,810 +0,0 @@ -{ - "Accuracy": 59.15, - "Time": [ - 2.891577087839999, - 0.04480297073115646 - ], - "Device": "TITAN RTX", - "Log": [ - { - "epoch": 0, - "train_loss": 2.238740921020508, - "val_loss": 2.1966593265533447, - "acc": 18.39, - "time": 3.142668161000003, - "param": null - }, - { - "epoch": 1, - "train_loss": 2.1030218601226807, - "val_loss": 1.9304202795028687, - "acc": 31.4, - "time": 2.866064168000001, - "param": null - }, - { - "epoch": 2, - "train_loss": 2.102165460586548, - "val_loss": 1.7866367101669312, - "acc": 34.08, - "time": 2.8656765240000013, - "param": null - }, - { - "epoch": 3, - "train_loss": 1.9763004779815674, - "val_loss": 1.6346896886825562, - "acc": 44.06, - "time": 2.8573803359999985, - "param": null - }, - { - "epoch": 4, - "train_loss": 1.9886680841445923, - "val_loss": 1.5025888681411743, - "acc": 45.91, - "time": 2.884063083000001, - "param": null - }, - { - "epoch": 5, - "train_loss": 1.8460354804992676, - "val_loss": 1.4987385272979736, - "acc": 47.11, - "time": 2.8803015360000046, - "param": null - }, - { - "epoch": 6, - "train_loss": 1.8773751258850098, - "val_loss": 1.3948158025741577, - "acc": 46.54, - "time": 2.8545809470000023, - "param": null - }, - { - "epoch": 7, - "train_loss": 1.9329789876937866, - "val_loss": 1.3192082643508911, - "acc": 49.96, - "time": 2.8968874109999945, - "param": null - }, - { - "epoch": 8, - "train_loss": 1.877233624458313, - "val_loss": 1.4252185821533203, - "acc": 51.3, - "time": 2.856139868999989, - "param": null - }, - { - "epoch": 9, - "train_loss": 1.8099722862243652, - "val_loss": 1.4050142765045166, - "acc": 52.78, - "time": 2.9680558680000075, - "param": null - }, - { - "epoch": 10, - "train_loss": 1.7314376831054688, - "val_loss": 1.2530089616775513, - "acc": 54.54, - "time": 2.8540503600000022, - "param": null - }, - { - "epoch": 11, - "train_loss": 1.714591383934021, - "val_loss": 1.4025185108184814, - "acc": 54.3, - "time": 2.8576988419999907, - "param": null - }, - { - "epoch": 12, - "train_loss": 1.6348106861114502, - "val_loss": 1.2283456325531006, - "acc": 56.1, - "time": 2.8865886950000004, - "param": null - }, - { - "epoch": 13, - "train_loss": 1.6358189582824707, - "val_loss": 1.2930848598480225, - "acc": 55.75, - "time": 2.872028806000003, - "param": null - }, - { - "epoch": 14, - "train_loss": 1.7428033351898193, - "val_loss": 1.185242772102356, - "acc": 58.11, - "time": 2.8523812309999954, - "param": null - }, - { - "epoch": 15, - "train_loss": 1.584743857383728, - "val_loss": 1.2095788717269897, - "acc": 58.13, - "time": 2.8867363259999905, - "param": null - }, - { - "epoch": 16, - "train_loss": 1.637380838394165, - "val_loss": 1.2247447967529297, - "acc": 57.89, - "time": 2.927922326000001, - "param": null - }, - { - "epoch": 17, - "train_loss": 1.3542273044586182, - "val_loss": 1.1683770418167114, - "acc": 57.79, - "time": 2.9269400579999996, - "param": null - }, - { - "epoch": 18, - "train_loss": 1.5338425636291504, - "val_loss": 1.1944950819015503, - "acc": 57.9, - "time": 2.8685170960000193, - "param": null - }, - { - "epoch": 19, - "train_loss": 1.406481146812439, - "val_loss": 1.164185643196106, - "acc": 58.67, - "time": 2.8701969949999864, - "param": null - }, - { - "epoch": 20, - "train_loss": 1.4967502355575562, - "val_loss": 1.3286499977111816, - "acc": 57.86, - "time": 2.8738660139999865, - "param": null - }, - { - "epoch": 21, - "train_loss": 1.4469761848449707, - "val_loss": 1.2245217561721802, - "acc": 59.15, - "time": 2.8806329870000127, - "param": null - }, - { - "epoch": 22, - "train_loss": 1.3278372287750244, - "val_loss": 1.3220281600952148, - "acc": 58.65, - "time": 2.8654193290000194, - "param": null - }, - { - "epoch": 23, - "train_loss": 1.310671329498291, - "val_loss": 1.4258655309677124, - "acc": 58.62, - "time": 2.893132035000008, - "param": null - }, - { - "epoch": 24, - "train_loss": 1.2683414220809937, - "val_loss": 1.2353028059005737, - "acc": 58.18, - "time": 2.8943645079999953, - "param": null - }, - { - "epoch": 25, - "train_loss": 1.2522042989730835, - "val_loss": 1.3762693405151367, - "acc": 57.19, - "time": 2.902485732999992, - "param": null - }, - { - "epoch": 26, - "train_loss": 1.2245501279830933, - "val_loss": 1.4818042516708374, - "acc": 58.7, - "time": 2.9064084750000063, - "param": null - }, - { - "epoch": 27, - "train_loss": 1.0753016471862793, - "val_loss": 1.2684273719787598, - "acc": 57.39, - "time": 2.8864203069999803, - "param": null - }, - { - "epoch": 28, - "train_loss": 1.204155445098877, - "val_loss": 1.6392837762832642, - "acc": 58.03, - "time": 2.8522731609999994, - "param": null - }, - { - "epoch": 29, - "train_loss": 1.0476032495498657, - "val_loss": 1.9471677541732788, - "acc": 57.36, - "time": 2.8621515780000095, - "param": null - }, - { - "epoch": 30, - "train_loss": 1.0602569580078125, - "val_loss": 1.7120074033737183, - "acc": 57.4, - "time": 2.892253502000017, - "param": null - }, - { - "epoch": 31, - "train_loss": 0.8457854390144348, - "val_loss": 1.7540744543075562, - "acc": 57.55, - "time": 2.876599782999989, - "param": null - }, - { - "epoch": 32, - "train_loss": 0.7922731041908264, - "val_loss": 2.066330671310425, - "acc": 55.91, - "time": 2.8669344470000055, - "param": null - }, - { - "epoch": 33, - "train_loss": 0.6974552273750305, - "val_loss": 1.8737841844558716, - "acc": 57.54, - "time": 2.8812602219999803, - "param": null - }, - { - "epoch": 34, - "train_loss": 0.6545730829238892, - "val_loss": 2.59554386138916, - "acc": 56.33, - "time": 2.992003705000002, - "param": null - }, - { - "epoch": 35, - "train_loss": 0.6207562685012817, - "val_loss": 2.878373861312866, - "acc": 56.64, - "time": 2.880951809999999, - "param": null - }, - { - "epoch": 36, - "train_loss": 0.643873929977417, - "val_loss": 2.8013312816619873, - "acc": 56.3, - "time": 2.8770664789999785, - "param": null - }, - { - "epoch": 37, - "train_loss": 0.5519305467605591, - "val_loss": 2.3760335445404053, - "acc": 56.42, - "time": 2.860177633999996, - "param": null - }, - { - "epoch": 38, - "train_loss": 0.4566989541053772, - "val_loss": 3.261263370513916, - "acc": 56.38, - "time": 2.8768970639999907, - "param": null - }, - { - "epoch": 39, - "train_loss": 0.4070911407470703, - "val_loss": 3.162980079650879, - "acc": 55.65, - "time": 2.8753553720000014, - "param": null - }, - { - "epoch": 40, - "train_loss": 0.3369447588920593, - "val_loss": 3.6949386596679688, - "acc": 56.16, - "time": 2.8803950129999976, - "param": null - }, - { - "epoch": 41, - "train_loss": 0.233261376619339, - "val_loss": 3.1482434272766113, - "acc": 56.81, - "time": 2.958535100000006, - "param": null - }, - { - "epoch": 42, - "train_loss": 0.2671070098876953, - "val_loss": 3.852823495864868, - "acc": 56.28, - "time": 2.8413516090000144, - "param": null - }, - { - "epoch": 43, - "train_loss": 0.17192059755325317, - "val_loss": 3.9386160373687744, - "acc": 56.43, - "time": 2.882356039000001, - "param": null - }, - { - "epoch": 44, - "train_loss": 0.21115460991859436, - "val_loss": 3.674778938293457, - "acc": 56.6, - "time": 2.8627468419999786, - "param": null - }, - { - "epoch": 45, - "train_loss": 0.17167963087558746, - "val_loss": 3.678553819656372, - "acc": 55.82, - "time": 2.8513568369999973, - "param": null - }, - { - "epoch": 46, - "train_loss": 0.1660262942314148, - "val_loss": 4.021941661834717, - "acc": 56.93, - "time": 2.883665002000015, - "param": null - }, - { - "epoch": 47, - "train_loss": 0.14190584421157837, - "val_loss": 4.224549770355225, - "acc": 57.28, - "time": 2.8793066700000054, - "param": null - }, - { - "epoch": 48, - "train_loss": 0.13566425442695618, - "val_loss": 3.8729546070098877, - "acc": 56.01, - "time": 2.8533336880000206, - "param": null - }, - { - "epoch": 49, - "train_loss": 0.20210444927215576, - "val_loss": 4.181947231292725, - "acc": 56.79, - "time": 2.871362753999989, - "param": null - }, - { - "epoch": 50, - "train_loss": 0.17691804468631744, - "val_loss": 4.109798908233643, - "acc": 56.21, - "time": 2.868851712000037, - "param": null - }, - { - "epoch": 51, - "train_loss": 0.09427379816770554, - "val_loss": 4.5040059089660645, - "acc": 56.41, - "time": 2.8451589609999814, - "param": null - }, - { - "epoch": 52, - "train_loss": 0.1147536113858223, - "val_loss": 4.073654651641846, - "acc": 57.34, - "time": 2.849203299999999, - "param": null - }, - { - "epoch": 53, - "train_loss": 0.053406428545713425, - "val_loss": 4.1925435066223145, - "acc": 57.14, - "time": 2.9053060059999893, - "param": null - }, - { - "epoch": 54, - "train_loss": 0.09250318259000778, - "val_loss": 4.594822406768799, - "acc": 56.91, - "time": 2.8818305060000284, - "param": null - }, - { - "epoch": 55, - "train_loss": 0.040889639407396317, - "val_loss": 4.507361888885498, - "acc": 56.76, - "time": 2.9061625310000068, - "param": null - }, - { - "epoch": 56, - "train_loss": 0.033295318484306335, - "val_loss": 4.082192897796631, - "acc": 57.37, - "time": 2.895834288000003, - "param": null - }, - { - "epoch": 57, - "train_loss": 0.03692072629928589, - "val_loss": 5.112850189208984, - "acc": 57.24, - "time": 2.905829856999958, - "param": null - }, - { - "epoch": 58, - "train_loss": 0.08262491971254349, - "val_loss": 5.430809497833252, - "acc": 57.25, - "time": 2.8767505010000036, - "param": null - }, - { - "epoch": 59, - "train_loss": 0.018249312415719032, - "val_loss": 4.355240821838379, - "acc": 56.89, - "time": 2.865761673999998, - "param": null - }, - { - "epoch": 60, - "train_loss": 0.03834836557507515, - "val_loss": 4.825883388519287, - "acc": 56.69, - "time": 2.8549083070000165, - "param": null - }, - { - "epoch": 61, - "train_loss": 0.07849375903606415, - "val_loss": 4.652444362640381, - "acc": 57.56, - "time": 2.8533472180000103, - "param": null - }, - { - "epoch": 62, - "train_loss": 0.04940357804298401, - "val_loss": 5.252978801727295, - "acc": 57.81, - "time": 2.8411212999999975, - "param": null - }, - { - "epoch": 63, - "train_loss": 0.035617221146821976, - "val_loss": 4.783238410949707, - "acc": 56.78, - "time": 2.8682083010000383, - "param": null - }, - { - "epoch": 64, - "train_loss": 0.04857274517416954, - "val_loss": 4.7291789054870605, - "acc": 57.41, - "time": 2.876162964999992, - "param": null - }, - { - "epoch": 65, - "train_loss": 0.02348227985203266, - "val_loss": 4.526753902435303, - "acc": 57.38, - "time": 2.896928296999988, - "param": null - }, - { - "epoch": 66, - "train_loss": 0.014324082992970943, - "val_loss": 5.439307689666748, - "acc": 57.11, - "time": 2.8848618569999758, - "param": null - }, - { - "epoch": 67, - "train_loss": 0.023290712386369705, - "val_loss": 5.721205234527588, - "acc": 58.09, - "time": 2.8774091299999895, - "param": null - }, - { - "epoch": 68, - "train_loss": 0.01625976897776127, - "val_loss": 6.3620829582214355, - "acc": 57.82, - "time": 2.8771049090000247, - "param": null - }, - { - "epoch": 69, - "train_loss": 0.024740245193243027, - "val_loss": 4.560676097869873, - "acc": 58.33, - "time": 2.88398736299996, - "param": null - }, - { - "epoch": 70, - "train_loss": 0.024858517572283745, - "val_loss": 4.982525825500488, - "acc": 58.51, - "time": 2.8906606280000346, - "param": null - }, - { - "epoch": 71, - "train_loss": 0.02504434995353222, - "val_loss": 5.8997955322265625, - "acc": 58.4, - "time": 2.907554915999981, - "param": null - }, - { - "epoch": 72, - "train_loss": 0.014550256542861462, - "val_loss": 4.25344705581665, - "acc": 58.63, - "time": 2.909489785000005, - "param": null - }, - { - "epoch": 73, - "train_loss": 0.013704107142984867, - "val_loss": 4.419747352600098, - "acc": 58.4, - "time": 2.934780938000017, - "param": null - }, - { - "epoch": 74, - "train_loss": 0.0007794767734594643, - "val_loss": 5.364877700805664, - "acc": 58.34, - "time": 2.8807501869999896, - "param": null - }, - { - "epoch": 75, - "train_loss": 0.024075284600257874, - "val_loss": 4.890963554382324, - "acc": 58.39, - "time": 2.890908232999948, - "param": null - }, - { - "epoch": 76, - "train_loss": 0.0007155562052503228, - "val_loss": 5.429594039916992, - "acc": 58.47, - "time": 2.9814038840000308, - "param": null - }, - { - "epoch": 77, - "train_loss": 0.025066755712032318, - "val_loss": 5.024533748626709, - "acc": 58.43, - "time": 2.90384117800005, - "param": null - }, - { - "epoch": 78, - "train_loss": 0.024619707837700844, - "val_loss": 5.127645492553711, - "acc": 58.48, - "time": 2.889764621999973, - "param": null - }, - { - "epoch": 79, - "train_loss": 0.011625911109149456, - "val_loss": 5.643911361694336, - "acc": 58.44, - "time": 2.8605676199999834, - "param": null - }, - { - "epoch": 80, - "train_loss": 0.035187769681215286, - "val_loss": 4.96016263961792, - "acc": 58.22, - "time": 3.0620140229999606, - "param": null - }, - { - "epoch": 81, - "train_loss": 0.012686162255704403, - "val_loss": 4.861446380615234, - "acc": 58.29, - "time": 3.0467026740000165, - "param": null - }, - { - "epoch": 82, - "train_loss": 0.03516190126538277, - "val_loss": 5.382493495941162, - "acc": 58.59, - "time": 2.948570172000018, - "param": null - }, - { - "epoch": 83, - "train_loss": 0.044635310769081116, - "val_loss": 5.259823799133301, - "acc": 58.32, - "time": 2.9322751760000187, - "param": null - }, - { - "epoch": 84, - "train_loss": 0.02375461347401142, - "val_loss": 4.794026851654053, - "acc": 58.32, - "time": 2.904587699999979, - "param": null - }, - { - "epoch": 85, - "train_loss": 0.0005525348242372274, - "val_loss": 6.069370746612549, - "acc": 58.49, - "time": 2.9163807219999853, - "param": null - }, - { - "epoch": 86, - "train_loss": 0.011588472872972488, - "val_loss": 4.31306266784668, - "acc": 58.26, - "time": 2.90974307700003, - "param": null - }, - { - "epoch": 87, - "train_loss": 0.04098305106163025, - "val_loss": 5.128759860992432, - "acc": 58.38, - "time": 2.9223316519999685, - "param": null - }, - { - "epoch": 88, - "train_loss": 0.0007651126361452043, - "val_loss": 5.77998685836792, - "acc": 58.31, - "time": 2.895599219000019, - "param": null - }, - { - "epoch": 89, - "train_loss": 0.01347972359508276, - "val_loss": 5.071125030517578, - "acc": 58.44, - "time": 2.8723491999999737, - "param": null - }, - { - "epoch": 90, - "train_loss": 0.0595603846013546, - "val_loss": 5.926196575164795, - "acc": 58.55, - "time": 2.933795826999983, - "param": null - }, - { - "epoch": 91, - "train_loss": 0.010888534598052502, - "val_loss": 4.442563533782959, - "acc": 58.29, - "time": 2.873865481999985, - "param": null - }, - { - "epoch": 92, - "train_loss": 0.01195440161973238, - "val_loss": 5.685233116149902, - "acc": 58.4, - "time": 2.908180643000037, - "param": null - }, - { - "epoch": 93, - "train_loss": 0.012302059680223465, - "val_loss": 5.39638090133667, - "acc": 58.33, - "time": 2.8950546359999976, - "param": null - }, - { - "epoch": 94, - "train_loss": 0.022545015439391136, - "val_loss": 4.769997596740723, - "acc": 58.38, - "time": 2.8618274910000423, - "param": null - }, - { - "epoch": 95, - "train_loss": 0.032540563493967056, - "val_loss": 6.151810646057129, - "acc": 58.28, - "time": 2.8651707379999607, - "param": null - }, - { - "epoch": 96, - "train_loss": 0.012408148497343063, - "val_loss": 5.0970139503479, - "acc": 58.3, - "time": 2.875970102999986, - "param": null - }, - { - "epoch": 97, - "train_loss": 0.024478793144226074, - "val_loss": 5.438201904296875, - "acc": 58.36, - "time": 2.868522979999966, - "param": null - }, - { - "epoch": 98, - "train_loss": 0.03768819198012352, - "val_loss": 6.296820640563965, - "acc": 58.3, - "time": 2.888704816000029, - "param": null - }, - { - "epoch": 99, - "train_loss": 0.03524984419345856, - "val_loss": 5.514108657836914, - "acc": 58.17, - "time": 2.8756691419999925, - "param": null - } - ] -} \ No newline at end of file diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 5782ae2..aee7999 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -69,7 +69,7 @@ if __name__ == "__main__": 'aug_dataset', #'aug_model' } - n_inner_iter = 0 + n_inner_iter = 1 epochs = 100 dataug_epoch_start=0 @@ -104,6 +104,9 @@ if __name__ == "__main__": #### Augmented Dataset #### if 'aug_dataset' in tasks: + + xs, ys = next(iter(dl_train)) + viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug))) t0 = time.process_time() model = LeNet(3,10).to(device) #model = WideResNet(num_classes=10, wrn_size=16).to(device)