My Project
|
Functions | |
def | print_graph (PyTorch_obj, fig_name='graph') |
def | plot_resV2 (log, fig_name='res', param_names=None) |
def | plot_compare (filenames, fig_name='res') |
def | viz_sample_data (imgs, labels, fig_name='data_sample', weight_labels=None) |
def | print_torch_mem (add_info='') |
def | clip_norm (tensors, max_norm, norm_type=2) |
Utilties function.
def utils.clip_norm | ( | tensors, | |
max_norm, | |||
norm_type = 2 |
|||
) |
Clips norm of passed tensors. The norm is computed over all tensors together, as if they were concatenated into a single vector. Clipped tensors are returned. See: https://github.com/facebookresearch/higher/issues/18 Args: tensors (Iterable[Tensor]): an iterable of Tensors or a single Tensor to be normalized. max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Clipped (List[Tensor]) tensors.
def utils.plot_compare | ( | filenames, | |
fig_name = 'res' |
|||
) |
Save a visual graph comparing trainings stats. Args: filenames (list[Strings]): Relative paths to the logs (JSON files). fig_name (string): Relative path where to save the graph. (default: res)
def utils.plot_resV2 | ( | log, | |
fig_name = 'res' , |
|||
param_names = None |
|||
) |
Save a visual graph of the logs. Args: log (dict): Logs of the training generated by most of train_utils. fig_name (string): Relative path where to save the graph. (default: res) param_names (list): Labels for the parameters. (default: None)
def utils.print_graph | ( | PyTorch_obj, | |
fig_name = 'graph' |
|||
) |
Save the computational graph. Args: PyTorch_obj (Tensor): End of the graph. Commonly, the loss tensor to get the whole graph. fig_name (string): Relative path where to save the graph. (default: graph)
def utils.print_torch_mem | ( | add_info = '' | ) |
Print informations on PyTorch memory usage. Args: add_info (string): Prefix added before the print. (default: None)
def utils.viz_sample_data | ( | imgs, | |
labels, | |||
fig_name = 'data_sample' , |
|||
weight_labels = None |
|||
) |
Save data samples. Args: imgs (Tensor): Batch of image to sample from. Intended to contain at least 25 images. labels (Tensor): Labels of the images. fig_name (string): Relative path where to save the graph. (default: data_sample) weight_labels (Tensor): Weights associated to each labels. (default: None)