My Project
|
Public Member Functions | |
def | __init__ (self, data_augmenter, model) |
def | forward (self, x) |
def | augment (self, mode=True) |
def | train (self, mode=True) |
def | eval (self) |
def | items (self) |
def | update (self, modules) |
def | is_augmenting (self) |
def | TF_names (self) |
def | __getitem__ (self, key) |
def | __str__ (self) |
Wrapper for a Data Augmentation module and a model. Attributes: _mods (nn.ModuleDict): A dictionary containing the modules. _data_augmentation (bool): Wether data augmentation should be used.
def dataug.Augmented_model.__init__ | ( | self, | |
data_augmenter, | |||
model | |||
) |
Init Augmented Model. By default, data augmentation will be performed. Args: data_augmenter (nn.Module): Data augmentation module. model (nn.Module): Network.
def dataug.Augmented_model.__getitem__ | ( | self, | |
key | |||
) |
Access to the modules. Args: key (string): Name of the module to access. Returns: nn.Module.
def dataug.Augmented_model.__str__ | ( | self | ) |
Name of the module Returns: String containing the name of the module as well as the higher levels parameters.
def dataug.Augmented_model.augment | ( | self, | |
mode = True |
|||
) |
Set the augmentation mode. Args: mode (bool): Wether to perform data augmentation on the forward pass. (default: True)
def dataug.Augmented_model.eval | ( | self | ) |
Set the module to evaluation mode.
def dataug.Augmented_model.forward | ( | self, | |
x | |||
) |
Main method of the Augmented model. Perform the forward pass of both modules. Args: x (Tensor): Batch of data. Returns: Tensor : Output of the networks. Should be logits.
def dataug.Augmented_model.is_augmenting | ( | self | ) |
Return wether data augmentation is applied. Returns: bool : True if data augmentation is applied.
def dataug.Augmented_model.items | ( | self | ) |
Return an iterable of the ModuleDict key/value pairs.
def dataug.Augmented_model.TF_names | ( | self | ) |
Get the transformations names used by the data augmentation module. Returns: list : names of the transformations of the data augmentation module.
def dataug.Augmented_model.train | ( | self, | |
mode = True |
|||
) |
Set the module training mode. Args: mode (bool): Wether to learn the parameter of the module. (default: None)
def dataug.Augmented_model.update | ( | self, | |
modules | |||
) |
Update the module dictionnary. The new dictionnary should keep the same structure.