Augmented Dataset fonctionnel

This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-04 12:28:32 -05:00
parent 33ef7afd04
commit 2ee8022c2f
26 changed files with 64488 additions and 123 deletions

View file

@ -35,14 +35,15 @@ import augmentation_transforms
import numpy as np
class AugmentedDataset(VisionDataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None):
super(AugmentedDataset, self).__init__(root, transform=transform, target_transform=target_transform)
supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform)
self.sup_data = supervised_dataset.data
self.sup_targets = supervised_dataset.targets
self.sup_data = supervised_dataset.data if not subset else supervised_dataset.data[subset[0]:subset[1]]
self.sup_targets = supervised_dataset.targets if not subset else supervised_dataset.targets[subset[0]:subset[1]]
assert len(self.sup_data)==len(self.sup_targets)
for idx, img in enumerate(self.sup_data):
self.sup_data[idx]= Image.fromarray(img) #to PIL Image
@ -53,11 +54,19 @@ class AugmentedDataset(VisionDataset):
self.data= self.sup_data
self.targets= self.sup_targets
self.dataset_info= {
'name': 'CIFAR10',
'sup': len(self.sup_data),
'unsup': len(self.unsup_data),
'length': len(self.sup_data)+len(self.unsup_data),
}
self._TF = [
'Invert', 'Cutout', 'Sharpness', 'AutoContrast', 'Posterize',
'ShearX', 'TranslateX', 'TranslateY', 'ShearY', 'Rotate',
'Equalize', 'Contrast', 'Color', 'Solarize', 'Brightness']
'Equalize', 'Contrast', 'Color', 'Solarize', 'Brightness'
]
self._op_list =[]
self.prob=0.5
for tf in self._TF:
@ -95,6 +104,8 @@ class AugmentedDataset(VisionDataset):
policies += [[op_1, op_2]]
for idx, image in enumerate(self.sup_data):
if (idx/self.dataset_info['sup'])%0.2==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup'])
for _ in range(aug_copy):
chosen_policy = policies[np.random.choice(len(policies))]
aug_image = augmentation_transforms.apply_policy(chosen_policy, image)
@ -103,42 +114,47 @@ class AugmentedDataset(VisionDataset):
self.unsup_data+=[aug_image]
self.unsup_targets+=[self.sup_targets[idx]]
print(type(self.data), type(self.sup_data), type(self.unsup_data))
print(len(self.data), len(self.sup_data), len(self.unsup_data))
#self.data= self.sup_data+self.unsup_data
self.unsup_data=np.array(self.unsup_data).astype(self.sup_data.dtype)
self.data= np.concatenate((self.sup_data, self.unsup_data), axis=0)
print(len(self.data))
self.targets= self.sup_targets+self.unsup_targets
self.targets= np.concatenate((self.sup_targets, self.unsup_targets), axis=0)
assert len(self.unsup_data)==len(self.unsup_targets)
assert len(self.data)==len(self.targets)
self.dataset_info['unsup']=len(self.unsup_data)
self.dataset_info['length']=self.dataset_info['sup']+self.dataset_info['unsup']
def len_supervised(self):
return len(self.sup_data)
return self.dataset_info['sup']
def len_unsupervised(self):
return len(self.unsup_data)
return self.dataset_info['unsup']
def __len__(self):
return len(self.data)
return self.dataset_info['length']
def __str__(self):
return "CIFAR10(Sup:{}-Unsup:{})".format(self.dataset_info['sup'], self.dataset_info['unsup'])
### Classic Dataset ###
data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform)
#print(len(data_train))
#data_train = AugmentedDataset("./data", train=True, download=True, transform=transform)
#print(len(data_train), data_train.len_supervised(), data_train.len_unsupervised())
#data_train.augement_data()
#print(len(data_train), data_train.len_supervised(), data_train.len_unsupervised())
#data_val = torchvision.datasets.CIFAR10(
# "./data", train=True, download=True, transform=transform
#)
data_test = torchvision.datasets.CIFAR10(
"./data", train=False, download=True, transform=transform
)
#'''
#data_val = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform)
data_test = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=transform)
train_subset_indices=range(int(len(data_train)/2))
val_subset_indices=range(int(len(data_train)/2),len(data_train))
#train_subset_indices=range(BATCH_SIZE*10)
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
### Augmented Dataset ###
data_train_aug = AugmentedDataset("./data", train=True, download=True, transform=transform, subset=(0,int(len(data_train)/2)))
#data_train_aug.augement_data(aug_copy=1)
print(data_train_aug)
dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True)
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False)

Binary file not shown.

After

Width:  |  Height:  |  Size: 179 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 392 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 446 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 324 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 191 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 200 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 203 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 292 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 192 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 293 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 408 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB

View file

@ -0,0 +1,810 @@
{
"Accuracy": 59.15,
"Time": [
2.891577087839999,
0.04480297073115646
],
"Device": "TITAN RTX",
"Log": [
{
"epoch": 0,
"train_loss": 2.238740921020508,
"val_loss": 2.1966593265533447,
"acc": 18.39,
"time": 3.142668161000003,
"param": null
},
{
"epoch": 1,
"train_loss": 2.1030218601226807,
"val_loss": 1.9304202795028687,
"acc": 31.4,
"time": 2.866064168000001,
"param": null
},
{
"epoch": 2,
"train_loss": 2.102165460586548,
"val_loss": 1.7866367101669312,
"acc": 34.08,
"time": 2.8656765240000013,
"param": null
},
{
"epoch": 3,
"train_loss": 1.9763004779815674,
"val_loss": 1.6346896886825562,
"acc": 44.06,
"time": 2.8573803359999985,
"param": null
},
{
"epoch": 4,
"train_loss": 1.9886680841445923,
"val_loss": 1.5025888681411743,
"acc": 45.91,
"time": 2.884063083000001,
"param": null
},
{
"epoch": 5,
"train_loss": 1.8460354804992676,
"val_loss": 1.4987385272979736,
"acc": 47.11,
"time": 2.8803015360000046,
"param": null
},
{
"epoch": 6,
"train_loss": 1.8773751258850098,
"val_loss": 1.3948158025741577,
"acc": 46.54,
"time": 2.8545809470000023,
"param": null
},
{
"epoch": 7,
"train_loss": 1.9329789876937866,
"val_loss": 1.3192082643508911,
"acc": 49.96,
"time": 2.8968874109999945,
"param": null
},
{
"epoch": 8,
"train_loss": 1.877233624458313,
"val_loss": 1.4252185821533203,
"acc": 51.3,
"time": 2.856139868999989,
"param": null
},
{
"epoch": 9,
"train_loss": 1.8099722862243652,
"val_loss": 1.4050142765045166,
"acc": 52.78,
"time": 2.9680558680000075,
"param": null
},
{
"epoch": 10,
"train_loss": 1.7314376831054688,
"val_loss": 1.2530089616775513,
"acc": 54.54,
"time": 2.8540503600000022,
"param": null
},
{
"epoch": 11,
"train_loss": 1.714591383934021,
"val_loss": 1.4025185108184814,
"acc": 54.3,
"time": 2.8576988419999907,
"param": null
},
{
"epoch": 12,
"train_loss": 1.6348106861114502,
"val_loss": 1.2283456325531006,
"acc": 56.1,
"time": 2.8865886950000004,
"param": null
},
{
"epoch": 13,
"train_loss": 1.6358189582824707,
"val_loss": 1.2930848598480225,
"acc": 55.75,
"time": 2.872028806000003,
"param": null
},
{
"epoch": 14,
"train_loss": 1.7428033351898193,
"val_loss": 1.185242772102356,
"acc": 58.11,
"time": 2.8523812309999954,
"param": null
},
{
"epoch": 15,
"train_loss": 1.584743857383728,
"val_loss": 1.2095788717269897,
"acc": 58.13,
"time": 2.8867363259999905,
"param": null
},
{
"epoch": 16,
"train_loss": 1.637380838394165,
"val_loss": 1.2247447967529297,
"acc": 57.89,
"time": 2.927922326000001,
"param": null
},
{
"epoch": 17,
"train_loss": 1.3542273044586182,
"val_loss": 1.1683770418167114,
"acc": 57.79,
"time": 2.9269400579999996,
"param": null
},
{
"epoch": 18,
"train_loss": 1.5338425636291504,
"val_loss": 1.1944950819015503,
"acc": 57.9,
"time": 2.8685170960000193,
"param": null
},
{
"epoch": 19,
"train_loss": 1.406481146812439,
"val_loss": 1.164185643196106,
"acc": 58.67,
"time": 2.8701969949999864,
"param": null
},
{
"epoch": 20,
"train_loss": 1.4967502355575562,
"val_loss": 1.3286499977111816,
"acc": 57.86,
"time": 2.8738660139999865,
"param": null
},
{
"epoch": 21,
"train_loss": 1.4469761848449707,
"val_loss": 1.2245217561721802,
"acc": 59.15,
"time": 2.8806329870000127,
"param": null
},
{
"epoch": 22,
"train_loss": 1.3278372287750244,
"val_loss": 1.3220281600952148,
"acc": 58.65,
"time": 2.8654193290000194,
"param": null
},
{
"epoch": 23,
"train_loss": 1.310671329498291,
"val_loss": 1.4258655309677124,
"acc": 58.62,
"time": 2.893132035000008,
"param": null
},
{
"epoch": 24,
"train_loss": 1.2683414220809937,
"val_loss": 1.2353028059005737,
"acc": 58.18,
"time": 2.8943645079999953,
"param": null
},
{
"epoch": 25,
"train_loss": 1.2522042989730835,
"val_loss": 1.3762693405151367,
"acc": 57.19,
"time": 2.902485732999992,
"param": null
},
{
"epoch": 26,
"train_loss": 1.2245501279830933,
"val_loss": 1.4818042516708374,
"acc": 58.7,
"time": 2.9064084750000063,
"param": null
},
{
"epoch": 27,
"train_loss": 1.0753016471862793,
"val_loss": 1.2684273719787598,
"acc": 57.39,
"time": 2.8864203069999803,
"param": null
},
{
"epoch": 28,
"train_loss": 1.204155445098877,
"val_loss": 1.6392837762832642,
"acc": 58.03,
"time": 2.8522731609999994,
"param": null
},
{
"epoch": 29,
"train_loss": 1.0476032495498657,
"val_loss": 1.9471677541732788,
"acc": 57.36,
"time": 2.8621515780000095,
"param": null
},
{
"epoch": 30,
"train_loss": 1.0602569580078125,
"val_loss": 1.7120074033737183,
"acc": 57.4,
"time": 2.892253502000017,
"param": null
},
{
"epoch": 31,
"train_loss": 0.8457854390144348,
"val_loss": 1.7540744543075562,
"acc": 57.55,
"time": 2.876599782999989,
"param": null
},
{
"epoch": 32,
"train_loss": 0.7922731041908264,
"val_loss": 2.066330671310425,
"acc": 55.91,
"time": 2.8669344470000055,
"param": null
},
{
"epoch": 33,
"train_loss": 0.6974552273750305,
"val_loss": 1.8737841844558716,
"acc": 57.54,
"time": 2.8812602219999803,
"param": null
},
{
"epoch": 34,
"train_loss": 0.6545730829238892,
"val_loss": 2.59554386138916,
"acc": 56.33,
"time": 2.992003705000002,
"param": null
},
{
"epoch": 35,
"train_loss": 0.6207562685012817,
"val_loss": 2.878373861312866,
"acc": 56.64,
"time": 2.880951809999999,
"param": null
},
{
"epoch": 36,
"train_loss": 0.643873929977417,
"val_loss": 2.8013312816619873,
"acc": 56.3,
"time": 2.8770664789999785,
"param": null
},
{
"epoch": 37,
"train_loss": 0.5519305467605591,
"val_loss": 2.3760335445404053,
"acc": 56.42,
"time": 2.860177633999996,
"param": null
},
{
"epoch": 38,
"train_loss": 0.4566989541053772,
"val_loss": 3.261263370513916,
"acc": 56.38,
"time": 2.8768970639999907,
"param": null
},
{
"epoch": 39,
"train_loss": 0.4070911407470703,
"val_loss": 3.162980079650879,
"acc": 55.65,
"time": 2.8753553720000014,
"param": null
},
{
"epoch": 40,
"train_loss": 0.3369447588920593,
"val_loss": 3.6949386596679688,
"acc": 56.16,
"time": 2.8803950129999976,
"param": null
},
{
"epoch": 41,
"train_loss": 0.233261376619339,
"val_loss": 3.1482434272766113,
"acc": 56.81,
"time": 2.958535100000006,
"param": null
},
{
"epoch": 42,
"train_loss": 0.2671070098876953,
"val_loss": 3.852823495864868,
"acc": 56.28,
"time": 2.8413516090000144,
"param": null
},
{
"epoch": 43,
"train_loss": 0.17192059755325317,
"val_loss": 3.9386160373687744,
"acc": 56.43,
"time": 2.882356039000001,
"param": null
},
{
"epoch": 44,
"train_loss": 0.21115460991859436,
"val_loss": 3.674778938293457,
"acc": 56.6,
"time": 2.8627468419999786,
"param": null
},
{
"epoch": 45,
"train_loss": 0.17167963087558746,
"val_loss": 3.678553819656372,
"acc": 55.82,
"time": 2.8513568369999973,
"param": null
},
{
"epoch": 46,
"train_loss": 0.1660262942314148,
"val_loss": 4.021941661834717,
"acc": 56.93,
"time": 2.883665002000015,
"param": null
},
{
"epoch": 47,
"train_loss": 0.14190584421157837,
"val_loss": 4.224549770355225,
"acc": 57.28,
"time": 2.8793066700000054,
"param": null
},
{
"epoch": 48,
"train_loss": 0.13566425442695618,
"val_loss": 3.8729546070098877,
"acc": 56.01,
"time": 2.8533336880000206,
"param": null
},
{
"epoch": 49,
"train_loss": 0.20210444927215576,
"val_loss": 4.181947231292725,
"acc": 56.79,
"time": 2.871362753999989,
"param": null
},
{
"epoch": 50,
"train_loss": 0.17691804468631744,
"val_loss": 4.109798908233643,
"acc": 56.21,
"time": 2.868851712000037,
"param": null
},
{
"epoch": 51,
"train_loss": 0.09427379816770554,
"val_loss": 4.5040059089660645,
"acc": 56.41,
"time": 2.8451589609999814,
"param": null
},
{
"epoch": 52,
"train_loss": 0.1147536113858223,
"val_loss": 4.073654651641846,
"acc": 57.34,
"time": 2.849203299999999,
"param": null
},
{
"epoch": 53,
"train_loss": 0.053406428545713425,
"val_loss": 4.1925435066223145,
"acc": 57.14,
"time": 2.9053060059999893,
"param": null
},
{
"epoch": 54,
"train_loss": 0.09250318259000778,
"val_loss": 4.594822406768799,
"acc": 56.91,
"time": 2.8818305060000284,
"param": null
},
{
"epoch": 55,
"train_loss": 0.040889639407396317,
"val_loss": 4.507361888885498,
"acc": 56.76,
"time": 2.9061625310000068,
"param": null
},
{
"epoch": 56,
"train_loss": 0.033295318484306335,
"val_loss": 4.082192897796631,
"acc": 57.37,
"time": 2.895834288000003,
"param": null
},
{
"epoch": 57,
"train_loss": 0.03692072629928589,
"val_loss": 5.112850189208984,
"acc": 57.24,
"time": 2.905829856999958,
"param": null
},
{
"epoch": 58,
"train_loss": 0.08262491971254349,
"val_loss": 5.430809497833252,
"acc": 57.25,
"time": 2.8767505010000036,
"param": null
},
{
"epoch": 59,
"train_loss": 0.018249312415719032,
"val_loss": 4.355240821838379,
"acc": 56.89,
"time": 2.865761673999998,
"param": null
},
{
"epoch": 60,
"train_loss": 0.03834836557507515,
"val_loss": 4.825883388519287,
"acc": 56.69,
"time": 2.8549083070000165,
"param": null
},
{
"epoch": 61,
"train_loss": 0.07849375903606415,
"val_loss": 4.652444362640381,
"acc": 57.56,
"time": 2.8533472180000103,
"param": null
},
{
"epoch": 62,
"train_loss": 0.04940357804298401,
"val_loss": 5.252978801727295,
"acc": 57.81,
"time": 2.8411212999999975,
"param": null
},
{
"epoch": 63,
"train_loss": 0.035617221146821976,
"val_loss": 4.783238410949707,
"acc": 56.78,
"time": 2.8682083010000383,
"param": null
},
{
"epoch": 64,
"train_loss": 0.04857274517416954,
"val_loss": 4.7291789054870605,
"acc": 57.41,
"time": 2.876162964999992,
"param": null
},
{
"epoch": 65,
"train_loss": 0.02348227985203266,
"val_loss": 4.526753902435303,
"acc": 57.38,
"time": 2.896928296999988,
"param": null
},
{
"epoch": 66,
"train_loss": 0.014324082992970943,
"val_loss": 5.439307689666748,
"acc": 57.11,
"time": 2.8848618569999758,
"param": null
},
{
"epoch": 67,
"train_loss": 0.023290712386369705,
"val_loss": 5.721205234527588,
"acc": 58.09,
"time": 2.8774091299999895,
"param": null
},
{
"epoch": 68,
"train_loss": 0.01625976897776127,
"val_loss": 6.3620829582214355,
"acc": 57.82,
"time": 2.8771049090000247,
"param": null
},
{
"epoch": 69,
"train_loss": 0.024740245193243027,
"val_loss": 4.560676097869873,
"acc": 58.33,
"time": 2.88398736299996,
"param": null
},
{
"epoch": 70,
"train_loss": 0.024858517572283745,
"val_loss": 4.982525825500488,
"acc": 58.51,
"time": 2.8906606280000346,
"param": null
},
{
"epoch": 71,
"train_loss": 0.02504434995353222,
"val_loss": 5.8997955322265625,
"acc": 58.4,
"time": 2.907554915999981,
"param": null
},
{
"epoch": 72,
"train_loss": 0.014550256542861462,
"val_loss": 4.25344705581665,
"acc": 58.63,
"time": 2.909489785000005,
"param": null
},
{
"epoch": 73,
"train_loss": 0.013704107142984867,
"val_loss": 4.419747352600098,
"acc": 58.4,
"time": 2.934780938000017,
"param": null
},
{
"epoch": 74,
"train_loss": 0.0007794767734594643,
"val_loss": 5.364877700805664,
"acc": 58.34,
"time": 2.8807501869999896,
"param": null
},
{
"epoch": 75,
"train_loss": 0.024075284600257874,
"val_loss": 4.890963554382324,
"acc": 58.39,
"time": 2.890908232999948,
"param": null
},
{
"epoch": 76,
"train_loss": 0.0007155562052503228,
"val_loss": 5.429594039916992,
"acc": 58.47,
"time": 2.9814038840000308,
"param": null
},
{
"epoch": 77,
"train_loss": 0.025066755712032318,
"val_loss": 5.024533748626709,
"acc": 58.43,
"time": 2.90384117800005,
"param": null
},
{
"epoch": 78,
"train_loss": 0.024619707837700844,
"val_loss": 5.127645492553711,
"acc": 58.48,
"time": 2.889764621999973,
"param": null
},
{
"epoch": 79,
"train_loss": 0.011625911109149456,
"val_loss": 5.643911361694336,
"acc": 58.44,
"time": 2.8605676199999834,
"param": null
},
{
"epoch": 80,
"train_loss": 0.035187769681215286,
"val_loss": 4.96016263961792,
"acc": 58.22,
"time": 3.0620140229999606,
"param": null
},
{
"epoch": 81,
"train_loss": 0.012686162255704403,
"val_loss": 4.861446380615234,
"acc": 58.29,
"time": 3.0467026740000165,
"param": null
},
{
"epoch": 82,
"train_loss": 0.03516190126538277,
"val_loss": 5.382493495941162,
"acc": 58.59,
"time": 2.948570172000018,
"param": null
},
{
"epoch": 83,
"train_loss": 0.044635310769081116,
"val_loss": 5.259823799133301,
"acc": 58.32,
"time": 2.9322751760000187,
"param": null
},
{
"epoch": 84,
"train_loss": 0.02375461347401142,
"val_loss": 4.794026851654053,
"acc": 58.32,
"time": 2.904587699999979,
"param": null
},
{
"epoch": 85,
"train_loss": 0.0005525348242372274,
"val_loss": 6.069370746612549,
"acc": 58.49,
"time": 2.9163807219999853,
"param": null
},
{
"epoch": 86,
"train_loss": 0.011588472872972488,
"val_loss": 4.31306266784668,
"acc": 58.26,
"time": 2.90974307700003,
"param": null
},
{
"epoch": 87,
"train_loss": 0.04098305106163025,
"val_loss": 5.128759860992432,
"acc": 58.38,
"time": 2.9223316519999685,
"param": null
},
{
"epoch": 88,
"train_loss": 0.0007651126361452043,
"val_loss": 5.77998685836792,
"acc": 58.31,
"time": 2.895599219000019,
"param": null
},
{
"epoch": 89,
"train_loss": 0.01347972359508276,
"val_loss": 5.071125030517578,
"acc": 58.44,
"time": 2.8723491999999737,
"param": null
},
{
"epoch": 90,
"train_loss": 0.0595603846013546,
"val_loss": 5.926196575164795,
"acc": 58.55,
"time": 2.933795826999983,
"param": null
},
{
"epoch": 91,
"train_loss": 0.010888534598052502,
"val_loss": 4.442563533782959,
"acc": 58.29,
"time": 2.873865481999985,
"param": null
},
{
"epoch": 92,
"train_loss": 0.01195440161973238,
"val_loss": 5.685233116149902,
"acc": 58.4,
"time": 2.908180643000037,
"param": null
},
{
"epoch": 93,
"train_loss": 0.012302059680223465,
"val_loss": 5.39638090133667,
"acc": 58.33,
"time": 2.8950546359999976,
"param": null
},
{
"epoch": 94,
"train_loss": 0.022545015439391136,
"val_loss": 4.769997596740723,
"acc": 58.38,
"time": 2.8618274910000423,
"param": null
},
{
"epoch": 95,
"train_loss": 0.032540563493967056,
"val_loss": 6.151810646057129,
"acc": 58.28,
"time": 2.8651707379999607,
"param": null
},
{
"epoch": 96,
"train_loss": 0.012408148497343063,
"val_loss": 5.0970139503479,
"acc": 58.3,
"time": 2.875970102999986,
"param": null
},
{
"epoch": 97,
"train_loss": 0.024478793144226074,
"val_loss": 5.438201904296875,
"acc": 58.36,
"time": 2.868522979999966,
"param": null
},
{
"epoch": 98,
"train_loss": 0.03768819198012352,
"val_loss": 6.296820640563965,
"acc": 58.3,
"time": 2.888704816000029,
"param": null
},
{
"epoch": 99,
"train_loss": 0.03524984419345856,
"val_loss": 5.514108657836914,
"acc": 58.17,
"time": 2.8756691419999925,
"param": null
}
]
}

83
higher/test_brutus.py Normal file
View file

@ -0,0 +1,83 @@
from model import *
from dataug import *
#from utils import *
from train_utils import *
tf_names = [
## Geometric TF ##
'Identity',
'FlipUD',
'FlipLR',
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
]
device = torch.device('cuda')
if device == torch.device('cpu'):
device_name = 'CPU'
else:
device_name = torch.cuda.get_device_name(device)
##########################################
if __name__ == "__main__":
res_folder="res/brutus-tests/"
epochs= 150
inner_its = [1]
dist_mix = [1]
dataug_epoch_starts= [0]
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
N_seq_TF= [2, 3, 4]
mag_setup = [(True,True), (False, False)]
#prob_setup = [True, False]
nb_run= 3
try:
os.mkdir(res_folder)
os.mkdir(res_folder+"log/")
except FileExistsError:
pass
for n_inner_iter in inner_its:
for dataug_epoch_start in dataug_epoch_starts:
for n_tf in N_seq_TF:
for dist in dist_mix:
#for i in TF_nb:
for m_setup in mag_setup:
#for p_setup in prob_setup:
for run in range(nb_run):
if n_inner_iter == 0 and (m_setup!=(True,True) or p_setup!=True): continue #Autres setup inutiles sans meta-opti
if n_tf ==2 and m_setup==(True,True): continue #Deja resultats
#keys = list(TF.TF_dict.keys())[0:i]
#ntf_dict = {k: TF.TF_dict[k] for k in keys}
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=False, fixed_mag=m_setup[0], shared_mag=m_setup[1]), LeNet(3,10)).to(device)
print(str(aug_model), 'on', device_name)
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=20, loss_patience=None)
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in :", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{}epochs(dataug:{})-{}in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter,run)
with open(res_folder+"log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
#plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names)
print('-'*9)

View file

@ -64,109 +64,97 @@ else:
##########################################
if __name__ == "__main__":
n_inner_iter = 1
epochs = 200
tasks={
#'classic',
'aug_dataset',
#'aug_model'
}
n_inner_iter = 0
epochs = 100
dataug_epoch_start=0
#### Classic ####
'''
model = LeNet(3,10).to(device)
#model = WideResNet(num_classes=10, wrn_size=16).to(device)
#model = Augmented_model(Data_augV3(mix_dist=0.0), LeNet(3,10)).to(device)
#model.augment(mode=False)
if 'classic' in tasks:
t0 = time.process_time()
model = LeNet(3,10).to(device)
#model = WideResNet(num_classes=10, wrn_size=16).to(device)
#model = Augmented_model(Data_augV3(mix_dist=0.0), LeNet(3,10)).to(device)
#model.augment(mode=False)
print(str(model), 'on', device_name)
log= train_classic(model=model, epochs=epochs)
#log= train_classic_higher(model=model, epochs=epochs)
print(str(model), 'on', device_name)
log= train_classic(model=model, epochs=epochs)
#log= train_classic_higher(model=model, epochs=epochs)
####
plot_res(log, fig_name="res/{}-{} epochs".format(str(model),epochs))
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Log": log}
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
with open("res/log/%s.json" % "{}-{} epochs".format(str(model),epochs), "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
print('-'*9)
'''
#### Augmented Model ####
'''
t0 = time.process_time()
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
#tf_dict = TF.TF_dict
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device)
print(str(aug_model), 'on', device_name)
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Log": log}
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs".format(str(model),epochs)
with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
plot_res(log, fig_name="res/"+filename)
plot_resV2(log, fig_name="res/"+filename, param_names=tf_names)
print('Execution Time : %.00f '%(time.process_time() - t0))
print('-'*9)
'''
#### TF tests ####
#'''
res_folder="res/brutus-tests/"
epochs= 150
inner_its = [1]
dist_mix = [1]
dataug_epoch_starts= [0]
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
N_seq_TF= [2, 3, 4]
mag_setup = [(True,True), (False, False)]
#prob_setup = [True, False]
nb_run= 3
print('Execution Time : %.00f '%(time.process_time() - t0))
print('-'*9)
try:
os.mkdir(res_folder)
os.mkdir(res_folder+"log/")
except FileExistsError:
pass
for n_inner_iter in inner_its:
for dataug_epoch_start in dataug_epoch_starts:
for n_tf in N_seq_TF:
for dist in dist_mix:
#for i in TF_nb:
for m_setup in mag_setup:
#for p_setup in prob_setup:
for run in range(nb_run):
if n_inner_iter == 0 and (m_setup!=(True,True) or p_setup!=True): continue #Autres setup inutiles sans meta-opti
if n_tf ==2 and m_setup==(True,True): continue #Deja resultats
#keys = list(TF.TF_dict.keys())[0:i]
#ntf_dict = {k: TF.TF_dict[k] for k in keys}
#### Augmented Dataset ####
if 'aug_dataset' in tasks:
t0 = time.process_time()
model = LeNet(3,10).to(device)
#model = WideResNet(num_classes=10, wrn_size=16).to(device)
#model = Augmented_model(Data_augV3(mix_dist=0.0), LeNet(3,10)).to(device)
#model.augment(mode=False)
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=False, fixed_mag=m_setup[0], shared_mag=m_setup[1]), LeNet(3,10)).to(device)
print(str(aug_model), 'on', device_name)
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=20, loss_patience=None)
print(str(model), 'on', device_name)
log= train_classic(model=model, epochs=epochs)
#log= train_classic_higher(model=model, epochs=epochs)
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in :", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{}epochs(dataug:{})-{}in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter,run)
with open(res_folder+"log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Log": log}
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{}-{} epochs".format(str(data_train_aug),str(model),epochs)
with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
#plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names)
print('-'*9)
plot_res(log, fig_name="res/"+filename)
#'''
print('Execution Time : %.00f '%(time.process_time() - t0))
print('-'*9)
#### Augmented Model ####
if 'aug_model' in tasks:
t0 = time.process_time()
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device)
print(str(aug_model), 'on', device_name)
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
plot_resV2(log, fig_name="res/"+filename, param_names=tf_names)
print('Execution Time : %.00f '%(time.process_time() - t0))
print('-'*9)

View file

@ -616,10 +616,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
#if epoch>50:
meta_opt.step()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
#model['data_aug'].next_TF_set()
if epoch>50:
meta_opt.step()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
#model['data_aug'].next_TF_set()
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
@ -683,8 +683,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
model.augment(mode=True)
if inner_it != 0: high_grad_track = True
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
#print("Copy ", countcopy)
return log