diff --git a/higher/compare_res.py b/higher/compare_res.py index 2509b41..ee6c373 100644 --- a/higher/compare_res.py +++ b/higher/compare_res.py @@ -21,17 +21,22 @@ if __name__ == "__main__": ## Acc, Time, Epochs = f(n_tf) ## fig_name="res/TF_seq_tests_compare" - inner_its = [10] + inner_its = [0] dataug_epoch_starts= [0] - TF_nb = 14 #range(1,14+1) - N_seq_TF= [1, 2, 3, 4] + TF_nb = range(1,14+1) + N_seq_TF= [1] #, 2, 3, 4] fig, ax = plt.subplots(ncols=3, figsize=(30, 8)) for in_it in inner_its: for dataug in dataug_epoch_starts: + + n_tf = TF_nb #filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF)-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(n_tf, dataug, in_it) for n_tf in TF_nb] - filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(TF_nb, n_tf, dataug, in_it) for n_tf in N_seq_TF] - filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-100 epochs (dataug:{})- {} in_it.json".format(TF_nb, n_tf, dataug, in_it) for n_tf in N_seq_TF] + filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(n_tf, 1, dataug, in_it) for n_tf in TF_nb] + + #n_tf = N_seq_TF + #filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-200 epochs (dataug:{})- {} in_it.json".format(TF_nb, n_tf, dataug, in_it) for n_tf in N_seq_TF] + #filenames =["res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-{} TF x {})-LeNet)-100 epochs (dataug:{})- {} in_it.json".format(TF_nb, n_tf, dataug, in_it) for n_tf in N_seq_TF] all_data=[] @@ -41,9 +46,7 @@ if __name__ == "__main__": with open(file) as json_file: data = json.load(json_file) all_data.append(data) - - n_tf = N_seq_TF - #n_tf = [len(x["Param_names"]) for x in all_data] + acc = [x["Accuracy"] for x in all_data] epochs = [len(x["Log"]) for x in all_data] time = [x["Time"][0] for x in all_data] diff --git a/higher/model.py b/higher/model.py index fe7e609..3fd6435 100644 --- a/higher/model.py +++ b/higher/model.py @@ -54,6 +54,7 @@ class LeNet(nn.Module): ## Wide ResNet ## #https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py #https://github.com/arcelien/pba/blob/master/pba/wrn.py +#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py class BasicBlock(nn.Module): def __init__(self, in_planes, out_planes, stride, dropRate=0.0): super(BasicBlock, self).__init__() @@ -97,9 +98,10 @@ class WideResNet(nn.Module): def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0): super(WideResNet, self).__init__() - kernel_size = wrn_size + self.kernel_size = wrn_size + self.depth=depth filter_size = 3 - nChannels = [min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4] + nChannels = [min(self.kernel_size, 16), self.kernel_size, self.kernel_size * 2, self.kernel_size * 4] strides = [1, 2, 2] # stride for each resblock #nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] @@ -137,4 +139,10 @@ class WideResNet(nn.Module): out = self.relu(self.bn1(out)) out = F.avg_pool2d(out, 8) out = out.view(-1, self.nChannels) - return self.fc(out) \ No newline at end of file + return self.fc(out) + + def architecture(self): + return super(WideResNet, self).__str__() + + def __str__(self): + return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth) \ No newline at end of file diff --git a/higher/res/TF_nb_tests/Aug_mod(Data_augV4(Uniform-14 TF x 1)-LeNet)-200 epochs (dataug:0)- 0 in_it.png b/higher/res/TF_nb_tests/Aug_mod(Data_augV4(Uniform-14 TF x 1)-LeNet)-200 epochs (dataug:0)- 0 in_it.png index afc4be4..235f854 100644 Binary files a/higher/res/TF_nb_tests/Aug_mod(Data_augV4(Uniform-14 TF x 1)-LeNet)-200 epochs (dataug:0)- 0 in_it.png and b/higher/res/TF_nb_tests/Aug_mod(Data_augV4(Uniform-14 TF x 1)-LeNet)-200 epochs (dataug:0)- 0 in_it.png differ diff --git a/higher/res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-14 TF x 1)-LeNet)-200 epochs (dataug:0)- 0 in_it.json b/higher/res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-14 TF x 1)-LeNet)-200 epochs (dataug:0)- 0 in_it.json index ddc0284..e403b99 100644 --- a/higher/res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-14 TF x 1)-LeNet)-200 epochs (dataug:0)- 0 in_it.json +++ b/higher/res/TF_nb_tests/log/Aug_mod(Data_augV4(Uniform-14 TF x 1)-LeNet)-200 epochs (dataug:0)- 0 in_it.json @@ -1,10 +1,10 @@ { - "Accuracy": 68.16, + "Accuracy": 68.66, "Time": [ - 31.493859366909113, - 2.15108059308094 + 33.33878931903108, + 0.7264083224707042 ], - "Device": "GeForce GTX 1060 with Max-Q Design", + "Device": "TITAN RTX", "Param_names": [ "Identity", "FlipUD", @@ -24,10 +24,10 @@ "Log": [ { "epoch": 1, - "train_loss": 2.275712251663208, - "val_loss": 2.2696399688720703, - "acc": 13.0, - "time": 27.403998994999995, + "train_loss": 2.2361512184143066, + "val_loss": 2.2366385459899902, + "acc": 17.88, + "time": 33.05993920400215, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -47,10 +47,10 @@ }, { "epoch": 2, - "train_loss": 2.117508888244629, - "val_loss": 1.9593552350997925, - "acc": 28.77, - "time": 27.006608027000006, + "train_loss": 2.0710363388061523, + "val_loss": 1.9404091835021973, + "acc": 28.64, + "time": 33.911141052005405, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -70,10 +70,10 @@ }, { "epoch": 3, - "train_loss": 2.000314235687256, - "val_loss": 1.8604073524475098, - "acc": 33.53, - "time": 27.250486267, + "train_loss": 1.9797052145004272, + "val_loss": 1.9548026323318481, + "acc": 29.91, + "time": 32.924524751004355, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -93,10 +93,10 @@ }, { "epoch": 4, - "train_loss": 1.8282198905944824, - "val_loss": 1.8107564449310303, - "acc": 36.82, - "time": 27.057047364, + "train_loss": 1.8238189220428467, + "val_loss": 1.7366312742233276, + "acc": 37.03, + "time": 33.56608005799353, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -116,10 +116,10 @@ }, { "epoch": 5, - "train_loss": 1.7922232151031494, - "val_loss": 1.579028606414795, - "acc": 43.82, - "time": 26.995444665000008, + "train_loss": 1.8354424238204956, + "val_loss": 1.7071852684020996, + "acc": 40.78, + "time": 34.506957949000935, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -139,10 +139,10 @@ }, { "epoch": 6, - "train_loss": 2.067457914352417, - "val_loss": 1.7379605770111084, - "acc": 39.23, - "time": 27.93069691299999, + "train_loss": 1.860048532485962, + "val_loss": 1.5845730304718018, + "acc": 44.11, + "time": 33.044019629000104, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -162,10 +162,10 @@ }, { "epoch": 7, - "train_loss": 1.7192401885986328, - "val_loss": 1.6571940183639526, - "acc": 44.41, - "time": 31.668140728999987, + "train_loss": 1.8181153535842896, + "val_loss": 1.5851922035217285, + "acc": 46.19, + "time": 32.89763066800515, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -185,10 +185,10 @@ }, { "epoch": 8, - "train_loss": 1.5483973026275635, - "val_loss": 1.5771993398666382, - "acc": 47.58, - "time": 32.578817886999985, + "train_loss": 1.7653775215148926, + "val_loss": 1.523246169090271, + "acc": 47.09, + "time": 32.92637052300415, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -208,10 +208,10 @@ }, { "epoch": 9, - "train_loss": 1.8271228075027466, - "val_loss": 1.3951427936553955, - "acc": 48.88, - "time": 35.423403560999986, + "train_loss": 1.747633934020996, + "val_loss": 1.5431556701660156, + "acc": 49.06, + "time": 32.998405736005225, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -231,10 +231,10 @@ }, { "epoch": 10, - "train_loss": 1.4928618669509888, - "val_loss": 1.4250946044921875, - "acc": 49.95, - "time": 34.642475778999994, + "train_loss": 1.5814456939697266, + "val_loss": 1.4110125303268433, + "acc": 49.68, + "time": 33.171790264001174, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -254,10 +254,10 @@ }, { "epoch": 11, - "train_loss": 1.3911683559417725, - "val_loss": 1.4054059982299805, - "acc": 51.49, - "time": 32.22097354599998, + "train_loss": 1.6201672554016113, + "val_loss": 1.4525870084762573, + "acc": 50.82, + "time": 33.52474787200481, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -277,10 +277,10 @@ }, { "epoch": 12, - "train_loss": 1.5814663171768188, - "val_loss": 1.2640619277954102, - "acc": 53.28000000000001, - "time": 34.08548932100001, + "train_loss": 1.3984065055847168, + "val_loss": 1.3418843746185303, + "acc": 52.43, + "time": 34.218198128997756, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -300,10 +300,10 @@ }, { "epoch": 13, - "train_loss": 1.447560429573059, - "val_loss": 1.279397964477539, - "acc": 53.55, - "time": 32.63061355899998, + "train_loss": 1.3165555000305176, + "val_loss": 1.3960962295532227, + "acc": 52.32, + "time": 33.146675167001376, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -323,10 +323,10 @@ }, { "epoch": 14, - "train_loss": 1.4522494077682495, - "val_loss": 1.243777871131897, - "acc": 55.22, - "time": 31.293837975999963, + "train_loss": 1.5227468013763428, + "val_loss": 1.3094451427459717, + "acc": 53.18, + "time": 33.75051434899797, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -346,10 +346,10 @@ }, { "epoch": 15, - "train_loss": 1.3711776733398438, - "val_loss": 1.4037364721298218, - "acc": 56.18, - "time": 31.95970496299998, + "train_loss": 1.317140817642212, + "val_loss": 1.2790968418121338, + "acc": 54.29, + "time": 32.75788728400221, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -369,10 +369,10 @@ }, { "epoch": 16, - "train_loss": 1.3742002248764038, - "val_loss": 1.1787173748016357, - "acc": 57.75, - "time": 31.958187341999917, + "train_loss": 1.4629753828048706, + "val_loss": 1.269505500793457, + "acc": 55.0, + "time": 33.426912771996285, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -392,10 +392,10 @@ }, { "epoch": 17, - "train_loss": 1.3894418478012085, - "val_loss": 1.195398211479187, - "acc": 57.56, - "time": 31.539487654000027, + "train_loss": 1.4578170776367188, + "val_loss": 1.43046236038208, + "acc": 51.8, + "time": 34.40564017499855, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -415,10 +415,10 @@ }, { "epoch": 18, - "train_loss": 1.3557854890823364, - "val_loss": 1.1568230390548706, - "acc": 59.31999999999999, - "time": 32.22961745200007, + "train_loss": 1.3679463863372803, + "val_loss": 1.4067586660385132, + "acc": 57.71, + "time": 33.55429620000359, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -438,10 +438,10 @@ }, { "epoch": 19, - "train_loss": 1.3807461261749268, - "val_loss": 1.2305530309677124, - "acc": 58.5, - "time": 32.1257128609999, + "train_loss": 1.3431265354156494, + "val_loss": 1.070261836051941, + "acc": 58.48, + "time": 33.07030031100294, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -461,10 +461,10 @@ }, { "epoch": 20, - "train_loss": 1.226737141609192, - "val_loss": 1.148327112197876, - "acc": 59.660000000000004, - "time": 31.289705940999966, + "train_loss": 1.4094510078430176, + "val_loss": 1.123005747795105, + "acc": 59.68, + "time": 32.539832073998696, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -484,10 +484,10 @@ }, { "epoch": 21, - "train_loss": 1.344757080078125, - "val_loss": 1.1047474145889282, - "acc": 60.8, - "time": 31.70342632699999, + "train_loss": 1.217442274093628, + "val_loss": 1.233298659324646, + "acc": 56.61, + "time": 32.36720731999958, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -507,10 +507,10 @@ }, { "epoch": 22, - "train_loss": 1.1737595796585083, - "val_loss": 1.1111972332000732, - "acc": 61.7, - "time": 35.76539084199999, + "train_loss": 1.1805152893066406, + "val_loss": 1.266473650932312, + "acc": 59.68, + "time": 32.94165890399745, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -530,10 +530,10 @@ }, { "epoch": 23, - "train_loss": 1.1932852268218994, - "val_loss": 1.1477299928665161, - "acc": 58.440000000000005, - "time": 36.96900732300003, + "train_loss": 1.2170499563217163, + "val_loss": 1.1163690090179443, + "acc": 59.34, + "time": 33.072550057004264, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -553,10 +553,10 @@ }, { "epoch": 24, - "train_loss": 1.2846112251281738, - "val_loss": 1.228249192237854, - "acc": 61.11, - "time": 33.928253889000075, + "train_loss": 1.413037657737732, + "val_loss": 1.112147331237793, + "acc": 61.26, + "time": 33.176103555000736, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -576,10 +576,10 @@ }, { "epoch": 25, - "train_loss": 1.1475807428359985, - "val_loss": 1.1381301879882812, - "acc": 62.88, - "time": 31.589221180000095, + "train_loss": 1.1219078302383423, + "val_loss": 0.9503553509712219, + "acc": 60.64, + "time": 32.70201930800249, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -599,10 +599,10 @@ }, { "epoch": 26, - "train_loss": 1.1768566370010376, - "val_loss": 1.1205030679702759, - "acc": 63.07000000000001, - "time": 32.330985730000066, + "train_loss": 1.3311303853988647, + "val_loss": 1.1181837320327759, + "acc": 61.74, + "time": 33.267298737999226, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -622,10 +622,10 @@ }, { "epoch": 27, - "train_loss": 0.9738653302192688, - "val_loss": 1.163575291633606, - "acc": 62.09, - "time": 32.006105488, + "train_loss": 1.1976383924484253, + "val_loss": 1.1429328918457031, + "acc": 60.51, + "time": 30.91484549499728, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -645,10 +645,10 @@ }, { "epoch": 28, - "train_loss": 0.9619320631027222, - "val_loss": 0.9622966647148132, - "acc": 63.690000000000005, - "time": 32.37410356600003, + "train_loss": 1.1824053525924683, + "val_loss": 1.1777867078781128, + "acc": 62.84, + "time": 32.26152072900004, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -668,10 +668,10 @@ }, { "epoch": 29, - "train_loss": 1.0292567014694214, - "val_loss": 0.8874693512916565, - "acc": 63.99, - "time": 31.567717850000008, + "train_loss": 1.0975511074066162, + "val_loss": 1.1751550436019897, + "acc": 61.28, + "time": 33.74328628699732, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -691,10 +691,10 @@ }, { "epoch": 30, - "train_loss": 1.151373267173767, - "val_loss": 1.0104514360427856, - "acc": 64.53999999999999, - "time": 31.407606353000006, + "train_loss": 1.4155818223953247, + "val_loss": 1.0619454383850098, + "acc": 62.48, + "time": 33.73068693000096, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -714,10 +714,10 @@ }, { "epoch": 31, - "train_loss": 1.2468770742416382, - "val_loss": 0.9727469682693481, - "acc": 63.7, - "time": 31.579545977999942, + "train_loss": 1.1282376050949097, + "val_loss": 1.0351556539535522, + "acc": 63.54, + "time": 33.51944501199614, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -737,10 +737,10 @@ }, { "epoch": 32, - "train_loss": 1.304778814315796, - "val_loss": 1.004093885421753, - "acc": 65.53999999999999, - "time": 31.671146249000003, + "train_loss": 1.1249756813049316, + "val_loss": 1.1129587888717651, + "acc": 64.21, + "time": 34.15534235299856, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -760,10 +760,10 @@ }, { "epoch": 33, - "train_loss": 1.1888853311538696, - "val_loss": 1.1745442152023315, - "acc": 65.16, - "time": 30.698943410000084, + "train_loss": 1.0925499200820923, + "val_loss": 1.1527929306030273, + "acc": 62.82, + "time": 35.872784801002126, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -783,10 +783,10 @@ }, { "epoch": 34, - "train_loss": 0.8327383399009705, - "val_loss": 0.871751606464386, - "acc": 66.43, - "time": 30.610187330999906, + "train_loss": 1.175700306892395, + "val_loss": 1.0398688316345215, + "acc": 64.41, + "time": 33.69705061199784, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -806,10 +806,10 @@ }, { "epoch": 35, - "train_loss": 0.9476880431175232, - "val_loss": 0.9127776026725769, - "acc": 66.16, - "time": 31.13587345100018, + "train_loss": 1.1436457633972168, + "val_loss": 1.0456116199493408, + "acc": 63.56, + "time": 34.21032336400094, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -829,10 +829,10 @@ }, { "epoch": 36, - "train_loss": 1.1106438636779785, - "val_loss": 1.0809861421585083, - "acc": 66.08000000000001, - "time": 30.617437078000194, + "train_loss": 0.9975263476371765, + "val_loss": 1.1174721717834473, + "acc": 64.61, + "time": 33.310891945999174, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -852,10 +852,10 @@ }, { "epoch": 37, - "train_loss": 1.0052621364593506, - "val_loss": 0.9276190996170044, - "acc": 65.67, - "time": 30.505723940000053, + "train_loss": 1.0764957666397095, + "val_loss": 1.1194300651550293, + "acc": 64.2, + "time": 33.54629349899915, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -875,10 +875,10 @@ }, { "epoch": 38, - "train_loss": 1.0943366289138794, - "val_loss": 0.9282073974609375, - "acc": 66.5, - "time": 31.413280156000155, + "train_loss": 0.901468813419342, + "val_loss": 0.907350480556488, + "acc": 65.2, + "time": 33.36209933200007, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -898,10 +898,10 @@ }, { "epoch": 39, - "train_loss": 1.065986156463623, - "val_loss": 1.0112760066986084, - "acc": 67.16, - "time": 31.774527946000035, + "train_loss": 0.9623751640319824, + "val_loss": 0.9955642819404602, + "acc": 66.11, + "time": 33.1279754030038, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -921,10 +921,10 @@ }, { "epoch": 40, - "train_loss": 0.8673455715179443, - "val_loss": 0.9415293335914612, - "acc": 66.0, - "time": 31.496526270000004, + "train_loss": 1.0192790031433105, + "val_loss": 0.965372622013092, + "acc": 66.46, + "time": 32.999066793003294, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -944,10 +944,10 @@ }, { "epoch": 41, - "train_loss": 0.9649273753166199, - "val_loss": 0.8739274144172668, - "acc": 68.16, - "time": 31.461511772999984, + "train_loss": 0.9529908299446106, + "val_loss": 0.9809656739234924, + "acc": 64.65, + "time": 33.0327638939998, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -967,10 +967,10 @@ }, { "epoch": 42, - "train_loss": 1.31170654296875, - "val_loss": 1.0550568103790283, - "acc": 65.61, - "time": 31.239886107000075, + "train_loss": 0.94074946641922, + "val_loss": 1.0360963344573975, + "acc": 66.3, + "time": 32.90476281900192, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -990,10 +990,10 @@ }, { "epoch": 43, - "train_loss": 0.732260525226593, - "val_loss": 1.031165599822998, - "acc": 67.03, - "time": 31.145920496000144, + "train_loss": 0.8608241677284241, + "val_loss": 0.8884866237640381, + "acc": 66.43, + "time": 32.79904112499935, "param": [ 0.0714285746216774, 0.0714285746216774, @@ -1013,10 +1013,493 @@ }, { "epoch": 44, - "train_loss": 0.8465375900268555, - "val_loss": 1.0749541521072388, - "acc": 67.23, - "time": 31.44703260900019, + "train_loss": 0.8138001561164856, + "val_loss": 1.0393797159194946, + "acc": 65.34, + "time": 33.71409166199737, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 45, + "train_loss": 0.7662962675094604, + "val_loss": 0.9524388909339905, + "acc": 66.75, + "time": 33.36418350299937, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 46, + "train_loss": 0.8844456076622009, + "val_loss": 0.9529110789299011, + "acc": 67.34, + "time": 32.49647754200123, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 47, + "train_loss": 0.7100163698196411, + "val_loss": 1.048550009727478, + "acc": 65.53, + "time": 34.56229288999748, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 48, + "train_loss": 0.9129636883735657, + "val_loss": 0.8993743658065796, + "acc": 66.17, + "time": 34.36059833099716, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 49, + "train_loss": 0.7252851724624634, + "val_loss": 0.927655816078186, + "acc": 66.43, + "time": 35.04270024399739, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 50, + "train_loss": 0.833204984664917, + "val_loss": 0.811994194984436, + "acc": 67.2, + "time": 33.50954103699769, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 51, + "train_loss": 0.863890528678894, + "val_loss": 1.0132911205291748, + "acc": 67.71, + "time": 33.39508569499594, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 52, + "train_loss": 0.8476536273956299, + "val_loss": 0.9530481100082397, + "acc": 67.11, + "time": 33.77243350000208, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 53, + "train_loss": 1.0126402378082275, + "val_loss": 0.8366172313690186, + "acc": 66.89, + "time": 33.44726553199871, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 54, + "train_loss": 0.8542940616607666, + "val_loss": 1.0701818466186523, + "acc": 66.42, + "time": 33.40405249399919, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 55, + "train_loss": 1.0342178344726562, + "val_loss": 0.9405211210250854, + "acc": 67.8, + "time": 33.316023687999405, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 56, + "train_loss": 0.650306224822998, + "val_loss": 0.9202330708503723, + "acc": 67.55, + "time": 34.273164316000475, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 57, + "train_loss": 0.7297669649124146, + "val_loss": 1.060660719871521, + "acc": 67.29, + "time": 33.44395486100257, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 58, + "train_loss": 0.5833786725997925, + "val_loss": 0.9211423993110657, + "acc": 68.26, + "time": 32.726114864999545, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 59, + "train_loss": 0.6446430683135986, + "val_loss": 0.9914183616638184, + "acc": 67.31, + "time": 32.50800098099717, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 60, + "train_loss": 0.7211475968360901, + "val_loss": 1.033563256263733, + "acc": 67.71, + "time": 32.57526234900433, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 61, + "train_loss": 0.6638731956481934, + "val_loss": 1.1473658084869385, + "acc": 66.88, + "time": 32.88343731400528, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 62, + "train_loss": 0.7230252623558044, + "val_loss": 0.8214514851570129, + "acc": 68.38, + "time": 32.53498285599926, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 63, + "train_loss": 0.7875636219978333, + "val_loss": 0.8595832586288452, + "acc": 68.66, + "time": 32.76464586200018, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 64, + "train_loss": 0.9066817760467529, + "val_loss": 0.9415571093559265, + "acc": 67.72, + "time": 32.93092659299873, + "param": [ + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774, + 0.0714285746216774 + ] + }, + { + "epoch": 65, + "train_loss": 0.8445136547088623, + "val_loss": 0.8770421147346497, + "acc": 68.24, + "time": 33.87918717900175, "param": [ 0.0714285746216774, 0.0714285746216774, diff --git a/higher/res/TF_seq_tests_compare.png b/higher/res/TF_seq_tests_compare.png index 229885f..517a0b2 100644 Binary files a/higher/res/TF_seq_tests_compare.png and b/higher/res/TF_seq_tests_compare.png differ diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 70a01f5..667169e 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -38,13 +38,13 @@ else: if __name__ == "__main__": n_inner_iter = 10 - epochs = 200 + epochs = 2 dataug_epoch_start=0 #### Classic #### ''' - model = LeNet(3,10).to(device) - #model = torchvision.models.resnet18() + #model = LeNet(3,10).to(device) + model = WideResNet(num_classes=10, wrn_size=16).to(device) #model = Augmented_model(Data_augV3(mix_dist=0.0), LeNet(3,10)).to(device) #model.augment(mode=False) @@ -69,31 +69,32 @@ if __name__ == "__main__": tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), LeNet(3,10)).to(device) + #aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10) #### - plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)) + plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), param_names=tf_names) print('-'*9) times = [x["time"] for x in log] out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} - print(str(aug_model),": acc", out["Accuracy"], "in (s ?):", out["Time"][0], "+/-", out["Time"][1]) + print(str(aug_model),": acc", out["Accuracy"], "in (s?):", out["Time"][0], "+/-", out["Time"][1]) with open("res/log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f: json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') - print('Execution Time : %.00f (s ?)'%(time.process_time() - t0)) + print('Execution Time : %.00f (s?)'%(time.process_time() - t0)) print('-'*9) #''' #### TF number tests #### ''' res_folder="res/TF_nb_tests/" - epochs= 100 + epochs= 200 inner_its = [10] dataug_epoch_starts= [0] - TF_nb = [len(TF.TF_dict)] #range(1,len(TF.TF_dict)+1) - N_seq_TF= [1, 2, 3, 4] + TF_nb = range(1,len(TF.TF_dict)+1) #[len(TF.TF_dict)] + N_seq_TF= [1] #[1, 2, 3, 4] try: os.mkdir(res_folder) @@ -106,7 +107,6 @@ if __name__ == "__main__": for dataug_epoch_start in dataug_epoch_starts: print("---Starting dataug", dataug_epoch_start,"---") for n_tf in N_seq_TF: - print("---Starting N_TF", n_tf,"---") for i in TF_nb: keys = list(TF.TF_dict.keys())[0:i] ntf_dict = {k: TF.TF_dict[k] for k in keys} @@ -114,7 +114,7 @@ if __name__ == "__main__": aug_model = Augmented_model(Data_augV4(TF_dict=ntf_dict, N_TF=n_tf, mix_dist=0.0), LeNet(3,10)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10) #### plot_res(log, fig_name=res_folder+"{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)) @@ -127,6 +127,4 @@ if __name__ == "__main__": print('Log :\"',f.name, '\" saved !') print('-'*9) - ''' - - \ No newline at end of file + ''' \ No newline at end of file diff --git a/higher/utils.py b/higher/utils.py index bfa0e2e..eb1be55 100644 --- a/higher/utils.py +++ b/higher/utils.py @@ -15,7 +15,7 @@ def print_graph(PyTorch_obj, fig_name='graph'): graph.format = 'svg' #https://graphviz.readthedocs.io/en/stable/manual.html#formats graph.render(fig_name) -def plot_res(log, fig_name='res'): +def plot_res(log, fig_name='res', param_names=None): epochs = [x["epoch"] for x in log] @@ -36,10 +36,13 @@ def plot_res(log, fig_name='res'): ax[2].legend() else : ax[2].set_title('Prob') - for idx, _ in enumerate(log[0]["param"]): - ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx)) - ax[2].legend() - #ax[2].legend(('P-0', 'P-45', 'P-180')) + #for idx, _ in enumerate(log[0]["param"]): + #ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx)) + if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])] + proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])] + ax[2].stackplot(epochs, proba, labels=param_names) + ax[2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5)) + fig_name = fig_name.replace('.',',') plt.savefig(fig_name) @@ -193,6 +196,20 @@ def print_torch_mem(add_info=''): #print(add_info, "-Garbage size :",len(gc.garbage)) + """Simple GPU memory report.""" + + mega_bytes = 1024.0 * 1024.0 + string = add_info + ' memory (MB)' + string += ' | allocated: {}'.format( + torch.cuda.memory_allocated() / mega_bytes) + string += ' | max allocated: {}'.format( + torch.cuda.max_memory_allocated() / mega_bytes) + string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) + string += ' | max cached: {}'.format( + torch.cuda.max_memory_cached()/ mega_bytes) + print(string) + + class loss_monitor(): #Voir https://github.com/pytorch/ignite def __init__(self, patience, end_train=1): self.patience = patience