mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Changes since Teledyne
This commit is contained in:
parent
03ffd7fe05
commit
b89dac9084
185 changed files with 16668 additions and 484 deletions
|
@ -3,17 +3,88 @@
|
|||
"""
|
||||
import numpy as np
|
||||
import json, math, time, os
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') #https://stackoverflow.com/questions/4706451/how-to-save-a-figure-remotely-with-pylab
|
||||
import matplotlib.pyplot as plt
|
||||
import copy
|
||||
import gc
|
||||
|
||||
from torchviz import make_dot
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import time
|
||||
|
||||
from nets.LeNet import *
|
||||
from nets.wideresnet import *
|
||||
from nets.wideresnet_cifar import *
|
||||
import nets.resnet_abn as resnet_abn
|
||||
import nets.resnet_deconv as resnet_DC
|
||||
from efficientnet_pytorch import EfficientNet
|
||||
from efficientnet_pytorch.utils import url_map as EfficientNet_map
|
||||
import torchvision.models as models
|
||||
def load_model(model, num_classes, pretrained=False):
|
||||
if model in models.resnet.__all__ :
|
||||
model_name = model #'resnet18' #'resnet34' #'wide_resnet50_2'
|
||||
if pretrained :
|
||||
print("Using pretrained weights")
|
||||
model = getattr(models.resnet, model_name)(pretrained=True)
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, num_classes)
|
||||
else:
|
||||
model = getattr(models.resnet, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model in models.vgg.__all__ :
|
||||
model_name = model #'vgg11', 'vgg1_bn'
|
||||
if pretrained :
|
||||
print("Using pretrained weights")
|
||||
model = getattr(models.vgg, model_name)(pretrained=True)
|
||||
num_ftrs = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
|
||||
else :
|
||||
model = getattr(models.vgg, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model in models.densenet.__all__ :
|
||||
model_name = model #'densenet121' #'densenet201'
|
||||
if pretrained :
|
||||
print("Using pretrained weights")
|
||||
model = getattr(models.densenet, model_name)(pretrained=True)
|
||||
num_ftrs = model.classifier.in_features
|
||||
model.classifier = nn.Linear(num_ftrs, num_classes)
|
||||
else:
|
||||
model = getattr(models.densenet, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model == 'LeNet':
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
model = LeNet(3,num_classes)
|
||||
model_name=str(model)
|
||||
elif model == 'WideResNet':
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
# model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_classes)
|
||||
# model = WideResNet(16, 4, dropout_rate=0.0, num_classes=num_classes)
|
||||
model = wide_resnet_cifar(26, 10, num_classes=num_classes)
|
||||
# model = wide_resnet_cifar(20, 10, num_classes=num_classes)
|
||||
model_name=str(model)
|
||||
elif model in EfficientNet_map.keys():
|
||||
model_name=model # efficientnet-b0 , efficientnet-b1, efficientnet-b4
|
||||
if pretrained: #ImageNet ou Advprop (Meilleurs perf normalement mais normalisation differentes)
|
||||
print("Using pretrained weights")
|
||||
model = EfficientNet.from_pretrained(model_name, advprop=False)
|
||||
else:
|
||||
model = EfficientNet.from_name(model_name)
|
||||
elif model in resnet_abn.__all__ :
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
model_name=model
|
||||
model = getattr(resnet_abn, model_name)(pretrained=False, num_classes=num_classes)
|
||||
elif model in resnet_DC.__all__:
|
||||
if pretrained :
|
||||
print("Pretrained weights not available")
|
||||
model_name = model
|
||||
model = getattr(resnet_DC, model_name)(num_classes=num_classes)
|
||||
else:
|
||||
raise Exception('Unknown model')
|
||||
|
||||
return model, model_name
|
||||
|
||||
class ConfusionMatrix(object):
|
||||
""" Confusion matrix.
|
||||
|
||||
|
@ -120,7 +191,8 @@ class ConfusionMatrix(object):
|
|||
f1=f1.mean()
|
||||
return f1
|
||||
|
||||
def print_graph(PyTorch_obj, fig_name='graph'):
|
||||
#from torchviz import make_dot
|
||||
def print_graph(PyTorch_obj=torch.randn(1, 3, 32, 32), fig_name='graph'):
|
||||
"""Save the computational graph.
|
||||
|
||||
Args:
|
||||
|
@ -128,7 +200,7 @@ def print_graph(PyTorch_obj, fig_name='graph'):
|
|||
fig_name (string): Relative path where to save the graph. (default: graph)
|
||||
"""
|
||||
graph=make_dot(PyTorch_obj)
|
||||
graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||
graph.format = 'png' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||
graph.render(fig_name)
|
||||
|
||||
def plot_resV2(log, fig_name='res', param_names=None, f1=True):
|
||||
|
@ -157,8 +229,13 @@ def plot_resV2(log, fig_name='res', param_names=None, f1=True):
|
|||
#'''
|
||||
#print(log[0]["f1"])
|
||||
if isinstance(log[0]["f1"], list):
|
||||
for c in range(len(log[0]["f1"])):
|
||||
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
|
||||
if len(log[0]["f1"])>10:
|
||||
print("Plotting results : Too many class for F1, plotting only min/max")
|
||||
ax[1, 0].plot(epochs,[max(x["f1"])*100 for x in log], label='F1-Max', ls='--')
|
||||
ax[1, 0].plot(epochs,[min(x["f1"])*100 for x in log], label='F1-Min', ls='--')
|
||||
else:
|
||||
for c in range(len(log[0]["f1"])):
|
||||
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c), ls='--')
|
||||
else:
|
||||
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1', ls='--')
|
||||
#'''
|
||||
|
@ -251,7 +328,6 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
|
|||
fig_name (string): Relative path where to save the graph. (default: data_sample)
|
||||
weight_labels (Tensor): Weights associated to each labels. (default: None)
|
||||
"""
|
||||
|
||||
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
|
||||
|
||||
plt.figure(figsize=(10,10))
|
||||
|
@ -262,7 +338,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
|
|||
plt.grid(False)
|
||||
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
|
||||
label = str(labels[i].item())
|
||||
if weight_labels is not None : label+= (" - p %.2f" % weight_labels[i].item())
|
||||
if torch.is_tensor(weight_labels): label+= (" - p %.2f" % weight_labels[i].item())
|
||||
plt.xlabel(label)
|
||||
|
||||
plt.savefig(fig_name)
|
||||
|
@ -348,10 +424,13 @@ def clip_norm(tensors, max_norm, norm_type=2):
|
|||
else:
|
||||
total_norm = 0
|
||||
for t in tensors:
|
||||
if t is None:
|
||||
continue
|
||||
param_norm = t.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
total_norm = total_norm ** (1. / norm_type)
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef >= 1:
|
||||
return tensors
|
||||
return [t.mul(clip_coef) for t in tensors]
|
||||
#return [t.mul(clip_coef) for t in tensors]
|
||||
return [t if t is None else t.mul(clip_coef) for t in tensors]
|
Loading…
Add table
Add a link
Reference in a new issue