Augmented Dataset fonctionnel
|
@ -35,14 +35,15 @@ import augmentation_transforms
|
|||
import numpy as np
|
||||
|
||||
class AugmentedDataset(VisionDataset):
|
||||
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
|
||||
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None):
|
||||
|
||||
super(AugmentedDataset, self).__init__(root, transform=transform, target_transform=target_transform)
|
||||
|
||||
supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform)
|
||||
|
||||
self.sup_data = supervised_dataset.data
|
||||
self.sup_targets = supervised_dataset.targets
|
||||
self.sup_data = supervised_dataset.data if not subset else supervised_dataset.data[subset[0]:subset[1]]
|
||||
self.sup_targets = supervised_dataset.targets if not subset else supervised_dataset.targets[subset[0]:subset[1]]
|
||||
assert len(self.sup_data)==len(self.sup_targets)
|
||||
|
||||
for idx, img in enumerate(self.sup_data):
|
||||
self.sup_data[idx]= Image.fromarray(img) #to PIL Image
|
||||
|
@ -53,11 +54,19 @@ class AugmentedDataset(VisionDataset):
|
|||
self.data= self.sup_data
|
||||
self.targets= self.sup_targets
|
||||
|
||||
self.dataset_info= {
|
||||
'name': 'CIFAR10',
|
||||
'sup': len(self.sup_data),
|
||||
'unsup': len(self.unsup_data),
|
||||
'length': len(self.sup_data)+len(self.unsup_data),
|
||||
}
|
||||
|
||||
|
||||
self._TF = [
|
||||
'Invert', 'Cutout', 'Sharpness', 'AutoContrast', 'Posterize',
|
||||
'ShearX', 'TranslateX', 'TranslateY', 'ShearY', 'Rotate',
|
||||
'Equalize', 'Contrast', 'Color', 'Solarize', 'Brightness']
|
||||
'Equalize', 'Contrast', 'Color', 'Solarize', 'Brightness'
|
||||
]
|
||||
self._op_list =[]
|
||||
self.prob=0.5
|
||||
for tf in self._TF:
|
||||
|
@ -95,6 +104,8 @@ class AugmentedDataset(VisionDataset):
|
|||
policies += [[op_1, op_2]]
|
||||
|
||||
for idx, image in enumerate(self.sup_data):
|
||||
if (idx/self.dataset_info['sup'])%0.2==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup'])
|
||||
|
||||
for _ in range(aug_copy):
|
||||
chosen_policy = policies[np.random.choice(len(policies))]
|
||||
aug_image = augmentation_transforms.apply_policy(chosen_policy, image)
|
||||
|
@ -103,42 +114,47 @@ class AugmentedDataset(VisionDataset):
|
|||
self.unsup_data+=[aug_image]
|
||||
self.unsup_targets+=[self.sup_targets[idx]]
|
||||
|
||||
print(type(self.data), type(self.sup_data), type(self.unsup_data))
|
||||
print(len(self.data), len(self.sup_data), len(self.unsup_data))
|
||||
#self.data= self.sup_data+self.unsup_data
|
||||
self.unsup_data=np.array(self.unsup_data).astype(self.sup_data.dtype)
|
||||
self.data= np.concatenate((self.sup_data, self.unsup_data), axis=0)
|
||||
print(len(self.data))
|
||||
self.targets= self.sup_targets+self.unsup_targets
|
||||
self.targets= np.concatenate((self.sup_targets, self.unsup_targets), axis=0)
|
||||
|
||||
assert len(self.unsup_data)==len(self.unsup_targets)
|
||||
assert len(self.data)==len(self.targets)
|
||||
self.dataset_info['unsup']=len(self.unsup_data)
|
||||
self.dataset_info['length']=self.dataset_info['sup']+self.dataset_info['unsup']
|
||||
|
||||
def len_supervised(self):
|
||||
return len(self.sup_data)
|
||||
return self.dataset_info['sup']
|
||||
|
||||
def len_unsupervised(self):
|
||||
return len(self.unsup_data)
|
||||
return self.dataset_info['unsup']
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
return self.dataset_info['length']
|
||||
|
||||
def __str__(self):
|
||||
return "CIFAR10(Sup:{}-Unsup:{})".format(self.dataset_info['sup'], self.dataset_info['unsup'])
|
||||
|
||||
### Classic Dataset ###
|
||||
data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform)
|
||||
#print(len(data_train))
|
||||
#data_train = AugmentedDataset("./data", train=True, download=True, transform=transform)
|
||||
#print(len(data_train), data_train.len_supervised(), data_train.len_unsupervised())
|
||||
#data_train.augement_data()
|
||||
#print(len(data_train), data_train.len_supervised(), data_train.len_unsupervised())
|
||||
#data_val = torchvision.datasets.CIFAR10(
|
||||
# "./data", train=True, download=True, transform=transform
|
||||
#)
|
||||
data_test = torchvision.datasets.CIFAR10(
|
||||
"./data", train=False, download=True, transform=transform
|
||||
)
|
||||
#'''
|
||||
#data_val = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform)
|
||||
data_test = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=transform)
|
||||
|
||||
|
||||
train_subset_indices=range(int(len(data_train)/2))
|
||||
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||
|
||||
### Augmented Dataset ###
|
||||
data_train_aug = AugmentedDataset("./data", train=True, download=True, transform=transform, subset=(0,int(len(data_train)/2)))
|
||||
#data_train_aug.augement_data(aug_copy=1)
|
||||
print(data_train_aug)
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
|
||||
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False)
|
||||
|
|
After Width: | Height: | Size: 179 KiB |
After Width: | Height: | Size: 392 KiB |
After Width: | Height: | Size: 446 KiB |
After Width: | Height: | Size: 324 KiB |
After Width: | Height: | Size: 191 KiB |
After Width: | Height: | Size: 156 KiB |
After Width: | Height: | Size: 200 KiB |
After Width: | Height: | Size: 203 KiB |
After Width: | Height: | Size: 292 KiB |
After Width: | Height: | Size: 192 KiB |
After Width: | Height: | Size: 293 KiB |
Before Width: | Height: | Size: 43 KiB |
After Width: | Height: | Size: 170 KiB |
After Width: | Height: | Size: 408 KiB |
After Width: | Height: | Size: 336 KiB |
|
@ -0,0 +1,810 @@
|
|||
{
|
||||
"Accuracy": 59.15,
|
||||
"Time": [
|
||||
2.891577087839999,
|
||||
0.04480297073115646
|
||||
],
|
||||
"Device": "TITAN RTX",
|
||||
"Log": [
|
||||
{
|
||||
"epoch": 0,
|
||||
"train_loss": 2.238740921020508,
|
||||
"val_loss": 2.1966593265533447,
|
||||
"acc": 18.39,
|
||||
"time": 3.142668161000003,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 1,
|
||||
"train_loss": 2.1030218601226807,
|
||||
"val_loss": 1.9304202795028687,
|
||||
"acc": 31.4,
|
||||
"time": 2.866064168000001,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 2,
|
||||
"train_loss": 2.102165460586548,
|
||||
"val_loss": 1.7866367101669312,
|
||||
"acc": 34.08,
|
||||
"time": 2.8656765240000013,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 3,
|
||||
"train_loss": 1.9763004779815674,
|
||||
"val_loss": 1.6346896886825562,
|
||||
"acc": 44.06,
|
||||
"time": 2.8573803359999985,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 4,
|
||||
"train_loss": 1.9886680841445923,
|
||||
"val_loss": 1.5025888681411743,
|
||||
"acc": 45.91,
|
||||
"time": 2.884063083000001,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 5,
|
||||
"train_loss": 1.8460354804992676,
|
||||
"val_loss": 1.4987385272979736,
|
||||
"acc": 47.11,
|
||||
"time": 2.8803015360000046,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 6,
|
||||
"train_loss": 1.8773751258850098,
|
||||
"val_loss": 1.3948158025741577,
|
||||
"acc": 46.54,
|
||||
"time": 2.8545809470000023,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 7,
|
||||
"train_loss": 1.9329789876937866,
|
||||
"val_loss": 1.3192082643508911,
|
||||
"acc": 49.96,
|
||||
"time": 2.8968874109999945,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 8,
|
||||
"train_loss": 1.877233624458313,
|
||||
"val_loss": 1.4252185821533203,
|
||||
"acc": 51.3,
|
||||
"time": 2.856139868999989,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 9,
|
||||
"train_loss": 1.8099722862243652,
|
||||
"val_loss": 1.4050142765045166,
|
||||
"acc": 52.78,
|
||||
"time": 2.9680558680000075,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 10,
|
||||
"train_loss": 1.7314376831054688,
|
||||
"val_loss": 1.2530089616775513,
|
||||
"acc": 54.54,
|
||||
"time": 2.8540503600000022,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 11,
|
||||
"train_loss": 1.714591383934021,
|
||||
"val_loss": 1.4025185108184814,
|
||||
"acc": 54.3,
|
||||
"time": 2.8576988419999907,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 12,
|
||||
"train_loss": 1.6348106861114502,
|
||||
"val_loss": 1.2283456325531006,
|
||||
"acc": 56.1,
|
||||
"time": 2.8865886950000004,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 13,
|
||||
"train_loss": 1.6358189582824707,
|
||||
"val_loss": 1.2930848598480225,
|
||||
"acc": 55.75,
|
||||
"time": 2.872028806000003,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 14,
|
||||
"train_loss": 1.7428033351898193,
|
||||
"val_loss": 1.185242772102356,
|
||||
"acc": 58.11,
|
||||
"time": 2.8523812309999954,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 15,
|
||||
"train_loss": 1.584743857383728,
|
||||
"val_loss": 1.2095788717269897,
|
||||
"acc": 58.13,
|
||||
"time": 2.8867363259999905,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 16,
|
||||
"train_loss": 1.637380838394165,
|
||||
"val_loss": 1.2247447967529297,
|
||||
"acc": 57.89,
|
||||
"time": 2.927922326000001,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 17,
|
||||
"train_loss": 1.3542273044586182,
|
||||
"val_loss": 1.1683770418167114,
|
||||
"acc": 57.79,
|
||||
"time": 2.9269400579999996,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 18,
|
||||
"train_loss": 1.5338425636291504,
|
||||
"val_loss": 1.1944950819015503,
|
||||
"acc": 57.9,
|
||||
"time": 2.8685170960000193,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 19,
|
||||
"train_loss": 1.406481146812439,
|
||||
"val_loss": 1.164185643196106,
|
||||
"acc": 58.67,
|
||||
"time": 2.8701969949999864,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 20,
|
||||
"train_loss": 1.4967502355575562,
|
||||
"val_loss": 1.3286499977111816,
|
||||
"acc": 57.86,
|
||||
"time": 2.8738660139999865,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 21,
|
||||
"train_loss": 1.4469761848449707,
|
||||
"val_loss": 1.2245217561721802,
|
||||
"acc": 59.15,
|
||||
"time": 2.8806329870000127,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 22,
|
||||
"train_loss": 1.3278372287750244,
|
||||
"val_loss": 1.3220281600952148,
|
||||
"acc": 58.65,
|
||||
"time": 2.8654193290000194,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 23,
|
||||
"train_loss": 1.310671329498291,
|
||||
"val_loss": 1.4258655309677124,
|
||||
"acc": 58.62,
|
||||
"time": 2.893132035000008,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 24,
|
||||
"train_loss": 1.2683414220809937,
|
||||
"val_loss": 1.2353028059005737,
|
||||
"acc": 58.18,
|
||||
"time": 2.8943645079999953,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 25,
|
||||
"train_loss": 1.2522042989730835,
|
||||
"val_loss": 1.3762693405151367,
|
||||
"acc": 57.19,
|
||||
"time": 2.902485732999992,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 26,
|
||||
"train_loss": 1.2245501279830933,
|
||||
"val_loss": 1.4818042516708374,
|
||||
"acc": 58.7,
|
||||
"time": 2.9064084750000063,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 27,
|
||||
"train_loss": 1.0753016471862793,
|
||||
"val_loss": 1.2684273719787598,
|
||||
"acc": 57.39,
|
||||
"time": 2.8864203069999803,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 28,
|
||||
"train_loss": 1.204155445098877,
|
||||
"val_loss": 1.6392837762832642,
|
||||
"acc": 58.03,
|
||||
"time": 2.8522731609999994,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 29,
|
||||
"train_loss": 1.0476032495498657,
|
||||
"val_loss": 1.9471677541732788,
|
||||
"acc": 57.36,
|
||||
"time": 2.8621515780000095,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 30,
|
||||
"train_loss": 1.0602569580078125,
|
||||
"val_loss": 1.7120074033737183,
|
||||
"acc": 57.4,
|
||||
"time": 2.892253502000017,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 31,
|
||||
"train_loss": 0.8457854390144348,
|
||||
"val_loss": 1.7540744543075562,
|
||||
"acc": 57.55,
|
||||
"time": 2.876599782999989,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 32,
|
||||
"train_loss": 0.7922731041908264,
|
||||
"val_loss": 2.066330671310425,
|
||||
"acc": 55.91,
|
||||
"time": 2.8669344470000055,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 33,
|
||||
"train_loss": 0.6974552273750305,
|
||||
"val_loss": 1.8737841844558716,
|
||||
"acc": 57.54,
|
||||
"time": 2.8812602219999803,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 34,
|
||||
"train_loss": 0.6545730829238892,
|
||||
"val_loss": 2.59554386138916,
|
||||
"acc": 56.33,
|
||||
"time": 2.992003705000002,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 35,
|
||||
"train_loss": 0.6207562685012817,
|
||||
"val_loss": 2.878373861312866,
|
||||
"acc": 56.64,
|
||||
"time": 2.880951809999999,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 36,
|
||||
"train_loss": 0.643873929977417,
|
||||
"val_loss": 2.8013312816619873,
|
||||
"acc": 56.3,
|
||||
"time": 2.8770664789999785,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 37,
|
||||
"train_loss": 0.5519305467605591,
|
||||
"val_loss": 2.3760335445404053,
|
||||
"acc": 56.42,
|
||||
"time": 2.860177633999996,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 38,
|
||||
"train_loss": 0.4566989541053772,
|
||||
"val_loss": 3.261263370513916,
|
||||
"acc": 56.38,
|
||||
"time": 2.8768970639999907,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 39,
|
||||
"train_loss": 0.4070911407470703,
|
||||
"val_loss": 3.162980079650879,
|
||||
"acc": 55.65,
|
||||
"time": 2.8753553720000014,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 40,
|
||||
"train_loss": 0.3369447588920593,
|
||||
"val_loss": 3.6949386596679688,
|
||||
"acc": 56.16,
|
||||
"time": 2.8803950129999976,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 41,
|
||||
"train_loss": 0.233261376619339,
|
||||
"val_loss": 3.1482434272766113,
|
||||
"acc": 56.81,
|
||||
"time": 2.958535100000006,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 42,
|
||||
"train_loss": 0.2671070098876953,
|
||||
"val_loss": 3.852823495864868,
|
||||
"acc": 56.28,
|
||||
"time": 2.8413516090000144,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 43,
|
||||
"train_loss": 0.17192059755325317,
|
||||
"val_loss": 3.9386160373687744,
|
||||
"acc": 56.43,
|
||||
"time": 2.882356039000001,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 44,
|
||||
"train_loss": 0.21115460991859436,
|
||||
"val_loss": 3.674778938293457,
|
||||
"acc": 56.6,
|
||||
"time": 2.8627468419999786,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 45,
|
||||
"train_loss": 0.17167963087558746,
|
||||
"val_loss": 3.678553819656372,
|
||||
"acc": 55.82,
|
||||
"time": 2.8513568369999973,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 46,
|
||||
"train_loss": 0.1660262942314148,
|
||||
"val_loss": 4.021941661834717,
|
||||
"acc": 56.93,
|
||||
"time": 2.883665002000015,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 47,
|
||||
"train_loss": 0.14190584421157837,
|
||||
"val_loss": 4.224549770355225,
|
||||
"acc": 57.28,
|
||||
"time": 2.8793066700000054,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 48,
|
||||
"train_loss": 0.13566425442695618,
|
||||
"val_loss": 3.8729546070098877,
|
||||
"acc": 56.01,
|
||||
"time": 2.8533336880000206,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 49,
|
||||
"train_loss": 0.20210444927215576,
|
||||
"val_loss": 4.181947231292725,
|
||||
"acc": 56.79,
|
||||
"time": 2.871362753999989,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 50,
|
||||
"train_loss": 0.17691804468631744,
|
||||
"val_loss": 4.109798908233643,
|
||||
"acc": 56.21,
|
||||
"time": 2.868851712000037,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 51,
|
||||
"train_loss": 0.09427379816770554,
|
||||
"val_loss": 4.5040059089660645,
|
||||
"acc": 56.41,
|
||||
"time": 2.8451589609999814,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 52,
|
||||
"train_loss": 0.1147536113858223,
|
||||
"val_loss": 4.073654651641846,
|
||||
"acc": 57.34,
|
||||
"time": 2.849203299999999,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 53,
|
||||
"train_loss": 0.053406428545713425,
|
||||
"val_loss": 4.1925435066223145,
|
||||
"acc": 57.14,
|
||||
"time": 2.9053060059999893,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 54,
|
||||
"train_loss": 0.09250318259000778,
|
||||
"val_loss": 4.594822406768799,
|
||||
"acc": 56.91,
|
||||
"time": 2.8818305060000284,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 55,
|
||||
"train_loss": 0.040889639407396317,
|
||||
"val_loss": 4.507361888885498,
|
||||
"acc": 56.76,
|
||||
"time": 2.9061625310000068,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 56,
|
||||
"train_loss": 0.033295318484306335,
|
||||
"val_loss": 4.082192897796631,
|
||||
"acc": 57.37,
|
||||
"time": 2.895834288000003,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 57,
|
||||
"train_loss": 0.03692072629928589,
|
||||
"val_loss": 5.112850189208984,
|
||||
"acc": 57.24,
|
||||
"time": 2.905829856999958,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 58,
|
||||
"train_loss": 0.08262491971254349,
|
||||
"val_loss": 5.430809497833252,
|
||||
"acc": 57.25,
|
||||
"time": 2.8767505010000036,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 59,
|
||||
"train_loss": 0.018249312415719032,
|
||||
"val_loss": 4.355240821838379,
|
||||
"acc": 56.89,
|
||||
"time": 2.865761673999998,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 60,
|
||||
"train_loss": 0.03834836557507515,
|
||||
"val_loss": 4.825883388519287,
|
||||
"acc": 56.69,
|
||||
"time": 2.8549083070000165,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 61,
|
||||
"train_loss": 0.07849375903606415,
|
||||
"val_loss": 4.652444362640381,
|
||||
"acc": 57.56,
|
||||
"time": 2.8533472180000103,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 62,
|
||||
"train_loss": 0.04940357804298401,
|
||||
"val_loss": 5.252978801727295,
|
||||
"acc": 57.81,
|
||||
"time": 2.8411212999999975,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 63,
|
||||
"train_loss": 0.035617221146821976,
|
||||
"val_loss": 4.783238410949707,
|
||||
"acc": 56.78,
|
||||
"time": 2.8682083010000383,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 64,
|
||||
"train_loss": 0.04857274517416954,
|
||||
"val_loss": 4.7291789054870605,
|
||||
"acc": 57.41,
|
||||
"time": 2.876162964999992,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 65,
|
||||
"train_loss": 0.02348227985203266,
|
||||
"val_loss": 4.526753902435303,
|
||||
"acc": 57.38,
|
||||
"time": 2.896928296999988,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 66,
|
||||
"train_loss": 0.014324082992970943,
|
||||
"val_loss": 5.439307689666748,
|
||||
"acc": 57.11,
|
||||
"time": 2.8848618569999758,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 67,
|
||||
"train_loss": 0.023290712386369705,
|
||||
"val_loss": 5.721205234527588,
|
||||
"acc": 58.09,
|
||||
"time": 2.8774091299999895,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 68,
|
||||
"train_loss": 0.01625976897776127,
|
||||
"val_loss": 6.3620829582214355,
|
||||
"acc": 57.82,
|
||||
"time": 2.8771049090000247,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 69,
|
||||
"train_loss": 0.024740245193243027,
|
||||
"val_loss": 4.560676097869873,
|
||||
"acc": 58.33,
|
||||
"time": 2.88398736299996,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 70,
|
||||
"train_loss": 0.024858517572283745,
|
||||
"val_loss": 4.982525825500488,
|
||||
"acc": 58.51,
|
||||
"time": 2.8906606280000346,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 71,
|
||||
"train_loss": 0.02504434995353222,
|
||||
"val_loss": 5.8997955322265625,
|
||||
"acc": 58.4,
|
||||
"time": 2.907554915999981,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 72,
|
||||
"train_loss": 0.014550256542861462,
|
||||
"val_loss": 4.25344705581665,
|
||||
"acc": 58.63,
|
||||
"time": 2.909489785000005,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 73,
|
||||
"train_loss": 0.013704107142984867,
|
||||
"val_loss": 4.419747352600098,
|
||||
"acc": 58.4,
|
||||
"time": 2.934780938000017,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 74,
|
||||
"train_loss": 0.0007794767734594643,
|
||||
"val_loss": 5.364877700805664,
|
||||
"acc": 58.34,
|
||||
"time": 2.8807501869999896,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 75,
|
||||
"train_loss": 0.024075284600257874,
|
||||
"val_loss": 4.890963554382324,
|
||||
"acc": 58.39,
|
||||
"time": 2.890908232999948,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 76,
|
||||
"train_loss": 0.0007155562052503228,
|
||||
"val_loss": 5.429594039916992,
|
||||
"acc": 58.47,
|
||||
"time": 2.9814038840000308,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 77,
|
||||
"train_loss": 0.025066755712032318,
|
||||
"val_loss": 5.024533748626709,
|
||||
"acc": 58.43,
|
||||
"time": 2.90384117800005,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 78,
|
||||
"train_loss": 0.024619707837700844,
|
||||
"val_loss": 5.127645492553711,
|
||||
"acc": 58.48,
|
||||
"time": 2.889764621999973,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 79,
|
||||
"train_loss": 0.011625911109149456,
|
||||
"val_loss": 5.643911361694336,
|
||||
"acc": 58.44,
|
||||
"time": 2.8605676199999834,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 80,
|
||||
"train_loss": 0.035187769681215286,
|
||||
"val_loss": 4.96016263961792,
|
||||
"acc": 58.22,
|
||||
"time": 3.0620140229999606,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 81,
|
||||
"train_loss": 0.012686162255704403,
|
||||
"val_loss": 4.861446380615234,
|
||||
"acc": 58.29,
|
||||
"time": 3.0467026740000165,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 82,
|
||||
"train_loss": 0.03516190126538277,
|
||||
"val_loss": 5.382493495941162,
|
||||
"acc": 58.59,
|
||||
"time": 2.948570172000018,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 83,
|
||||
"train_loss": 0.044635310769081116,
|
||||
"val_loss": 5.259823799133301,
|
||||
"acc": 58.32,
|
||||
"time": 2.9322751760000187,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 84,
|
||||
"train_loss": 0.02375461347401142,
|
||||
"val_loss": 4.794026851654053,
|
||||
"acc": 58.32,
|
||||
"time": 2.904587699999979,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 85,
|
||||
"train_loss": 0.0005525348242372274,
|
||||
"val_loss": 6.069370746612549,
|
||||
"acc": 58.49,
|
||||
"time": 2.9163807219999853,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 86,
|
||||
"train_loss": 0.011588472872972488,
|
||||
"val_loss": 4.31306266784668,
|
||||
"acc": 58.26,
|
||||
"time": 2.90974307700003,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 87,
|
||||
"train_loss": 0.04098305106163025,
|
||||
"val_loss": 5.128759860992432,
|
||||
"acc": 58.38,
|
||||
"time": 2.9223316519999685,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 88,
|
||||
"train_loss": 0.0007651126361452043,
|
||||
"val_loss": 5.77998685836792,
|
||||
"acc": 58.31,
|
||||
"time": 2.895599219000019,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 89,
|
||||
"train_loss": 0.01347972359508276,
|
||||
"val_loss": 5.071125030517578,
|
||||
"acc": 58.44,
|
||||
"time": 2.8723491999999737,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 90,
|
||||
"train_loss": 0.0595603846013546,
|
||||
"val_loss": 5.926196575164795,
|
||||
"acc": 58.55,
|
||||
"time": 2.933795826999983,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 91,
|
||||
"train_loss": 0.010888534598052502,
|
||||
"val_loss": 4.442563533782959,
|
||||
"acc": 58.29,
|
||||
"time": 2.873865481999985,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 92,
|
||||
"train_loss": 0.01195440161973238,
|
||||
"val_loss": 5.685233116149902,
|
||||
"acc": 58.4,
|
||||
"time": 2.908180643000037,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 93,
|
||||
"train_loss": 0.012302059680223465,
|
||||
"val_loss": 5.39638090133667,
|
||||
"acc": 58.33,
|
||||
"time": 2.8950546359999976,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 94,
|
||||
"train_loss": 0.022545015439391136,
|
||||
"val_loss": 4.769997596740723,
|
||||
"acc": 58.38,
|
||||
"time": 2.8618274910000423,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 95,
|
||||
"train_loss": 0.032540563493967056,
|
||||
"val_loss": 6.151810646057129,
|
||||
"acc": 58.28,
|
||||
"time": 2.8651707379999607,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 96,
|
||||
"train_loss": 0.012408148497343063,
|
||||
"val_loss": 5.0970139503479,
|
||||
"acc": 58.3,
|
||||
"time": 2.875970102999986,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 97,
|
||||
"train_loss": 0.024478793144226074,
|
||||
"val_loss": 5.438201904296875,
|
||||
"acc": 58.36,
|
||||
"time": 2.868522979999966,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 98,
|
||||
"train_loss": 0.03768819198012352,
|
||||
"val_loss": 6.296820640563965,
|
||||
"acc": 58.3,
|
||||
"time": 2.888704816000029,
|
||||
"param": null
|
||||
},
|
||||
{
|
||||
"epoch": 99,
|
||||
"train_loss": 0.03524984419345856,
|
||||
"val_loss": 5.514108657836914,
|
||||
"acc": 58.17,
|
||||
"time": 2.8756691419999925,
|
||||
"param": null
|
||||
}
|
||||
]
|
||||
}
|
83
higher/test_brutus.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
from model import *
|
||||
from dataug import *
|
||||
#from utils import *
|
||||
from train_utils import *
|
||||
|
||||
tf_names = [
|
||||
## Geometric TF ##
|
||||
'Identity',
|
||||
'FlipUD',
|
||||
'FlipLR',
|
||||
'Rotate',
|
||||
'TranslateX',
|
||||
'TranslateY',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast',
|
||||
'Color',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'Posterize',
|
||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||
]
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
if device == torch.device('cpu'):
|
||||
device_name = 'CPU'
|
||||
else:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
res_folder="res/brutus-tests/"
|
||||
epochs= 150
|
||||
inner_its = [1]
|
||||
dist_mix = [1]
|
||||
dataug_epoch_starts= [0]
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
|
||||
N_seq_TF= [2, 3, 4]
|
||||
mag_setup = [(True,True), (False, False)]
|
||||
#prob_setup = [True, False]
|
||||
nb_run= 3
|
||||
|
||||
try:
|
||||
os.mkdir(res_folder)
|
||||
os.mkdir(res_folder+"log/")
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
for n_inner_iter in inner_its:
|
||||
for dataug_epoch_start in dataug_epoch_starts:
|
||||
for n_tf in N_seq_TF:
|
||||
for dist in dist_mix:
|
||||
#for i in TF_nb:
|
||||
for m_setup in mag_setup:
|
||||
#for p_setup in prob_setup:
|
||||
for run in range(nb_run):
|
||||
if n_inner_iter == 0 and (m_setup!=(True,True) or p_setup!=True): continue #Autres setup inutiles sans meta-opti
|
||||
if n_tf ==2 and m_setup==(True,True): continue #Deja resultats
|
||||
#keys = list(TF.TF_dict.keys())[0:i]
|
||||
#ntf_dict = {k: TF.TF_dict[k] for k in keys}
|
||||
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=False, fixed_mag=m_setup[0], shared_mag=m_setup[1]), LeNet(3,10)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=20, loss_patience=None)
|
||||
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in :", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{}epochs(dataug:{})-{}in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter,run)
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
||||
#plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names)
|
||||
print('-'*9)
|
|
@ -64,12 +64,19 @@ else:
|
|||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
n_inner_iter = 1
|
||||
epochs = 200
|
||||
tasks={
|
||||
#'classic',
|
||||
'aug_dataset',
|
||||
#'aug_model'
|
||||
}
|
||||
n_inner_iter = 0
|
||||
epochs = 100
|
||||
dataug_epoch_start=0
|
||||
|
||||
|
||||
#### Classic ####
|
||||
'''
|
||||
if 'classic' in tasks:
|
||||
t0 = time.process_time()
|
||||
model = LeNet(3,10).to(device)
|
||||
#model = WideResNet(num_classes=10, wrn_size=16).to(device)
|
||||
#model = Augmented_model(Data_augV3(mix_dist=0.0), LeNet(3,10)).to(device)
|
||||
|
@ -80,23 +87,57 @@ if __name__ == "__main__":
|
|||
#log= train_classic_higher(model=model, epochs=epochs)
|
||||
|
||||
####
|
||||
plot_res(log, fig_name="res/{}-{} epochs".format(str(model),epochs))
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Log": log}
|
||||
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
with open("res/log/%s.json" % "{}-{} epochs".format(str(model),epochs), "w+") as f:
|
||||
filename = "{}-{} epochs".format(str(model),epochs)
|
||||
with open("res/log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
||||
plot_res(log, fig_name="res/"+filename)
|
||||
|
||||
print('Execution Time : %.00f '%(time.process_time() - t0))
|
||||
print('-'*9)
|
||||
'''
|
||||
#### Augmented Model ####
|
||||
'''
|
||||
|
||||
|
||||
#### Augmented Dataset ####
|
||||
if 'aug_dataset' in tasks:
|
||||
t0 = time.process_time()
|
||||
model = LeNet(3,10).to(device)
|
||||
#model = WideResNet(num_classes=10, wrn_size=16).to(device)
|
||||
#model = Augmented_model(Data_augV3(mix_dist=0.0), LeNet(3,10)).to(device)
|
||||
#model.augment(mode=False)
|
||||
|
||||
print(str(model), 'on', device_name)
|
||||
log= train_classic(model=model, epochs=epochs)
|
||||
#log= train_classic_higher(model=model, epochs=epochs)
|
||||
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Log": log}
|
||||
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{}-{} epochs".format(str(data_train_aug),str(model),epochs)
|
||||
with open("res/log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
||||
plot_res(log, fig_name="res/"+filename)
|
||||
|
||||
print('Execution Time : %.00f '%(time.process_time() - t0))
|
||||
print('-'*9)
|
||||
|
||||
|
||||
#### Augmented Model ####
|
||||
if 'aug_model' in tasks:
|
||||
t0 = time.process_time()
|
||||
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
#tf_dict = TF.TF_dict
|
||||
|
||||
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
|
||||
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
|
@ -117,56 +158,3 @@ if __name__ == "__main__":
|
|||
|
||||
print('Execution Time : %.00f '%(time.process_time() - t0))
|
||||
print('-'*9)
|
||||
'''
|
||||
#### TF tests ####
|
||||
#'''
|
||||
res_folder="res/brutus-tests/"
|
||||
epochs= 150
|
||||
inner_its = [1]
|
||||
dist_mix = [1]
|
||||
dataug_epoch_starts= [0]
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
|
||||
N_seq_TF= [2, 3, 4]
|
||||
mag_setup = [(True,True), (False, False)]
|
||||
#prob_setup = [True, False]
|
||||
nb_run= 3
|
||||
|
||||
try:
|
||||
os.mkdir(res_folder)
|
||||
os.mkdir(res_folder+"log/")
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
for n_inner_iter in inner_its:
|
||||
for dataug_epoch_start in dataug_epoch_starts:
|
||||
for n_tf in N_seq_TF:
|
||||
for dist in dist_mix:
|
||||
#for i in TF_nb:
|
||||
for m_setup in mag_setup:
|
||||
#for p_setup in prob_setup:
|
||||
for run in range(nb_run):
|
||||
if n_inner_iter == 0 and (m_setup!=(True,True) or p_setup!=True): continue #Autres setup inutiles sans meta-opti
|
||||
if n_tf ==2 and m_setup==(True,True): continue #Deja resultats
|
||||
#keys = list(TF.TF_dict.keys())[0:i]
|
||||
#ntf_dict = {k: TF.TF_dict[k] for k in keys}
|
||||
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=False, fixed_mag=m_setup[0], shared_mag=m_setup[1]), LeNet(3,10)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=20, loss_patience=None)
|
||||
|
||||
####
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in :", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{}epochs(dataug:{})-{}in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter,run)
|
||||
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
||||
#plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names)
|
||||
print('-'*9)
|
||||
|
||||
#'''
|
||||
|
|
|
@ -616,7 +616,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
#if epoch>50:
|
||||
if epoch>50:
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
#model['data_aug'].next_TF_set()
|
||||
|
@ -683,8 +683,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
model.augment(mode=True)
|
||||
if inner_it != 0: high_grad_track = True
|
||||
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
||||
#print("Copy ", countcopy)
|
||||
return log
|
||||
|
|