My Project
|
Public Member Functions | |
def | __init__ (self, model) |
def | get_diffopt (self, opt, grad_callback=None, track_higher_grads=True) |
def | forward (self, x) |
def | detach_ (self) |
def | state_dict (self) |
def | __getitem__ (self, key) |
def | __str__ (self) |
Model wrapper for higher gradient tracking. Keep in memory the orginial model and it's functionnal, higher, version. Might not be needed anymore if Higher implement detach for fmodel. see : https://github.com/facebookresearch/higher TODO: Get rid of the original model if not needed by user. Attributes: _name (string): Name of the model. _mods (nn.ModuleDict): Models (Orginial and Higher version).
def dataug.Higher_model.__init__ | ( | self, | |
model | |||
) |
Init Higher_model. Args: model (nn.Module): Network for which higher gradients can be tracked.
def dataug.Higher_model.__getitem__ | ( | self, | |
key | |||
) |
Access to modules Args: key (string): Name of the module to access. Returns: nn.Module.
def dataug.Higher_model.__str__ | ( | self | ) |
Name of the module Returns: String containing the name of the module.
def dataug.Higher_model.detach_ | ( | self | ) |
Detach from the graph. Needed to limit the number of state kept in memory.
def dataug.Higher_model.forward | ( | self, | |
x | |||
) |
Main method of the model. Args: x (Tensor): Batch of data. Returns: Tensor : Output of the network. Should be logits.
def dataug.Higher_model.get_diffopt | ( | self, | |
opt, | |||
grad_callback = None , |
|||
track_higher_grads = True |
|||
) |
Get a differentiable version of an Optimizer. Higher/Differentiable optimizer required to be used for higher gradient tracking. Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step) Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called. Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation. Args: opt (torch.optim): Optimizer to make differentiable. grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None) track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True) Returns: (Higher.DifferentiableOptimizer): Differentiable version of the optimizer.
def dataug.Higher_model.state_dict | ( | self | ) |
Returns a dictionary containing a whole state of the module.