mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
162 lines
5 KiB
Python
162 lines
5 KiB
Python
|
import numpy as np
|
||
|
import json, math, time, os
|
||
|
import matplotlib.pyplot as plt
|
||
|
import copy
|
||
|
import gc
|
||
|
|
||
|
from torchviz import make_dot
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
import time
|
||
|
|
||
|
class timer():
|
||
|
def __init__(self):
|
||
|
self._start_time=time.time()
|
||
|
def exec_time(self):
|
||
|
end = time.time()
|
||
|
res = end-self._start_time
|
||
|
self._start_time=end
|
||
|
return res
|
||
|
|
||
|
def plot_res(log, fig_name='res', param_names=None):
|
||
|
|
||
|
epochs = [x["epoch"] for x in log]
|
||
|
|
||
|
fig, ax = plt.subplots(ncols=3, figsize=(15, 3))
|
||
|
|
||
|
ax[0].set_title('Loss')
|
||
|
ax[0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
||
|
ax[0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||
|
ax[0].legend()
|
||
|
|
||
|
ax[1].set_title('Acc')
|
||
|
ax[1].plot(epochs,[x["acc"] for x in log])
|
||
|
|
||
|
if log[0]["param"]!= None:
|
||
|
if isinstance(log[0]["param"],float):
|
||
|
ax[2].set_title('Mag')
|
||
|
ax[2].plot(epochs,[x["param"] for x in log], label='Mag')
|
||
|
ax[2].legend()
|
||
|
else :
|
||
|
ax[2].set_title('Prob')
|
||
|
#for idx, _ in enumerate(log[0]["param"]):
|
||
|
#ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx))
|
||
|
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
||
|
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||
|
ax[2].stackplot(epochs, proba, labels=param_names)
|
||
|
ax[2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||
|
|
||
|
|
||
|
fig_name = fig_name.replace('.',',')
|
||
|
plt.savefig(fig_name)
|
||
|
plt.close()
|
||
|
|
||
|
def plot_res_compare(filenames, fig_name='res'):
|
||
|
|
||
|
all_data=[]
|
||
|
#legend=""
|
||
|
for idx, file in enumerate(filenames):
|
||
|
#legend+=str(idx)+'-'+file+'\n'
|
||
|
with open(file) as json_file:
|
||
|
data = json.load(json_file)
|
||
|
all_data.append(data)
|
||
|
|
||
|
n_tf = [len(x["Param_names"]) for x in all_data]
|
||
|
acc = [x["Accuracy"] for x in all_data]
|
||
|
time = [x["Time"][0] for x in all_data]
|
||
|
|
||
|
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
|
||
|
|
||
|
ax[0].plot(n_tf, acc)
|
||
|
ax[1].plot(n_tf, time)
|
||
|
|
||
|
ax[0].set_title('Acc')
|
||
|
ax[1].set_title('Time')
|
||
|
#for a in ax: a.legend()
|
||
|
|
||
|
fig_name = fig_name.replace('.',',')
|
||
|
plt.savefig(fig_name, bbox_inches='tight')
|
||
|
plt.close()
|
||
|
|
||
|
def plot_TF_res(log, tf_names, fig_name='res'):
|
||
|
|
||
|
mean = np.mean([x["param"] for x in log], axis=0)
|
||
|
std = np.std([x["param"] for x in log], axis=0)
|
||
|
|
||
|
fig, ax = plt.subplots(1, 1, figsize=(30, 8), sharey=True)
|
||
|
ax.bar(tf_names, mean, yerr=std)
|
||
|
#ax.bar(tf_names, log[-1]["param"])
|
||
|
|
||
|
fig_name = fig_name.replace('.',',')
|
||
|
plt.savefig(fig_name, bbox_inches='tight')
|
||
|
plt.close()
|
||
|
|
||
|
def model_copy(src,dst, patch_copy=True, copy_grad=True):
|
||
|
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
|
||
|
|
||
|
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
|
||
|
|
||
|
if patch_copy:
|
||
|
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
|
||
|
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
|
||
|
|
||
|
#Copie des gradients
|
||
|
if copy_grad:
|
||
|
for paramName, paramValue, in src.named_parameters():
|
||
|
for netCopyName, netCopyValue, in dst.named_parameters():
|
||
|
if paramName == netCopyName:
|
||
|
netCopyValue.grad = paramValue.grad
|
||
|
#netCopyValue=copy.deepcopy(paramValue)
|
||
|
|
||
|
try: #Data_augV4
|
||
|
dst['data_aug']._input_info = src['data_aug']._input_info
|
||
|
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
def optim_copy(dopt, opt):
|
||
|
|
||
|
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
|
||
|
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
|
||
|
|
||
|
for group_idx, group in enumerate(opt.param_groups):
|
||
|
# print('gp idx',group_idx)
|
||
|
for p_idx, p in enumerate(group['params']):
|
||
|
opt.state[p]=dopt.state[group_idx][p_idx]
|
||
|
|
||
|
class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
||
|
def __init__(self, patience, end_train=1):
|
||
|
self.patience = patience
|
||
|
self.end_train = end_train
|
||
|
self.counter = 0
|
||
|
self.best_score = None
|
||
|
self.reached_limit = 0
|
||
|
|
||
|
def register(self, loss):
|
||
|
if self.best_score is None:
|
||
|
self.best_score = loss
|
||
|
elif loss > self.best_score:
|
||
|
self.counter += 1
|
||
|
#if not self.reached_limit:
|
||
|
print("loss no improve counter", self.counter, self.reached_limit)
|
||
|
else:
|
||
|
self.best_score = loss
|
||
|
self.counter = 0
|
||
|
def limit_reached(self):
|
||
|
if self.counter >= self.patience:
|
||
|
self.counter = 0
|
||
|
self.reached_limit +=1
|
||
|
self.best_score = None
|
||
|
return self.reached_limit
|
||
|
|
||
|
def end_training(self):
|
||
|
if self.limit_reached() >= self.end_train:
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
def reset(self):
|
||
|
self.__init__(self.patience, self.end_train)
|