From f0c0559e731d9598fb573155b990050985fa531c Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Wed, 13 Nov 2019 13:38:00 -0500 Subject: [PATCH] Modification test pour simplifier early stopping --- higher/datasets.py | 4 +-- higher/dataug.py | 3 +- higher/test_dataug.py | 5 +++- higher/train_utils.py | 66 +++++++++++++++++++++++++++---------------- 4 files changed, 49 insertions(+), 29 deletions(-) diff --git a/higher/datasets.py b/higher/datasets.py index 39ba406..7d0589f 100644 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -3,8 +3,8 @@ from torch.utils.data import SubsetRandomSampler import torchvision BATCH_SIZE = 300 -TEST_SIZE = 300 -#TEST_SIZE = 10000 +#TEST_SIZE = 300 +TEST_SIZE = 10000 #ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1] #transform_train = torchvision.transforms.Compose([ diff --git a/higher/dataug.py b/higher/dataug.py index 63d10ef..24654c3 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -553,9 +553,10 @@ class Augmented_model(nn.Module): mode=self._data_augmentation self._mods['data_aug'].augment(mode) super(Augmented_model, self).train(mode) + return self def eval(self): - self.train(mode=False) + return self.train(mode=False) #super(Augmented_model, self).eval() def items(self): diff --git a/higher/test_dataug.py b/higher/test_dataug.py index d14d583..70a01f5 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -65,6 +65,7 @@ if __name__ == "__main__": ''' #### Augmented Model #### #''' + t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), LeNet(3,10)).to(device) @@ -77,10 +78,12 @@ if __name__ == "__main__": print('-'*9) times = [x["time"] for x in log] out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} - print(str(aug_model),": acc", out["Accuracy"], "in (ms):", out["Time"][0], "+/-", out["Time"][1]) + print(str(aug_model),": acc", out["Accuracy"], "in (s ?):", out["Time"][0], "+/-", out["Time"][1]) with open("res/log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f: json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') + + print('Execution Time : %.00f (s ?)'%(time.process_time() - t0)) print('-'*9) #''' #### TF number tests #### diff --git a/higher/train_utils.py b/higher/train_utils.py index 47560bd..bf970a6 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -9,26 +9,43 @@ from utils import * def test(model): device = next(model.parameters()).device model.eval() - for i, (features, labels) in enumerate(dl_test): - features,labels = features.to(device), labels.to(device) - pred = model.forward(features) - return pred.argmax(dim=1).eq(labels).sum().item() / dl_test.batch_size * 100 + #for i, (features, labels) in enumerate(dl_test): + # features,labels = features.to(device), labels.to(device) -def compute_loss(model, dl_it, dl): + # pred = model.forward(features) + # return pred.argmax(dim=1).eq(labels).sum().item() / dl_test.batch_size * 100 + + correct = 0 + total = 0 + loss = [] + with torch.no_grad(): + for features, labels in dl_test: + features,labels = features.to(device), labels.to(device) + + outputs = model(features) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + loss.append(F.cross_entropy(outputs, labels).item()) + + accuracy = 100 * correct / total + + return accuracy, np.mean(loss) + +def compute_vaLoss(model, dl_it, dl): device = next(model.parameters()).device try: - xs_val, ys_val = next(dl_it) + xs, ys = next(dl_it) except StopIteration: #Fin epoch val - dl_val_it = iter(dl) - xs_val, ys_val = next(dl_it) - xs_val, ys_val = xs_val.to(device), ys_val.to(device) + dl_it = iter(dl) + xs, ys = next(dl_it) + xs, ys = xs.to(device), ys.to(device) - try: - model.augment(mode=False) #Validation sans transfornations ! - except: - pass - return F.cross_entropy(model(xs_val), ys_val) + model.eval() #Validation sans transfornations ! + + return F.cross_entropy(model(xs), ys) def train_classic(model, epochs=1): device = next(model.parameters()).device @@ -61,7 +78,7 @@ def train_classic(model, epochs=1): xs_val, ys_val = xs_val.to(device), ys_val.to(device) val_loss = F.cross_entropy(model(xs_val), ys_val) - accuracy=test(model) + accuracy, _ =test(model) model.train() #### Log #### data={ @@ -120,7 +137,7 @@ def train_classic_higher(model, epochs=1): xs_val, ys_val = xs_val.to(device), ys_val.to(device) val_loss = F.cross_entropy(model(xs_val), ys_val) - accuracy=test(model) + accuracy, _ =test(model) model.train() #### Log #### data={ @@ -256,7 +273,7 @@ def train_classic_tests(model, epochs=1): xs_val, ys_val = xs_val.to(device), ys_val.to(device) val_loss = F.cross_entropy(model(xs_val), ys_val) - accuracy=test(model) + accuracy, _ =test(model) model.train() #### Log #### data={ @@ -309,7 +326,7 @@ def run_simple_dataug(inner_it, epochs=1): dl_train_it = iter(dl_train) xs, ys = next(dl_train_it) - accuracy=test(aug_model) + accuracy, _ =test(model) aug_model.train() #### Print #### @@ -426,7 +443,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0): #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) #viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) - accuracy=test(model) + accuracy, _ =test(model) model.train() #### Print #### @@ -522,7 +539,6 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f countcopy=0 val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch dl_val_it = iter(dl_val) - dl_test_it = iter(dl_test) #ATTENTION A UTILISER SEULEMT POUR EARLY STOP meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2) inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9) @@ -589,7 +605,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f if(high_grad_track and i%inner_it==0): #Perform Meta step #print("meta") #Peu utile si high_grad_track = False - val_loss = compute_loss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #print_graph(val_loss) @@ -614,20 +630,20 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f countcopy+=1 model_copy(src=fmodel, dst=model) optim_copy(dopt=diffopt, opt=inner_opt) - val_loss = compute_loss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False) fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) - accuracy=test(model) + accuracy, test_loss =test(model) model.train() #### Print #### if(print_freq and epoch%print_freq==0): print('-'*9) print('Epoch : %d/%d'%(epoch,epochs)) - print('Time : %.00f ms'%(tf - t0)) + print('Time : %.00f s'%(tf - t0)) print('Train loss :',loss.item(), '/ val loss', val_loss.item()) print('Accuracy :', accuracy) print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) @@ -648,7 +664,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f ############# if val_loss_monitor : model.eval() - val_loss_monitor.register(compute_loss(model, dl_it=dl_test_it, dl=dl_test))#val_loss.item()) + val_loss_monitor.register(test_loss)#val_loss.item()) if val_loss_monitor.end_training(): break #Stop training model.train()