mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Minor improvement + Comments
This commit is contained in:
parent
d21a6bbf5c
commit
c1ad787d97
5 changed files with 165 additions and 62 deletions
|
@ -6,15 +6,51 @@ import torch.nn.functional as F
|
|||
|
||||
import higher
|
||||
class Higher_model(nn.Module):
|
||||
"""Model wrapper for higher gradient tracking.
|
||||
|
||||
Keep in memory the orginial model and it's functionnal, higher, version.
|
||||
|
||||
Might not be needed anymore if Higher implement detach for fmodel.
|
||||
|
||||
see : https://github.com/facebookresearch/higher
|
||||
|
||||
TODO: Get rid of the original model if not needed by user.
|
||||
|
||||
Attributes:
|
||||
_name (string): Name of the model.
|
||||
_mods (nn.ModuleDict): Models (Orginial and Higher version).
|
||||
"""
|
||||
def __init__(self, model):
|
||||
"""Init Higher_model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Network for which higher gradients can be tracked.
|
||||
"""
|
||||
super(Higher_model, self).__init__()
|
||||
|
||||
self._name = model.__str__()
|
||||
self._mods = nn.ModuleDict({
|
||||
'original': model,
|
||||
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
})
|
||||
|
||||
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
|
||||
"""Get a differentiable version of an Optimizer.
|
||||
|
||||
Higher/Differentiable optimizer required to be used for higher gradient tracking.
|
||||
Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called.
|
||||
Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation.
|
||||
|
||||
Args:
|
||||
opt (torch.optim): Optimizer to make differentiable.
|
||||
grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None)
|
||||
track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True)
|
||||
|
||||
Returns:
|
||||
(Higher.DifferentiableOptimizer): Differentiable version of the optimizer.
|
||||
"""
|
||||
return higher.optim.get_diff_optim(opt,
|
||||
self._mods['original'].parameters(),
|
||||
fmodel=self._mods['functional'],
|
||||
|
@ -22,20 +58,49 @@ class Higher_model(nn.Module):
|
|||
track_higher_grads=track_higher_grads)
|
||||
|
||||
def forward(self, x):
|
||||
""" Main method of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of data.
|
||||
|
||||
Returns:
|
||||
Tensor : Output of the network. Should be logits.
|
||||
"""
|
||||
return self._mods['functional'](x)
|
||||
|
||||
def detach_(self):
|
||||
"""Detach from the graph.
|
||||
|
||||
Needed to limit the number of state kept in memory.
|
||||
"""
|
||||
tmp = self._mods['functional'].fast_params
|
||||
self._mods['functional']._fast_params=[]
|
||||
self._mods['functional'].update_params(tmp)
|
||||
for p in self._mods['functional'].fast_params:
|
||||
p.detach_().requires_grad_()
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
"""
|
||||
return self._mods['functional'].state_dict()
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Access to modules
|
||||
Args:
|
||||
key (string): Name of the module to access.
|
||||
|
||||
Returns:
|
||||
nn.Module.
|
||||
"""
|
||||
return self._mods[key]
|
||||
|
||||
def __str__(self):
|
||||
return self._mods['original'].__str__()
|
||||
"""Name of the module
|
||||
|
||||
Returns:
|
||||
String containing the name of the module.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
## Basic CNN ##
|
||||
class LeNet_F(nn.Module):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue