My Project
Public Member Functions | List of all members
dataug.Augmented_model Class Reference
Inheritance diagram for dataug.Augmented_model:
Inheritance graph
[legend]
Collaboration diagram for dataug.Augmented_model:
Collaboration graph
[legend]

Public Member Functions

def __init__ (self, data_augmenter, model)
 
def forward (self, x)
 
def augment (self, mode=True)
 
def train (self, mode=True)
 
def eval (self)
 
def items (self)
 
def update (self, modules)
 
def is_augmenting (self)
 
def TF_names (self)
 
def __getitem__ (self, key)
 
def __str__ (self)
 

Detailed Description

Wrapper for a Data Augmentation module and a model.

    Attributes:
        _mods (nn.ModuleDict): A dictionary containing the modules.
        _data_augmentation (bool): Wether data augmentation should be used. 

Constructor & Destructor Documentation

◆ __init__()

def dataug.Augmented_model.__init__ (   self,
  data_augmenter,
  model 
)
Init Augmented Model.
    
    By default, data augmentation will be performed.

    Args:
data_augmenter (nn.Module): Data augmentation module.
model (nn.Module): Network.

Member Function Documentation

◆ __getitem__()

def dataug.Augmented_model.__getitem__ (   self,
  key 
)
Access to the modules.
Args:
    key (string): Name of the module to access.

Returns:
    nn.Module.

◆ __str__()

def dataug.Augmented_model.__str__ (   self)
Name of the module

    Returns:
String containing the name of the module as well as the higher levels parameters.

◆ augment()

def dataug.Augmented_model.augment (   self,
  mode = True 
)
Set the augmentation mode.

    Args:
mode (bool): Wether to perform data augmentation on the forward pass. (default: True)

◆ eval()

def dataug.Augmented_model.eval (   self)
Set the module to evaluation mode.

◆ forward()

def dataug.Augmented_model.forward (   self,
  x 
)
Main method of the Augmented model.

    Perform the forward pass of both modules.

    Args:
x (Tensor): Batch of data.

    Returns:
Tensor : Output of the networks. Should be logits.

◆ is_augmenting()

def dataug.Augmented_model.is_augmenting (   self)
Return wether data augmentation is applied.

    Returns:
bool : True if data augmentation is applied.

◆ items()

def dataug.Augmented_model.items (   self)
Return an iterable of the ModuleDict key/value pairs.

◆ TF_names()

def dataug.Augmented_model.TF_names (   self)
Get the transformations names used by the data augmentation module.

    Returns:
list : names of the transformations of the data augmentation module.

◆ train()

def dataug.Augmented_model.train (   self,
  mode = True 
)
Set the module training mode.

    Args:
mode (bool): Wether to learn the parameter of the module. (default: None)

◆ update()

def dataug.Augmented_model.update (   self,
  modules 
)
Update the module dictionnary.

    The new dictionnary should keep the same structure.

The documentation for this class was generated from the following file: