My Project
Functions
train_utils Namespace Reference

Functions

def test (model)
 
def compute_vaLoss (model, dl_it, dl)
 
def train_classic (model, opt_param, epochs=1, print_freq=1)
 
def run_dist_dataugV3 (model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, print_freq=1, KLdiv=1, hp_opt=False, save_sample_freq=None)
 

Detailed Description

Utilities function for training.

Function Documentation

◆ compute_vaLoss()

def train_utils.compute_vaLoss (   model,
  dl_it,
  dl 
)
Evaluate a model on a batch of data.

    Args: 
        model (nn.Module): Model to evaluate.
        dl_it (Iterator): Data loader iterator.
        dl (DataLoader): Data loader.

    Returns:
        (Tensor) Loss on a single batch of data.

◆ run_dist_dataugV3()

def train_utils.run_dist_dataugV3 (   model,
  opt_param,
  epochs = 1,
  inner_it = 1,
  dataug_epoch_start = 0,
  print_freq = 1,
  KLdiv = 1,
  hp_opt = False,
  save_sample_freq = None 
)
Training of an augmented model with higher.

        This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
        Ex : Augmented_model(Data_augV5(...), Higher_model(model))

        Training loss can either be computed directly from augmented inputs (KLdiv=0).
        However, it is recommended to use the KLdiv loss computation, inspired from UDA, which combine original and augmented inputs to compute the loss (KLdiv>0).
        See : https://github.com/google-research/uda

    Args:
        model (nn.Module): Augmented model to train.
        opt_param (dict): Dictionnary containing optimizers parameters.
        epochs (int): Number of epochs to perform. (default: 1)
        inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
        dataug_epoch_start (int): Epoch when to start data augmentation. (default: 0)
        print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
        KLdiv (float): Proportion of the KLdiv loss added to the supervised loss. If set to 0, the loss is classicly computed on augmented inputs. (default: 1)
        hp_opt (bool): Wether to learn inner optimizer parameters. (default: False)
        save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, only one save would be done at the end of the training. (default: None)

    Returns:
        (list) Logs of training. Each items is a dict containing results of an epoch.

◆ test()

def train_utils.test (   model)
Evaluate a model on test data.

    Args:
        model (nn.Module): Model to test.

    Returns:
        (float, Tensor) Returns the accuracy and test loss of the model.

◆ train_classic()

def train_utils.train_classic (   model,
  opt_param,
  epochs = 1,
  print_freq = 1 
)
Classic training of a model.

    Args:
        model (nn.Module): Model to train.
        opt_param (dict): Dictionnary containing optimizers parameters.
        epochs (int): Number of epochs to perform. (default: 1)
        print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)

    Returns:
        (list) Logs of training. Each items is a dict containing results of an epoch.