leger chgt mag regularisation

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-19 21:46:14 -05:00
parent 64282bda3a
commit cc737b7997
5 changed files with 21 additions and 8 deletions

View file

@ -254,11 +254,20 @@ def print_torch_mem(add_info=''):
torch.cuda.max_memory_cached()/ mega_bytes)
print(string)
def TF_influence(log):
def plot_TF_influence(log, fig_name='TF_influence', param_names=None):
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
return np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
plt.figure()
mean = np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
std = np.std(proba, axis=1)*np.std(mag, axis=1)
plt.bar(param_names, mean, yerr=std)
plt.xticks(rotation=90)
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
class loss_monitor(): #Voir https://github.com/pytorch/ignite
def __init__(self, patience, end_train=1):