mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
leger chgt mag regularisation
This commit is contained in:
parent
64282bda3a
commit
cc737b7997
5 changed files with 21 additions and 8 deletions
|
@ -254,11 +254,20 @@ def print_torch_mem(add_info=''):
|
|||
torch.cuda.max_memory_cached()/ mega_bytes)
|
||||
print(string)
|
||||
|
||||
def TF_influence(log):
|
||||
def plot_TF_influence(log, fig_name='TF_influence', param_names=None):
|
||||
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
|
||||
return np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
|
||||
plt.figure()
|
||||
|
||||
mean = np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
|
||||
std = np.std(proba, axis=1)*np.std(mag, axis=1)
|
||||
plt.bar(param_names, mean, yerr=std)
|
||||
|
||||
plt.xticks(rotation=90)
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
||||
def __init__(self, patience, end_train=1):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue