Fix cast in Augmented Dataset

This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-04 12:58:11 -05:00
parent 2ee8022c2f
commit adaac437b6
3 changed files with 29 additions and 822 deletions

View file

@ -34,6 +34,8 @@ from PIL import Image
import augmentation_transforms
import numpy as np
download_data=False
class AugmentedDataset(VisionDataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None):
@ -63,9 +65,21 @@ class AugmentedDataset(VisionDataset):
self._TF = [
'Invert', 'Cutout', 'Sharpness', 'AutoContrast', 'Posterize',
'ShearX', 'TranslateX', 'TranslateY', 'ShearY', 'Rotate',
'Equalize', 'Contrast', 'Color', 'Solarize', 'Brightness'
'Invert',
'Cutout',
'Sharpness',
'AutoContrast',
'Posterize',
'ShearX',
'TranslateX',
'TranslateY',
'ShearY',
'Rotate',
'Equalize',
'Contrast',
'Color',
'Solarize',
'Brightness'
]
self._op_list =[]
self.prob=0.5
@ -108,13 +122,13 @@ class AugmentedDataset(VisionDataset):
for _ in range(aug_copy):
chosen_policy = policies[np.random.choice(len(policies))]
aug_image = augmentation_transforms.apply_policy(chosen_policy, image)
aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image
#aug_image = augmentation_transforms.cutout_numpy(aug_image)
self.unsup_data+=[aug_image]
self.unsup_targets+=[self.sup_targets[idx]]
self.unsup_data=np.array(self.unsup_data).astype(self.sup_data.dtype)
self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8
self.data= np.concatenate((self.sup_data, self.unsup_data), axis=0)
self.targets= np.concatenate((self.sup_targets, self.unsup_targets), axis=0)
@ -133,12 +147,12 @@ class AugmentedDataset(VisionDataset):
return self.dataset_info['length']
def __str__(self):
return "CIFAR10(Sup:{}-Unsup:{})".format(self.dataset_info['sup'], self.dataset_info['unsup'])
return "CIFAR10(Sup:{}-Unsup:{}-{}TF)".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF))
### Classic Dataset ###
data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform)
#data_val = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform)
data_test = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=transform)
data_train = torchvision.datasets.CIFAR10("./data", train=True, download=download_data, transform=transform)
#data_val = torchvision.datasets.CIFAR10("./data", train=True, download=download_data, transform=transform)
data_test = torchvision.datasets.CIFAR10("./data", train=False, download=download_data, transform=transform)
train_subset_indices=range(int(len(data_train)/2))
@ -149,8 +163,8 @@ val_subset_indices=range(int(len(data_train)/2),len(data_train))
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
### Augmented Dataset ###
data_train_aug = AugmentedDataset("./data", train=True, download=True, transform=transform, subset=(0,int(len(data_train)/2)))
#data_train_aug.augement_data(aug_copy=1)
data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2)))
data_train_aug.augement_data(aug_copy=1)
print(data_train_aug)
dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True)

View file

