From 7221142a9a1bd99fd3e0db99b9274af7624bc0a4 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Wed, 5 Feb 2020 12:23:23 -0500 Subject: [PATCH] Commentaires --- higher/smart_aug/higher_patch.py | 36 ++++++++++++++++++++------------ higher/smart_aug/train_utils.py | 1 + 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/higher/smart_aug/higher_patch.py b/higher/smart_aug/higher_patch.py index f6d3182..d9a6f90 100644 --- a/higher/smart_aug/higher_patch.py +++ b/higher/smart_aug/higher_patch.py @@ -1,21 +1,31 @@ +""" Patch for Higher package + + Recommended use :: + + import higher + import higher_patch + Might become unnecessary with future update of the Higher package. +""" import higher import torch as _torch def detach_(self): - """Removes all params from their compute graph in place.""" - # detach param groups - for group in self.param_groups: - for k, v in group.items(): - if isinstance(v,_torch.Tensor): - v.detach_().requires_grad_() + """Removes all params from their compute graph in place. - # detach state - for state_dict in self.state: - for k,v_dict in state_dict.items(): - if isinstance(k,_torch.Tensor): k.detach_().requires_grad_() - for k2,v2 in v_dict.items(): - if isinstance(v2,_torch.Tensor): - v2.detach_().requires_grad_() + """ + # detach param groups + for group in self.param_groups: + for k, v in group.items(): + if isinstance(v,_torch.Tensor): + v.detach_().requires_grad_() + + # detach state + for state_dict in self.state: + for k,v_dict in state_dict.items(): + if isinstance(k,_torch.Tensor): k.detach_().requires_grad_() + for k2,v2 in v_dict.items(): + if isinstance(v2,_torch.Tensor): + v2.detach_().requires_grad_() higher.optim.DifferentiableOptimizer.detach_ = detach_ \ No newline at end of file diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index 682b05b..6f36951 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -254,6 +254,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start for epoch in range(1, epochs+1): t0 = time.perf_counter() + #Cross-Validation #dl_train, dl_val = cvs.next_split() #dl_val_it = iter(dl_val)