My Project
Functions
utils Namespace Reference

Functions

def print_graph (PyTorch_obj, fig_name='graph')
 
def plot_resV2 (log, fig_name='res', param_names=None)
 
def plot_compare (filenames, fig_name='res')
 
def viz_sample_data (imgs, labels, fig_name='data_sample', weight_labels=None)
 
def print_torch_mem (add_info='')
 
def clip_norm (tensors, max_norm, norm_type=2)
 

Detailed Description

Utilties function.

Function Documentation

◆ clip_norm()

def utils.clip_norm (   tensors,
  max_norm,
  norm_type = 2 
)
Clips norm of passed tensors.
    The norm is computed over all tensors together, as if they were
    concatenated into a single vector. Clipped tensors are returned.
    
    See: https://github.com/facebookresearch/higher/issues/18

    Args:
        tensors (Iterable[Tensor]): an iterable of Tensors or a
            single Tensor to be normalized.
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
    Returns:
      Clipped (List[Tensor]) tensors.

◆ plot_compare()

def utils.plot_compare (   filenames,
  fig_name = 'res' 
)
Save a visual graph comparing trainings stats.

    Args:
        filenames (list[Strings]): Relative paths to the logs (JSON files).
        fig_name (string): Relative path where to save the graph. (default: res)

◆ plot_resV2()

def utils.plot_resV2 (   log,
  fig_name = 'res',
  param_names = None 
)
Save a visual graph of the logs.

    Args:
        log (dict): Logs of the training generated by most of train_utils.
        fig_name (string): Relative path where to save the graph. (default: res)
        param_names (list): Labels for the parameters. (default: None)

◆ print_graph()

def utils.print_graph (   PyTorch_obj,
  fig_name = 'graph' 
)
Save the computational graph.

    Args:
        PyTorch_obj (Tensor): End of the graph. Commonly, the loss tensor to get the whole graph.
        fig_name (string): Relative path where to save the graph. (default: graph)

◆ print_torch_mem()

def utils.print_torch_mem (   add_info = '')
Print informations on PyTorch memory usage.

    Args:
        add_info (string): Prefix added before the print. (default: None)

◆ viz_sample_data()

def utils.viz_sample_data (   imgs,
  labels,
  fig_name = 'data_sample',
  weight_labels = None 
)
Save data samples.

    Args:
        imgs (Tensor): Batch of image to sample from. Intended to contain at least 25 images.
        labels (Tensor): Labels of the images.
        fig_name (string): Relative path where to save the graph. (default: data_sample)
        weight_labels (Tensor): Weights associated to each labels. (default: None)