From 198fb06065769455c83b76a2bbcececa6deb0a06 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Wed, 13 Nov 2019 12:03:54 -0500 Subject: [PATCH] Modification du early stopping (sur test data...) --- higher/datasets.py | 4 ++-- higher/test_dataug.py | 4 ++-- higher/train_utils.py | 26 +++++++++++--------------- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/higher/datasets.py b/higher/datasets.py index 7d0589f..39ba406 100644 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -3,8 +3,8 @@ from torch.utils.data import SubsetRandomSampler import torchvision BATCH_SIZE = 300 -#TEST_SIZE = 300 -TEST_SIZE = 10000 +TEST_SIZE = 300 +#TEST_SIZE = 10000 #ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1] #transform_train = torchvision.transforms.Compose([ diff --git a/higher/test_dataug.py b/higher/test_dataug.py index e4bafac..d14d583 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -37,8 +37,8 @@ else: ########################################## if __name__ == "__main__": - n_inner_iter = 0 - epochs = 100 + n_inner_iter = 10 + epochs = 200 dataug_epoch_start=0 #### Classic #### diff --git a/higher/train_utils.py b/higher/train_utils.py index d3bfa3d..47560bd 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -6,12 +6,6 @@ import higher from datasets import * from utils import * -#Variables a definir -#dl_train = None -#dl_val = None -#dl_test = None -#device = torch.device('cuda') - def test(model): device = next(model.parameters()).device model.eval() @@ -21,13 +15,13 @@ def test(model): pred = model.forward(features) return pred.argmax(dim=1).eq(labels).sum().item() / dl_test.batch_size * 100 -def compute_vaLoss(model, dl_val_it): +def compute_loss(model, dl_it, dl): device = next(model.parameters()).device try: - xs_val, ys_val = next(dl_val_it) + xs_val, ys_val = next(dl_it) except StopIteration: #Fin epoch val - dl_val_it = iter(dl_val) - xs_val, ys_val = next(dl_val_it) + dl_val_it = iter(dl) + xs_val, ys_val = next(dl_it) xs_val, ys_val = xs_val.to(device), ys_val.to(device) try: @@ -528,6 +522,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f countcopy=0 val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch dl_val_it = iter(dl_val) + dl_test_it = iter(dl_test) #ATTENTION A UTILISER SEULEMT POUR EARLY STOP meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2) inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9) @@ -542,7 +537,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f val_loss_monitor= None if loss_patience != None : if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start - else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor + else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor (Not on val data : used by Dataug... => Test data) model.train() @@ -594,7 +589,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f if(high_grad_track and i%inner_it==0): #Perform Meta step #print("meta") #Peu utile si high_grad_track = False - val_loss = compute_vaLoss(model=fmodel, dl_val_it=dl_val_it) + val_loss = compute_loss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #print_graph(val_loss) @@ -619,7 +614,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f countcopy+=1 model_copy(src=fmodel, dst=model) optim_copy(dopt=diffopt, opt=inner_opt) - val_loss = compute_vaLoss(model=fmodel, dl_val_it=dl_val_it) + val_loss = compute_loss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False) fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) @@ -652,9 +647,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f log.append(data) ############# if val_loss_monitor : - val_loss_monitor.register(val_loss.item()) + model.eval() + val_loss_monitor.register(compute_loss(model, dl_it=dl_test_it, dl=dl_test))#val_loss.item()) if val_loss_monitor.end_training(): break #Stop training - + model.train() if not model.is_augmenting() and (epoch == dataug_epoch_start or (val_loss_monitor and val_loss_monitor.limit_reached()==1)): print('Starting Data Augmention...')