@ -1,810 +0,0 @@
{
"Accuracy": 59.15,
"Time": [
2.891577087839999,
0.04480297073115646
],
"Device": "TITAN RTX",
"Log": [
{
"epoch": 0,
"train_loss": 2.238740921020508,
"val_loss": 2.1966593265533447,
"acc": 18.39,
"time": 3.142668161000003,
"param": null
},
{
"epoch": 1,
"train_loss": 2.1030218601226807,
"val_loss": 1.9304202795028687,
"acc": 31.4,
"time": 2.866064168000001,
"param": null
},
{
"epoch": 2,
"train_loss": 2.102165460586548,
"val_loss": 1.7866367101669312,
"acc": 34.08,
"time": 2.8656765240000013,
"param": null
},
{
"epoch": 3,
"train_loss": 1.9763004779815674,
"val_loss": 1.6346896886825562,
"acc": 44.06,
"time": 2.8573803359999985,
"param": null
},
{
"epoch": 4,
"train_loss": 1.9886680841445923,
"val_loss": 1.5025888681411743,
"acc": 45.91,
"time": 2.884063083000001,
"param": null
},
{
"epoch": 5,
"train_loss": 1.8460354804992676,
"val_loss": 1.4987385272979736,
"acc": 47.11,
"time": 2.8803015360000046,
"param": null
},
{
"epoch": 6,
"train_loss": 1.8773751258850098,
"val_loss": 1.3948158025741577,
"acc": 46.54,
"time": 2.8545809470000023,
"param": null
},
{
"epoch": 7,
"train_loss": 1.9329789876937866,
"val_loss": 1.3192082643508911,
"acc": 49.96,
"time": 2.8968874109999945,
"param": null
},
{
"epoch": 8,
"train_loss": 1.877233624458313,
"val_loss": 1.4252185821533203,
"acc": 51.3,
"time": 2.856139868999989,
"param": null
},
{
"epoch": 9,
"train_loss": 1.8099722862243652,
"val_loss": 1.4050142765045166,
"acc": 52.78,
"time": 2.9680558680000075,
"param": null
},
{
"epoch": 10,
"train_loss": 1.7314376831054688,
"val_loss": 1.2530089616775513,
"acc": 54.54,
"time": 2.8540503600000022,
"param": null
},
{
"epoch": 11,
"train_loss": 1.714591383934021,
"val_loss": 1.4025185108184814,
"acc": 54.3,
"time": 2.8576988419999907,
"param": null
},
{
"epoch": 12,
"train_loss": 1.6348106861114502,
"val_loss": 1.2283456325531006,
"acc": 56.1,
"time": 2.8865886950000004,
"param": null
},
{
"epoch": 13,
"train_loss": 1.6358189582824707,
"val_loss": 1.2930848598480225,
"acc": 55.75,
"time": 2.872028806000003,
"param": null
},
{
"epoch": 14,
"train_loss": 1.7428033351898193,
"val_loss": 1.185242772102356,
"acc": 58.11,
"time": 2.8523812309999954,
"param": null
},
{
"epoch": 15,
"train_loss": 1.584743857383728,
"val_loss": 1.2095788717269897,
"acc": 58.13,
"time": 2.8867363259999905,
"param": null
},
{
"epoch": 16,
"train_loss": 1.637380838394165,
"val_loss": 1.2247447967529297,
"acc": 57.89,
"time": 2.927922326000001,
"param": null
},
{
"epoch": 17,
"train_loss": 1.3542273044586182,
"val_loss": 1.1683770418167114,
"acc": 57.79,
"time": 2.9269400579999996,
"param": null
},
{
"epoch": 18,
"train_loss": 1.5338425636291504,
"val_loss": 1.1944950819015503,
"acc": 57.9,
"time": 2.8685170960000193,
"param": null
},
{
"epoch": 19,
"train_loss": 1.406481146812439,
"val_loss": 1.164185643196106,
"acc": 58.67,
"time": 2.8701969949999864,
"param": null
},
{
"epoch": 20,
"train_loss": 1.4967502355575562,
"val_loss": 1.3286499977111816,
"acc": 57.86,
"time": 2.8738660139999865,
"param": null
},
{
"epoch": 21,
"train_loss": 1.4469761848449707,
"val_loss": 1.2245217561721802,
"acc": 59.15,
"time": 2.8806329870000127,
"param": null
},
{
"epoch": 22,
"train_loss": 1.3278372287750244,
"val_loss": 1.3220281600952148,
"acc": 58.65,
"time": 2.8654193290000194,
"param": null
},
{
"epoch": 23,
"train_loss": 1.310671329498291,
"val_loss": 1.4258655309677124,
"acc": 58.62,
"time": 2.893132035000008,
"param": null
},
{
"epoch": 24,
"train_loss": 1.2683414220809937,
"val_loss": 1.2353028059005737,
"acc": 58.18,
"time": 2.8943645079999953,
"param": null
},
{
"epoch": 25,
"train_loss": 1.2522042989730835,
"val_loss": 1.3762693405151367,
"acc": 57.19,
"time": 2.902485732999992,
"param": null
},
{
"epoch": 26,
"train_loss": 1.2245501279830933,
"val_loss": 1.4818042516708374,
"acc": 58.7,
"time": 2.9064084750000063,
"param": null
},
{
"epoch": 27,
"train_loss": 1.0753016471862793,
"val_loss": 1.2684273719787598,
"acc": 57.39,
"time": 2.8864203069999803,
"param": null
},
{
"epoch": 28,
"train_loss": 1.204155445098877,
"val_loss": 1.6392837762832642,
"acc": 58.03,
"time": 2.8522731609999994,
"param": null
},
{
"epoch": 29,
"train_loss": 1.0476032495498657,
"val_loss": 1.9471677541732788,
"acc": 57.36,
"time": 2.8621515780000095,
"param": null
},
{
"epoch": 30,
"train_loss": 1.0602569580078125,
"val_loss": 1.7120074033737183,
"acc": 57.4,
"time": 2.892253502000017,
"param": null
},
{
"epoch": 31,
"train_loss": 0.8457854390144348,
"val_loss": 1.7540744543075562,
"acc": 57.55,
"time": 2.876599782999989,
"param": null
},
{
"epoch": 32,
"train_loss": 0.7922731041908264,
"val_loss": 2.066330671310425,
"acc": 55.91,
"time": 2.8669344470000055,
"param": null
},
{
"epoch": 33,
"train_loss": 0.6974552273750305,
"val_loss": 1.8737841844558716,
"acc": 57.54,
"time": 2.8812602219999803,
"param": null
},
{
"epoch": 34,
"train_loss": 0.6545730829238892,
"val_loss": 2.59554386138916,
"acc": 56.33,
"time": 2.992003705000002,
"param": null
},
{
"epoch": 35,
"train_loss": 0.6207562685012817,
"val_loss": 2.878373861312866,
"acc": 56.64,
"time": 2.880951809999999,
"param": null
},
{
"epoch": 36,
"train_loss": 0.643873929977417,
"val_loss": 2.8013312816619873,
"acc": 56.3,
"time": 2.8770664789999785,
"param": null
},
{
"epoch": 37,
"train_loss": 0.5519305467605591,
"val_loss": 2.3760335445404053,
"acc": 56.42,
"time": 2.860177633999996,
"param": null
},
{
"epoch": 38,
"train_loss": 0.4566989541053772,
"val_loss": 3.261263370513916,
"acc": 56.38,
"time": 2.8768970639999907,
"param": null
},
{
"epoch": 39,
"train_loss": 0.4070911407470703,
"val_loss": 3.162980079650879,
"acc": 55.65,
"time": 2.8753553720000014,
"param": null
},
{
"epoch": 40,
"train_loss": 0.3369447588920593,
"val_loss": 3.6949386596679688,
"acc": 56.16,
"time": 2.8803950129999976,
"param": null
},
{
"epoch": 41,
"train_loss": 0.233261376619339,
"val_loss": 3.1482434272766113,
"acc": 56.81,
"time": 2.958535100000006,
"param": null
},
{
"epoch": 42,
"train_loss": 0.2671070098876953,
"val_loss": 3.852823495864868,
"acc": 56.28,
"time": 2.8413516090000144,
"param": null
},
{
"epoch": 43,
"train_loss": 0.17192059755325317,
"val_loss": 3.9386160373687744,
"acc": 56.43,
"time": 2.882356039000001,
"param": null
},
{
"epoch": 44,
"train_loss": 0.21115460991859436,
"val_loss": 3.674778938293457,
"acc": 56.6,
"time": 2.8627468419999786,
"param": null
},
{
"epoch": 45,
"train_loss": 0.17167963087558746,
"val_loss": 3.678553819656372,
"acc": 55.82,
"time": 2.8513568369999973,
"param": null
},
{
"epoch": 46,
"train_loss": 0.1660262942314148,
"val_loss": 4.021941661834717,
"acc": 56.93,
"time": 2.883665002000015,
"param": null
},
{
"epoch": 47,
"train_loss": 0.14190584421157837,
"val_loss": 4.224549770355225,
"acc": 57.28,
"time": 2.8793066700000054,
"param": null
},
{
"epoch": 48,
"train_loss": 0.13566425442695618,
"val_loss": 3.8729546070098877,
"acc": 56.01,
"time": 2.8533336880000206,
"param": null
},
{
"epoch": 49,
"train_loss": 0.20210444927215576,
"val_loss": 4.181947231292725,
"acc": 56.79,
"time": 2.871362753999989,
"param": null
},
{
"epoch": 50,
"train_loss": 0.17691804468631744,
"val_loss": 4.109798908233643,
"acc": 56.21,
"time": 2.868851712000037,
"param": null
},
{
"epoch": 51,
"train_loss": 0.09427379816770554,
"val_loss": 4.5040059089660645,
"acc": 56.41,
"time": 2.8451589609999814,
"param": null
},
{
"epoch": 52,
"train_loss": 0.1147536113858223,
"val_loss": 4.073654651641846,
"acc": 57.34,
"time": 2.849203299999999,
"param": null
},
{
"epoch": 53,
"train_loss": 0.053406428545713425,
"val_loss": 4.1925435066223145,
"acc": 57.14,
"time": 2.9053060059999893,
"param": null
},
{
"epoch": 54,
"train_loss": 0.09250318259000778,
"val_loss": 4.594822406768799,
"acc": 56.91,
"time": 2.8818305060000284,
"param": null
},
{
"epoch": 55,
"train_loss": 0.040889639407396317,
"val_loss": 4.507361888885498,
"acc": 56.76,
"time": 2.9061625310000068,
"param": null
},
{
"epoch": 56,
"train_loss": 0.033295318484306335,
"val_loss": 4.082192897796631,
"acc": 57.37,
"time": 2.895834288000003,
"param": null
},
{
"epoch": 57,
"train_loss": 0.03692072629928589,
"val_loss": 5.112850189208984,
"acc": 57.24,
"time": 2.905829856999958,
"param": null
},
{
"epoch": 58,
"train_loss": 0.08262491971254349,
"val_loss": 5.430809497833252,
"acc": 57.25,
"time": 2.8767505010000036,
"param": null
},
{
"epoch": 59,
"train_loss": 0.018249312415719032,
"val_loss": 4.355240821838379,
"acc": 56.89,
"time": 2.865761673999998,
"param": null
},
{
"epoch": 60,
"train_loss": 0.03834836557507515,
"val_loss": 4.825883388519287,
"acc": 56.69,
"time": 2.8549083070000165,
"param": null
},
{
"epoch": 61,
"train_loss": 0.07849375903606415,
"val_loss": 4.652444362640381,
"acc": 57.56,
"time": 2.8533472180000103,
"param": null
},
{
"epoch": 62,
"train_loss": 0.04940357804298401,
"val_loss": 5.252978801727295,
"acc": 57.81,
"time": 2.8411212999999975,
"param": null
},
{
"epoch": 63,
"train_loss": 0.035617221146821976,
"val_loss": 4.783238410949707,
"acc": 56.78,
"time": 2.8682083010000383,
"param": null
},
{
"epoch": 64,
"train_loss": 0.04857274517416954,
"val_loss": 4.7291789054870605,
"acc": 57.41,
"time": 2.876162964999992,
"param": null
},
{
"epoch": 65,
"train_loss": 0.02348227985203266,
"val_loss": 4.526753902435303,
"acc": 57.38,
"time": 2.896928296999988,
"param": null
},
{
"epoch": 66,
"train_loss": 0.014324082992970943,
"val_loss": 5.439307689666748,
"acc": 57.11,
"time": 2.8848618569999758,
"param": null
},
{
"epoch": 67,
"train_loss": 0.023290712386369705,
"val_loss": 5.721205234527588,
"acc": 58.09,
"time": 2.8774091299999895,
"param": null
},
{
"epoch": 68,
"train_loss": 0.01625976897776127,
"val_loss": 6.3620829582214355,
"acc": 57.82,
"time": 2.8771049090000247,
"param": null
},
{
"epoch": 69,
"train_loss": 0.024740245193243027,
"val_loss": 4.560676097869873,
"acc": 58.33,
"time": 2.88398736299996,
"param": null
},
{
"epoch": 70,
"train_loss": 0.024858517572283745,
"val_loss": 4.982525825500488,
"acc": 58.51,
"time": 2.8906606280000346,
"param": null
},
{
"epoch": 71,
"train_loss": 0.02504434995353222,
"val_loss": 5.8997955322265625,
"acc": 58.4,
"time": 2.907554915999981,
"param": null
},
{
"epoch": 72,
"train_loss": 0.014550256542861462,
"val_loss": 4.25344705581665,
"acc": 58.63,
"time": 2.909489785000005,
"param": null
},
{
"epoch": 73,
"train_loss": 0.013704107142984867,
"val_loss": 4.419747352600098,
"acc": 58.4,
"time": 2.934780938000017,
"param": null
},
{
"epoch": 74,
"train_loss": 0.0007794767734594643,
"val_loss": 5.364877700805664,
"acc": 58.34,
"time": 2.8807501869999896,
"param": null
},
{
"epoch": 75,
"train_loss": 0.024075284600257874,
"val_loss": 4.890963554382324,
"acc": 58.39,
"time": 2.890908232999948,
"param": null
},
{
"epoch": 76,
"train_loss": 0.0007155562052503228,
"val_loss": 5.429594039916992,
"acc": 58.47,
"time": 2.9814038840000308,
"param": null
},
{
"epoch": 77,
"train_loss": 0.025066755712032318,
"val_loss": 5.024533748626709,
"acc": 58.43,
"time": 2.90384117800005,
"param": null
},
{
"epoch": 78,
"train_loss": 0.024619707837700844,
"val_loss": 5.127645492553711,
"acc": 58.48,
"time": 2.889764621999973,
"param": null
},
{
"epoch": 79,
"train_loss": 0.011625911109149456,
"val_loss": 5.643911361694336,
"acc": 58.44,
"time": 2.8605676199999834,
"param": null
},
{
"epoch": 80,
"train_loss": 0.035187769681215286,
"val_loss": 4.96016263961792,
"acc": 58.22,
"time": 3.0620140229999606,
"param": null
},
{
"epoch": 81,
"train_loss": 0.012686162255704403,
"val_loss": 4.861446380615234,
"acc": 58.29,
"time": 3.0467026740000165,
"param": null
},
{
"epoch": 82,
"train_loss": 0.03516190126538277,
"val_loss": 5.382493495941162,
"acc": 58.59,
"time": 2.948570172000018,
"param": null
},
{
"epoch": 83,
"train_loss": 0.044635310769081116,
"val_loss": 5.259823799133301,
"acc": 58.32,
"time": 2.9322751760000187,
"param": null
},
{
"epoch": 84,
"train_loss": 0.02375461347401142,
"val_loss": 4.794026851654053,
"acc": 58.32,
"time": 2.904587699999979,
"param": null
},
{
"epoch": 85,
"train_loss": 0.0005525348242372274,
"val_loss": 6.069370746612549,
"acc": 58.49,
"time": 2.9163807219999853,
"param": null
},
{
"epoch": 86,
"train_loss": 0.011588472872972488,
"val_loss": 4.31306266784668,
"acc": 58.26,
"time": 2.90974307700003,
"param": null
},
{
"epoch": 87,
"train_loss": 0.04098305106163025,
"val_loss": 5.128759860992432,
"acc": 58.38,
"time": 2.9223316519999685,
"param": null
},
{
"epoch": 88,
"train_loss": 0.0007651126361452043,
"val_loss": 5.77998685836792,
"acc": 58.31,
"time": 2.895599219000019,
"param": null
},
{
"epoch": 89,
"train_loss": 0.01347972359508276,
"val_loss": 5.071125030517578,
"acc": 58.44,
"time": 2.8723491999999737,
"param": null
},
{
"epoch": 90,
"train_loss": 0.0595603846013546,
"val_loss": 5.926196575164795,
"acc": 58.55,
"time": 2.933795826999983,
"param": null
},
{
"epoch": 91,
"train_loss": 0.010888534598052502,
"val_loss": 4.442563533782959,
"acc": 58.29,
"time": 2.873865481999985,
"param": null
},
{
"epoch": 92,
"train_loss": 0.01195440161973238,
"val_loss": 5.685233116149902,
"acc": 58.4,
"time": 2.908180643000037,
"param": null
},
{
"epoch": 93,
"train_loss": 0.012302059680223465,
"val_loss": 5.39638090133667,
"acc": 58.33,
"time": 2.8950546359999976,
"param": null
},
{
"epoch": 94,
"train_loss": 0.022545015439391136,
"val_loss": 4.769997596740723,
"acc": 58.38,
"time": 2.8618274910000423,
"param": null
},
{
"epoch": 95,
"train_loss": 0.032540563493967056,
"val_loss": 6.151810646057129,
"acc": 58.28,
"time": 2.8651707379999607,
"param": null
},
{
"epoch": 96,
"train_loss": 0.012408148497343063,
"val_loss": 5.0970139503479,
"acc": 58.3,
"time": 2.875970102999986,
"param": null
},
{
"epoch": 97,
"train_loss": 0.024478793144226074,
"val_loss": 5.438201904296875,
"acc": 58.36,
"time": 2.868522979999966,
"param": null
},
{
"epoch": 98,
"train_loss": 0.03768819198012352,
"val_loss": 6.296820640563965,
"acc": 58.3,
"time": 2.888704816000029,
"param": null
},
{
"epoch": 99,
"train_loss": 0.03524984419345856,
"val_loss": 5.514108657836914,
"acc": 58.17,
"time": 2.8756691419999925,
"param": null
}
]
}

View file

@ -69,7 +69,7 @@ if __name__ == "__main__":
'aug_dataset',
#'aug_model'
}
n_inner_iter = 0
n_inner_iter = 1
epochs = 100
dataug_epoch_start=0
@ -104,6 +104,9 @@ if __name__ == "__main__":
#### Augmented Dataset ####
if 'aug_dataset' in tasks:
xs, ys = next(iter(dl_train))
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug)))
t0 = time.process_time()
model = LeNet(3,10).to(device)
#model = WideResNet(num_classes=10, wrn_size=16).to(device)