diff --git a/higher/smart_aug/dataug.py b/higher/smart_aug/dataug.py index e5d6600..0b459f1 100755 --- a/higher/smart_aug/dataug.py +++ b/higher/smart_aug/dataug.py @@ -19,6 +19,12 @@ import copy import transformations as TF +import higher +import higher_patch + +from utils import clip_norm +from train_utils import compute_vaLoss + ### Data augmenter ### class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) """Data augmentation module with learnable parameters. @@ -798,7 +804,6 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag) ### Models ### -import higher class Higher_model(nn.Module): """Model wrapper for higher gradient tracking. @@ -897,8 +902,6 @@ class Higher_model(nn.Module): """ return self._name -from utils import clip_norm -from train_utils import compute_vaLoss class Augmented_model(nn.Module): """Wrapper for a Data Augmentation module and a model. diff --git a/higher/smart_aug/higher_patch.py b/higher/smart_aug/higher_patch.py new file mode 100644 index 0000000..f6d3182 --- /dev/null +++ b/higher/smart_aug/higher_patch.py @@ -0,0 +1,21 @@ + +import higher +import torch as _torch + +def detach_(self): + """Removes all params from their compute graph in place.""" + # detach param groups + for group in self.param_groups: + for k, v in group.items(): + if isinstance(v,_torch.Tensor): + v.detach_().requires_grad_() + + # detach state + for state_dict in self.state: + for k,v_dict in state_dict.items(): + if isinstance(k,_torch.Tensor): k.detach_().requires_grad_() + for k2,v2 in v_dict.items(): + if isinstance(v2,_torch.Tensor): + v2.detach_().requires_grad_() + +higher.optim.DifferentiableOptimizer.detach_ = detach_ \ No newline at end of file diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index deafa06..682b05b 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -6,6 +6,7 @@ import torch #import torch.optim import torchvision import higher +import higher_patch from datasets import * from utils import * @@ -219,7 +220,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start """ device = next(model.parameters()).device log = [] - #dl_val_it = iter(dl_val) + dl_val_it = iter(dl_val) val_loss=None high_grad_track = True @@ -253,8 +254,8 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start for epoch in range(1, epochs+1): t0 = time.perf_counter() - dl_train, dl_val = cvs.next_split() - dl_val_it = iter(dl_val) + #dl_train, dl_val = cvs.next_split() + #dl_val_it = iter(dl_val) for i, (xs, ys) in enumerate(dl_train): xs, ys = xs.to(device), ys.to(device)