mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Modifs dist_dataugv3 (-copy/+rapide) + Legere modif TF
This commit is contained in:
parent
e291bc2e44
commit
75901b69b4
6 changed files with 198 additions and 83 deletions
|
@ -314,4 +314,38 @@ class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
|||
return False
|
||||
|
||||
def reset(self):
|
||||
self.__init__(self.patience, self.end_train)
|
||||
self.__init__(self.patience, self.end_train)
|
||||
|
||||
### https://github.com/facebookresearch/higher/issues/18 ####
|
||||
from torch._six import inf
|
||||
|
||||
def clip_norm(tensors, max_norm, norm_type=2):
|
||||
r"""Clips norm of passed tensors.
|
||||
The norm is computed over all tensors together, as if they were
|
||||
concatenated into a single vector. Clipped tensors are returned.
|
||||
Arguments:
|
||||
tensors (Iterable[Tensor]): an iterable of Tensors or a
|
||||
single Tensor to be normalized.
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
Returns:
|
||||
Clipped (List[Tensor]) tensors.
|
||||
"""
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
tensors = [tensors]
|
||||
tensors = list(tensors)
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(t.abs().max() for t in tensors)
|
||||
else:
|
||||
total_norm = 0
|
||||
for t in tensors:
|
||||
param_norm = t.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
total_norm = total_norm ** (1. / norm_type)
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef >= 1:
|
||||
return tensors
|
||||
return [t.mul(clip_coef) for t in tensors]
|
Loading…
Add table
Add a link
Reference in a new issue