2020-01-24 14:32:37 -05:00
|
|
|
""" Utilties function.
|
|
|
|
|
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
import numpy as np
|
|
|
|
import json, math, time, os
|
2024-08-20 11:53:35 +02:00
|
|
|
import matplotlib
|
|
|
|
matplotlib.use('Agg') #https://stackoverflow.com/questions/4706451/how-to-save-a-figure-remotely-with-pylab
|
2019-11-08 11:28:06 -05:00
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import copy
|
|
|
|
import gc
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
2019-12-06 14:13:28 -05:00
|
|
|
import time
|
|
|
|
|
2024-08-20 11:53:35 +02:00
|
|
|
from nets.LeNet import *
|
|
|
|
from nets.wideresnet import *
|
|
|
|
from nets.wideresnet_cifar import *
|
|
|
|
import nets.resnet_abn as resnet_abn
|
|
|
|
import nets.resnet_deconv as resnet_DC
|
|
|
|
from efficientnet_pytorch import EfficientNet
|
|
|
|
from efficientnet_pytorch.utils import url_map as EfficientNet_map
|
|
|
|
import torchvision.models as models
|
|
|
|
def load_model(model, num_classes, pretrained=False):
|
|
|
|
if model in models.resnet.__all__ :
|
|
|
|
model_name = model #'resnet18' #'resnet34' #'wide_resnet50_2'
|
|
|
|
if pretrained :
|
|
|
|
print("Using pretrained weights")
|
|
|
|
model = getattr(models.resnet, model_name)(pretrained=True)
|
|
|
|
num_ftrs = model.fc.in_features
|
|
|
|
model.fc = nn.Linear(num_ftrs, num_classes)
|
|
|
|
else:
|
|
|
|
model = getattr(models.resnet, model_name)(pretrained=False, num_classes=num_classes)
|
|
|
|
elif model in models.vgg.__all__ :
|
|
|
|
model_name = model #'vgg11', 'vgg1_bn'
|
|
|
|
if pretrained :
|
|
|
|
print("Using pretrained weights")
|
|
|
|
model = getattr(models.vgg, model_name)(pretrained=True)
|
|
|
|
num_ftrs = model.classifier[-1].in_features
|
|
|
|
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
|
|
|
|
else :
|
|
|
|
model = getattr(models.vgg, model_name)(pretrained=False, num_classes=num_classes)
|
|
|
|
elif model in models.densenet.__all__ :
|
|
|
|
model_name = model #'densenet121' #'densenet201'
|
|
|
|
if pretrained :
|
|
|
|
print("Using pretrained weights")
|
|
|
|
model = getattr(models.densenet, model_name)(pretrained=True)
|
|
|
|
num_ftrs = model.classifier.in_features
|
|
|
|
model.classifier = nn.Linear(num_ftrs, num_classes)
|
|
|
|
else:
|
|
|
|
model = getattr(models.densenet, model_name)(pretrained=False, num_classes=num_classes)
|
|
|
|
elif model == 'LeNet':
|
|
|
|
if pretrained :
|
|
|
|
print("Pretrained weights not available")
|
|
|
|
model = LeNet(3,num_classes)
|
|
|
|
model_name=str(model)
|
|
|
|
elif model == 'WideResNet':
|
|
|
|
if pretrained :
|
|
|
|
print("Pretrained weights not available")
|
|
|
|
# model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_classes)
|
|
|
|
# model = WideResNet(16, 4, dropout_rate=0.0, num_classes=num_classes)
|
|
|
|
model = wide_resnet_cifar(26, 10, num_classes=num_classes)
|
|
|
|
# model = wide_resnet_cifar(20, 10, num_classes=num_classes)
|
|
|
|
model_name=str(model)
|
|
|
|
elif model in EfficientNet_map.keys():
|
|
|
|
model_name=model # efficientnet-b0 , efficientnet-b1, efficientnet-b4
|
|
|
|
if pretrained: #ImageNet ou Advprop (Meilleurs perf normalement mais normalisation differentes)
|
|
|
|
print("Using pretrained weights")
|
|
|
|
model = EfficientNet.from_pretrained(model_name, advprop=False)
|
|
|
|
else:
|
|
|
|
model = EfficientNet.from_name(model_name)
|
|
|
|
elif model in resnet_abn.__all__ :
|
|
|
|
if pretrained :
|
|
|
|
print("Pretrained weights not available")
|
|
|
|
model_name=model
|
|
|
|
model = getattr(resnet_abn, model_name)(pretrained=False, num_classes=num_classes)
|
|
|
|
elif model in resnet_DC.__all__:
|
|
|
|
if pretrained :
|
|
|
|
print("Pretrained weights not available")
|
|
|
|
model_name = model
|
|
|
|
model = getattr(resnet_DC, model_name)(num_classes=num_classes)
|
|
|
|
else:
|
|
|
|
raise Exception('Unknown model')
|
|
|
|
|
|
|
|
return model, model_name
|
|
|
|
|
2020-01-31 16:43:10 -05:00
|
|
|
class ConfusionMatrix(object):
|
2020-02-03 17:46:32 -05:00
|
|
|
""" Confusion matrix.
|
|
|
|
|
|
|
|
Helps computing the confusion matrix and F1 scores.
|
|
|
|
|
|
|
|
Example use ::
|
|
|
|
confmat = ConfusionMatrix(...)
|
|
|
|
|
|
|
|
confmat.reset()
|
|
|
|
for data in dataset:
|
|
|
|
...
|
|
|
|
confmat.update(...)
|
|
|
|
|
|
|
|
confmat.f1_metric(...)
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
num_classes (int): Number of classes.
|
|
|
|
mat (Tensor): Confusion matrix. Filled by update method.
|
|
|
|
"""
|
2020-01-31 16:43:10 -05:00
|
|
|
def __init__(self, num_classes):
|
2020-02-03 17:46:32 -05:00
|
|
|
""" Initialize ConfusionMatrix.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_classes (int): Number of classes.
|
|
|
|
"""
|
2020-01-31 16:43:10 -05:00
|
|
|
self.num_classes = num_classes
|
|
|
|
self.mat = None
|
|
|
|
|
2020-02-03 17:46:32 -05:00
|
|
|
def update(self, target, pred):
|
|
|
|
""" Update the confusion matrix.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
target (Tensor): Target labels.
|
|
|
|
pred (Tensor): Prediction.
|
|
|
|
"""
|
2020-01-31 16:43:10 -05:00
|
|
|
n = self.num_classes
|
|
|
|
if self.mat is None:
|
2020-02-03 17:46:32 -05:00
|
|
|
self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
|
2020-01-31 16:43:10 -05:00
|
|
|
with torch.no_grad():
|
2020-02-03 17:46:32 -05:00
|
|
|
k = (target >= 0) & (target < n)
|
|
|
|
inds = n * target[k].to(torch.int64) + pred[k]
|
2020-01-31 16:43:10 -05:00
|
|
|
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
|
|
|
|
|
|
|
def reset(self):
|
2020-02-03 17:46:32 -05:00
|
|
|
""" Reset the Confusion matrix.
|
|
|
|
|
|
|
|
"""
|
2020-01-31 16:43:10 -05:00
|
|
|
if self.mat is not None:
|
|
|
|
self.mat.zero_()
|
|
|
|
|
|
|
|
def f1_metric(self, average=None):
|
2020-02-03 17:46:32 -05:00
|
|
|
""" Compute the F1 score.
|
|
|
|
|
|
|
|
Inspired from :
|
|
|
|
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
|
|
|
|
https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
|
|
|
|
|
|
|
|
Args:
|
|
|
|
average (str): Type of averaging performed on the data. (Default: None)
|
|
|
|
``None``:
|
|
|
|
The scores for each class are returned.
|
|
|
|
``'micro'``:
|
|
|
|
Calculate metrics globally by counting the total true positives,
|
|
|
|
false negatives and false positives.
|
|
|
|
``'macro'``:
|
|
|
|
Calculate metrics for each label, and find their unweighted
|
|
|
|
mean. This does not take label imbalance into account.
|
|
|
|
Return:
|
|
|
|
Tensor containing the F1 score. It's shape is either 1, if there was averaging, or (num_classes).
|
|
|
|
"""
|
|
|
|
|
2020-01-31 16:43:10 -05:00
|
|
|
h = self.mat.float()
|
|
|
|
TP = torch.diag(h)
|
|
|
|
TN = []
|
|
|
|
FP = []
|
|
|
|
FN = []
|
|
|
|
for c in range(self.num_classes):
|
|
|
|
idx = torch.ones(self.num_classes).bool()
|
|
|
|
idx[c] = 0
|
|
|
|
# all non-class samples classified as non-class
|
|
|
|
TN.append(self.mat[idx.nonzero()[:, None], idx.nonzero()].sum()) #conf_matrix[idx[:, None], idx].sum() - conf_matrix[idx, c].sum()
|
|
|
|
# all non-class samples classified as class
|
|
|
|
FP.append(self.mat[idx, c].sum())
|
|
|
|
# all class samples not classified as class
|
|
|
|
FN.append(self.mat[c, idx].sum())
|
|
|
|
|
|
|
|
#print('Class {}\nTP {}, TN {}, FP {}, FN {}'.format(c, TP[c], TN[c], FP[c], FN[c]))
|
|
|
|
|
|
|
|
tp = (TP/h.sum(1))#.sum()
|
|
|
|
tn = (torch.tensor(TN, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
|
|
|
fp = (torch.tensor(FP, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
|
|
|
fn = (torch.tensor(FN, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
|
|
|
|
|
|
|
if average=="micro":
|
|
|
|
tp, tn, fp, fn = tp.sum(), tn.sum(), fp.sum(), fn.sum()
|
|
|
|
|
|
|
|
epsilon = 1e-7
|
|
|
|
precision = tp / (tp + fp + epsilon)
|
|
|
|
recall = tp / (tp + fn + epsilon)
|
|
|
|
|
|
|
|
f1 = 2* (precision*recall) / (precision + recall + epsilon)
|
|
|
|
|
|
|
|
if average=="macro":
|
|
|
|
f1=f1.mean()
|
|
|
|
return f1
|
|
|
|
|
2024-08-20 11:53:35 +02:00
|
|
|
#from torchviz import make_dot
|
|
|
|
def print_graph(PyTorch_obj=torch.randn(1, 3, 32, 32), fig_name='graph'):
|
2020-01-24 14:32:37 -05:00
|
|
|
"""Save the computational graph.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
PyTorch_obj (Tensor): End of the graph. Commonly, the loss tensor to get the whole graph.
|
|
|
|
fig_name (string): Relative path where to save the graph. (default: graph)
|
|
|
|
"""
|
|
|
|
graph=make_dot(PyTorch_obj)
|
2024-08-20 11:53:35 +02:00
|
|
|
graph.format = 'png' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
2019-11-08 11:28:06 -05:00
|
|
|
graph.render(fig_name)
|
|
|
|
|
2020-02-21 11:32:53 -05:00
|
|
|
def plot_resV2(log, fig_name='res', param_names=None, f1=True):
|
2020-01-24 14:32:37 -05:00
|
|
|
"""Save a visual graph of the logs.
|
2019-11-14 21:17:54 -05:00
|
|
|
|
2020-01-24 14:32:37 -05:00
|
|
|
Args:
|
|
|
|
log (dict): Logs of the training generated by most of train_utils.
|
|
|
|
fig_name (string): Relative path where to save the graph. (default: res)
|
|
|
|
param_names (list): Labels for the parameters. (default: None)
|
2020-02-21 11:32:53 -05:00
|
|
|
f1 (bool): Wether to plot F1 scores. (default: True)
|
2020-01-24 14:32:37 -05:00
|
|
|
"""
|
2019-11-14 21:17:54 -05:00
|
|
|
epochs = [x["epoch"] for x in log]
|
|
|
|
|
2019-11-18 16:48:51 -05:00
|
|
|
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(30, 15))
|
2019-11-14 21:17:54 -05:00
|
|
|
|
|
|
|
ax[0, 0].set_title('Loss')
|
|
|
|
ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
|
|
|
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
|
|
|
ax[0, 0].legend()
|
|
|
|
|
2020-01-31 16:43:10 -05:00
|
|
|
ax[1, 0].set_title('Test')
|
|
|
|
ax[1, 0].plot(epochs,[x["acc"] for x in log], label='Acc')
|
|
|
|
|
2020-02-21 11:32:53 -05:00
|
|
|
if f1 and "f1" in log[0].keys():
|
2020-02-03 11:21:54 -05:00
|
|
|
#ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1')
|
|
|
|
#'''
|
2020-01-31 16:43:10 -05:00
|
|
|
#print(log[0]["f1"])
|
2020-02-03 11:21:54 -05:00
|
|
|
if isinstance(log[0]["f1"], list):
|
2024-08-20 11:53:35 +02:00
|
|
|
if len(log[0]["f1"])>10:
|
|
|
|
print("Plotting results : Too many class for F1, plotting only min/max")
|
|
|
|
ax[1, 0].plot(epochs,[max(x["f1"])*100 for x in log], label='F1-Max', ls='--')
|
|
|
|
ax[1, 0].plot(epochs,[min(x["f1"])*100 for x in log], label='F1-Min', ls='--')
|
|
|
|
else:
|
|
|
|
for c in range(len(log[0]["f1"])):
|
|
|
|
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
|
2020-02-03 11:21:54 -05:00
|
|
|
else:
|
|
|
|
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1', ls='--')
|
|
|
|
#'''
|
2020-01-31 16:43:10 -05:00
|
|
|
|
|
|
|
ax[1, 0].legend()
|
2019-11-14 21:17:54 -05:00
|
|
|
|
|
|
|
if log[0]["param"]!= None:
|
|
|
|
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
2019-11-18 16:48:51 -05:00
|
|
|
#proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
|
|
|
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
|
|
|
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
|
|
|
|
|
|
|
ax[0, 1].set_title('Prob =f(epoch)')
|
|
|
|
ax[0, 1].stackplot(epochs, proba, labels=param_names)
|
|
|
|
#ax[0, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
|
|
|
|
|
|
|
ax[1, 1].set_title('Prob =f(TF)')
|
|
|
|
mean = np.mean(proba, axis=1)
|
|
|
|
std = np.std(proba, axis=1)
|
|
|
|
ax[1, 1].bar(param_names, mean, yerr=std)
|
|
|
|
plt.sca(ax[1, 1]), plt.xticks(rotation=90)
|
|
|
|
|
|
|
|
ax[0, 2].set_title('Mag =f(epoch)')
|
|
|
|
ax[0, 2].stackplot(epochs, mag, labels=param_names)
|
2020-01-13 18:02:36 -05:00
|
|
|
#ax[0, 2].plot(epochs, np.array(mag).T, label=param_names)
|
2019-11-18 16:48:51 -05:00
|
|
|
ax[0, 2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
|
|
|
|
|
|
|
ax[1, 2].set_title('Mag =f(TF)')
|
|
|
|
mean = np.mean(mag, axis=1)
|
|
|
|
std = np.std(mag, axis=1)
|
|
|
|
ax[1, 2].bar(param_names, mean, yerr=std)
|
|
|
|
plt.sca(ax[1, 2]), plt.xticks(rotation=90)
|
2019-11-14 21:17:54 -05:00
|
|
|
|
|
|
|
|
2020-01-31 16:43:10 -05:00
|
|
|
fig_name = fig_name.replace('.',',').replace(',,/','../')
|
2019-11-14 21:17:54 -05:00
|
|
|
plt.savefig(fig_name, bbox_inches='tight')
|
|
|
|
plt.close()
|
|
|
|
|
2019-11-08 11:28:06 -05:00
|
|
|
def plot_compare(filenames, fig_name='res'):
|
2020-01-24 14:32:37 -05:00
|
|
|
"""Save a visual graph comparing trainings stats.
|
2019-11-08 11:28:06 -05:00
|
|
|
|
2020-01-24 14:32:37 -05:00
|
|
|
Args:
|
|
|
|
filenames (list[Strings]): Relative paths to the logs (JSON files).
|
|
|
|
fig_name (string): Relative path where to save the graph. (default: res)
|
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
all_data=[]
|
|
|
|
legend=""
|
|
|
|
for idx, file in enumerate(filenames):
|
|
|
|
legend+=str(idx)+'-'+file+'\n'
|
|
|
|
with open(file) as json_file:
|
|
|
|
data = json.load(json_file)
|
|
|
|
all_data.append(data)
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
|
|
|
|
|
|
|
|
for data_idx, log in enumerate(all_data):
|
|
|
|
log=log['Log']
|
|
|
|
epochs = [x["epoch"] for x in log]
|
|
|
|
|
|
|
|
ax[0].plot(epochs,[x["train_loss"] for x in log], label=str(data_idx)+'-Train')
|
|
|
|
ax[0].plot(epochs,[x["val_loss"] for x in log], label=str(data_idx)+'-Val')
|
|
|
|
|
|
|
|
ax[1].plot(epochs,[x["acc"] for x in log], label=str(data_idx))
|
|
|
|
#ax[1].text(x=0.5,y=0,s=str(data_idx)+'-'+filenames[data_idx], transform=ax[1].transAxes)
|
|
|
|
|
|
|
|
if log[0]["param"]!= None:
|
|
|
|
if isinstance(log[0]["param"],float):
|
|
|
|
ax[2].plot(epochs,[x["param"] for x in log], label=str(data_idx)+'-Mag')
|
|
|
|
|
|
|
|
else :
|
|
|
|
for idx, _ in enumerate(log[0]["param"]):
|
|
|
|
ax[2].plot(epochs,[x["param"][idx] for x in log], label=str(data_idx)+'-P'+str(idx))
|
|
|
|
|
|
|
|
fig.suptitle(legend)
|
|
|
|
ax[0].set_title('Loss')
|
|
|
|
ax[1].set_title('Acc')
|
|
|
|
ax[2].set_title('Param')
|
|
|
|
for a in ax: a.legend()
|
2019-11-08 17:41:19 -05:00
|
|
|
|
2019-11-08 11:28:06 -05:00
|
|
|
fig_name = fig_name.replace('.',',')
|
2019-11-08 17:41:19 -05:00
|
|
|
plt.savefig(fig_name, bbox_inches='tight')
|
|
|
|
plt.close()
|
|
|
|
|
2020-01-10 13:21:34 -05:00
|
|
|
def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
|
2020-01-24 14:32:37 -05:00
|
|
|
"""Save data samples.
|
2019-11-08 11:28:06 -05:00
|
|
|
|
2020-01-24 14:32:37 -05:00
|
|
|
Args:
|
|
|
|
imgs (Tensor): Batch of image to sample from. Intended to contain at least 25 images.
|
|
|
|
labels (Tensor): Labels of the images.
|
|
|
|
fig_name (string): Relative path where to save the graph. (default: data_sample)
|
|
|
|
weight_labels (Tensor): Weights associated to each labels. (default: None)
|
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
|
|
|
|
|
|
|
|
plt.figure(figsize=(10,10))
|
|
|
|
for i in range(25):
|
2020-01-24 11:50:30 -05:00
|
|
|
plt.subplot(5,5,i+1) #Trop de figure cree ?
|
2019-11-08 11:28:06 -05:00
|
|
|
plt.xticks([])
|
|
|
|
plt.yticks([])
|
|
|
|
plt.grid(False)
|
2019-11-18 12:53:23 -05:00
|
|
|
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
|
2020-01-10 13:21:34 -05:00
|
|
|
label = str(labels[i].item())
|
2024-08-20 11:53:35 +02:00
|
|
|
if torch.is_tensor(weight_labels): label+= (" - p %.2f" % weight_labels[i].item())
|
2020-01-10 13:21:34 -05:00
|
|
|
plt.xlabel(label)
|
2019-11-08 11:28:06 -05:00
|
|
|
|
|
|
|
plt.savefig(fig_name)
|
2019-11-11 17:01:15 -05:00
|
|
|
print("Sample saved :", fig_name)
|
2020-01-24 11:50:30 -05:00
|
|
|
plt.close('all')
|
2019-11-08 11:28:06 -05:00
|
|
|
|
|
|
|
def print_torch_mem(add_info=''):
|
2020-01-24 14:32:37 -05:00
|
|
|
"""Print informations on PyTorch memory usage.
|
2019-11-08 11:28:06 -05:00
|
|
|
|
2020-01-24 14:32:37 -05:00
|
|
|
Args:
|
|
|
|
add_info (string): Prefix added before the print. (default: None)
|
|
|
|
"""
|
2019-11-08 11:28:06 -05:00
|
|
|
nb=0
|
|
|
|
max_size=0
|
|
|
|
for obj in gc.get_objects():
|
|
|
|
#print(type(obj))
|
|
|
|
try:
|
|
|
|
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # and len(obj.size())>1:
|
|
|
|
#print(i, type(obj), obj.size())
|
|
|
|
size = np.sum(obj.size())
|
|
|
|
if(size>max_size): max_size=size
|
|
|
|
nb+=1
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
print(add_info, "-Pytroch tensor nb:",nb," / Max dim:", max_size)
|
|
|
|
|
|
|
|
#print(add_info, "-Garbage size :",len(gc.garbage))
|
|
|
|
|
2019-11-13 16:18:53 -05:00
|
|
|
"""Simple GPU memory report."""
|
|
|
|
|
|
|
|
mega_bytes = 1024.0 * 1024.0
|
|
|
|
string = add_info + ' memory (MB)'
|
|
|
|
string += ' | allocated: {}'.format(
|
|
|
|
torch.cuda.memory_allocated() / mega_bytes)
|
|
|
|
string += ' | max allocated: {}'.format(
|
|
|
|
torch.cuda.max_memory_allocated() / mega_bytes)
|
|
|
|
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
|
|
|
|
string += ' | max cached: {}'.format(
|
|
|
|
torch.cuda.max_memory_cached()/ mega_bytes)
|
|
|
|
print(string)
|
|
|
|
|
2020-01-24 14:32:37 -05:00
|
|
|
'''
|
2019-11-19 21:46:14 -05:00
|
|
|
def plot_TF_influence(log, fig_name='TF_influence', param_names=None):
|
2019-11-19 15:37:29 -05:00
|
|
|
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
|
|
|
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
|
|
|
|
2019-11-19 21:46:14 -05:00
|
|
|
plt.figure()
|
|
|
|
|
|
|
|
mean = np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
|
|
|
|
std = np.std(proba, axis=1)*np.std(mag, axis=1)
|
|
|
|
plt.bar(param_names, mean, yerr=std)
|
|
|
|
|
|
|
|
plt.xticks(rotation=90)
|
|
|
|
fig_name = fig_name.replace('.',',')
|
|
|
|
plt.savefig(fig_name, bbox_inches='tight')
|
|
|
|
plt.close()
|
2020-01-24 14:32:37 -05:00
|
|
|
'''
|
2019-11-13 16:18:53 -05:00
|
|
|
|
2020-01-15 16:55:03 -05:00
|
|
|
from torch._six import inf
|
|
|
|
def clip_norm(tensors, max_norm, norm_type=2):
|
2020-01-24 14:32:37 -05:00
|
|
|
"""Clips norm of passed tensors.
|
|
|
|
The norm is computed over all tensors together, as if they were
|
|
|
|
concatenated into a single vector. Clipped tensors are returned.
|
|
|
|
|
|
|
|
See: https://github.com/facebookresearch/higher/issues/18
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tensors (Iterable[Tensor]): an iterable of Tensors or a
|
|
|
|
single Tensor to be normalized.
|
|
|
|
max_norm (float or int): max norm of the gradients
|
|
|
|
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
|
|
|
infinity norm.
|
|
|
|
Returns:
|
|
|
|
Clipped (List[Tensor]) tensors.
|
2020-01-15 16:55:03 -05:00
|
|
|
"""
|
|
|
|
if isinstance(tensors, torch.Tensor):
|
|
|
|
tensors = [tensors]
|
|
|
|
tensors = list(tensors)
|
|
|
|
max_norm = float(max_norm)
|
|
|
|
norm_type = float(norm_type)
|
|
|
|
if norm_type == inf:
|
|
|
|
total_norm = max(t.abs().max() for t in tensors)
|
|
|
|
else:
|
|
|
|
total_norm = 0
|
|
|
|
for t in tensors:
|
2024-08-20 11:53:35 +02:00
|
|
|
if t is None:
|
|
|
|
continue
|
2020-01-15 16:55:03 -05:00
|
|
|
param_norm = t.norm(norm_type)
|
|
|
|
total_norm += param_norm.item() ** norm_type
|
|
|
|
total_norm = total_norm ** (1. / norm_type)
|
|
|
|
clip_coef = max_norm / (total_norm + 1e-6)
|
|
|
|
if clip_coef >= 1:
|
|
|
|
return tensors
|
2024-08-20 11:53:35 +02:00
|
|
|
#return [t.mul(clip_coef) for t in tensors]
|
|
|
|
return [t if t is None else t.mul(clip_coef) for t in tensors]
|