Changes since Teledyne

This commit is contained in:
Antoine Harlé 2024-08-20 11:53:35 +02:00 committed by AntoineH
parent 03ffd7fe05
commit b89dac9084
185 changed files with 16668 additions and 484 deletions

View file

@ -3,17 +3,88 @@
"""
import numpy as np
import json, math, time, os
import matplotlib
matplotlib.use('Agg') #https://stackoverflow.com/questions/4706451/how-to-save-a-figure-remotely-with-pylab
import matplotlib.pyplot as plt
import copy
import gc
from torchviz import make_dot
import torch
import torch.nn.functional as F
import time
from nets.LeNet import *
from nets.wideresnet import *
from nets.wideresnet_cifar import *
import nets.resnet_abn as resnet_abn
import nets.resnet_deconv as resnet_DC
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import url_map as EfficientNet_map
import torchvision.models as models
def load_model(model, num_classes, pretrained=False):
if model in models.resnet.__all__ :
model_name = model #'resnet18' #'resnet34' #'wide_resnet50_2'
if pretrained :
print("Using pretrained weights")
model = getattr(models.resnet, model_name)(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
else:
model = getattr(models.resnet, model_name)(pretrained=False, num_classes=num_classes)
elif model in models.vgg.__all__ :
model_name = model #'vgg11', 'vgg1_bn'
if pretrained :
print("Using pretrained weights")
model = getattr(models.vgg, model_name)(pretrained=True)
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
else :
model = getattr(models.vgg, model_name)(pretrained=False, num_classes=num_classes)
elif model in models.densenet.__all__ :
model_name = model #'densenet121' #'densenet201'
if pretrained :
print("Using pretrained weights")
model = getattr(models.densenet, model_name)(pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, num_classes)
else:
model = getattr(models.densenet, model_name)(pretrained=False, num_classes=num_classes)
elif model == 'LeNet':
if pretrained :
print("Pretrained weights not available")
model = LeNet(3,num_classes)
model_name=str(model)
elif model == 'WideResNet':
if pretrained :
print("Pretrained weights not available")
# model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_classes)
# model = WideResNet(16, 4, dropout_rate=0.0, num_classes=num_classes)
model = wide_resnet_cifar(26, 10, num_classes=num_classes)
# model = wide_resnet_cifar(20, 10, num_classes=num_classes)
model_name=str(model)
elif model in EfficientNet_map.keys():
model_name=model # efficientnet-b0 , efficientnet-b1, efficientnet-b4
if pretrained: #ImageNet ou Advprop (Meilleurs perf normalement mais normalisation differentes)
print("Using pretrained weights")
model = EfficientNet.from_pretrained(model_name, advprop=False)
else:
model = EfficientNet.from_name(model_name)
elif model in resnet_abn.__all__ :
if pretrained :
print("Pretrained weights not available")
model_name=model
model = getattr(resnet_abn, model_name)(pretrained=False, num_classes=num_classes)
elif model in resnet_DC.__all__:
if pretrained :
print("Pretrained weights not available")
model_name = model
model = getattr(resnet_DC, model_name)(num_classes=num_classes)
else:
raise Exception('Unknown model')
return model, model_name
class ConfusionMatrix(object):
""" Confusion matrix.
@ -120,7 +191,8 @@ class ConfusionMatrix(object):
f1=f1.mean()
return f1
def print_graph(PyTorch_obj, fig_name='graph'):
#from torchviz import make_dot
def print_graph(PyTorch_obj=torch.randn(1, 3, 32, 32), fig_name='graph'):
"""Save the computational graph.
Args:
@ -128,7 +200,7 @@ def print_graph(PyTorch_obj, fig_name='graph'):
fig_name (string): Relative path where to save the graph. (default: graph)
"""
graph=make_dot(PyTorch_obj)
graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
graph.format = 'png' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
graph.render(fig_name)
def plot_resV2(log, fig_name='res', param_names=None, f1=True):
@ -157,8 +229,13 @@ def plot_resV2(log, fig_name='res', param_names=None, f1=True):
#'''
#print(log[0]["f1"])
if isinstance(log[0]["f1"], list):
for c in range(len(log[0]["f1"])):
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
if len(log[0]["f1"])>10:
print("Plotting results : Too many class for F1, plotting only min/max")
ax[1, 0].plot(epochs,[max(x["f1"])*100 for x in log], label='F1-Max', ls='--')
ax[1, 0].plot(epochs,[min(x["f1"])*100 for x in log], label='F1-Min', ls='--')
else:
for c in range(len(log[0]["f1"])):
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
else:
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1', ls='--')
#'''
@ -251,7 +328,6 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
fig_name (string): Relative path where to save the graph. (default: data_sample)
weight_labels (Tensor): Weights associated to each labels. (default: None)
"""
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
plt.figure(figsize=(10,10))
@ -262,7 +338,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
plt.grid(False)
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
label = str(labels[i].item())
if weight_labels is not None : label+= (" - p %.2f" % weight_labels[i].item())
if torch.is_tensor(weight_labels): label+= (" - p %.2f" % weight_labels[i].item())
plt.xlabel(label)
plt.savefig(fig_name)
@ -348,10 +424,13 @@ def clip_norm(tensors, max_norm, norm_type=2):
else:
total_norm = 0
for t in tensors:
if t is None:
continue
param_norm = t.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef >= 1:
return tensors
return [t.mul(clip_coef) for t in tensors]
#return [t.mul(clip_coef) for t in tensors]
return [t if t is None else t.mul(clip_coef) for t in tensors]