Minor improvement + Comments

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-21 13:53:07 -05:00
parent d21a6bbf5c
commit c1ad787d97
5 changed files with 165 additions and 62 deletions

View file

@ -6,15 +6,51 @@ import torch.nn.functional as F
import higher
class Higher_model(nn.Module):
"""Model wrapper for higher gradient tracking.
Keep in memory the orginial model and it's functionnal, higher, version.
Might not be needed anymore if Higher implement detach for fmodel.
see : https://github.com/facebookresearch/higher
TODO: Get rid of the original model if not needed by user.
Attributes:
_name (string): Name of the model.
_mods (nn.ModuleDict): Models (Orginial and Higher version).
"""
def __init__(self, model):
"""Init Higher_model.
Args:
model (nn.Module): Network for which higher gradients can be tracked.
"""
super(Higher_model, self).__init__()
self._name = model.__str__()
self._mods = nn.ModuleDict({
'original': model,
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
})
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
"""Get a differentiable version of an Optimizer.
Higher/Differentiable optimizer required to be used for higher gradient tracking.
Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step)
Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called.
Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation.
Args:
opt (torch.optim): Optimizer to make differentiable.
grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None)
track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True)
Returns:
(Higher.DifferentiableOptimizer): Differentiable version of the optimizer.
"""
return higher.optim.get_diff_optim(opt,
self._mods['original'].parameters(),
fmodel=self._mods['functional'],
@ -22,20 +58,49 @@ class Higher_model(nn.Module):
track_higher_grads=track_higher_grads)
def forward(self, x):
""" Main method of the model.
Args:
x (Tensor): Batch of data.
Returns:
Tensor : Output of the network. Should be logits.
"""
return self._mods['functional'](x)
def detach_(self):
"""Detach from the graph.
Needed to limit the number of state kept in memory.
"""
tmp = self._mods['functional'].fast_params
self._mods['functional']._fast_params=[]
self._mods['functional'].update_params(tmp)
for p in self._mods['functional'].fast_params:
p.detach_().requires_grad_()
def state_dict(self):
"""Returns a dictionary containing a whole state of the module.
"""
return self._mods['functional'].state_dict()
def __getitem__(self, key):
"""Access to modules
Args:
key (string): Name of the module to access.
Returns:
nn.Module.
"""
return self._mods[key]
def __str__(self):
return self._mods['original'].__str__()
"""Name of the module
Returns:
String containing the name of the module.
"""
return self._name
## Basic CNN ##
class LeNet_F(nn.Module):