My Project
Public Member Functions | List of all members
dataug.Higher_model Class Reference
Inheritance diagram for dataug.Higher_model:
Inheritance graph
[legend]
Collaboration diagram for dataug.Higher_model:
Collaboration graph
[legend]

Public Member Functions

def __init__ (self, model)
 
def get_diffopt (self, opt, grad_callback=None, track_higher_grads=True)
 
def forward (self, x)
 
def detach_ (self)
 
def state_dict (self)
 
def __getitem__ (self, key)
 
def __str__ (self)
 

Detailed Description

Model wrapper for higher gradient tracking.

    Keep in memory the orginial model and it's functionnal, higher, version.

    Might not be needed anymore if Higher implement detach for fmodel.

    see : https://github.com/facebookresearch/higher

    TODO: Get rid of the original model if not needed by user.

    Attributes:
        _name (string): Name of the model.
        _mods (nn.ModuleDict): Models (Orginial and Higher version).

Constructor & Destructor Documentation

◆ __init__()

def dataug.Higher_model.__init__ (   self,
  model 
)
Init Higher_model.

    Args:
model (nn.Module): Network for which higher gradients can be tracked.

Member Function Documentation

◆ __getitem__()

def dataug.Higher_model.__getitem__ (   self,
  key 
)
Access to modules
Args:
    key (string): Name of the module to access.

Returns:
    nn.Module.

◆ __str__()

def dataug.Higher_model.__str__ (   self)
Name of the module

    Returns:
String containing the name of the module.

◆ detach_()

def dataug.Higher_model.detach_ (   self)
Detach from the graph.

    Needed to limit the number of state kept in memory.

◆ forward()

def dataug.Higher_model.forward (   self,
  x 
)
Main method of the model.

    Args:
x (Tensor): Batch of data.

    Returns:
Tensor : Output of the network. Should be logits.

◆ get_diffopt()

def dataug.Higher_model.get_diffopt (   self,
  opt,
  grad_callback = None,
  track_higher_grads = True 
)
Get a differentiable version of an Optimizer.

    Higher/Differentiable optimizer required to be used for higher gradient tracking.
    Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step)

    Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called.
    Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation.

    Args:
opt (torch.optim): Optimizer to make differentiable.
grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None)
track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True)

    Returns:
(Higher.DifferentiableOptimizer): Differentiable version of the optimizer.

◆ state_dict()

def dataug.Higher_model.state_dict (   self)
Returns a dictionary containing a whole state of the module.

The documentation for this class was generated from the following file: