mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-03 11:40:46 +02:00
Confmat / F1 + Minor fix
This commit is contained in:
parent
250ce2c3cf
commit
3ccacd0366
5 changed files with 120 additions and 32 deletions
|
@ -13,7 +13,7 @@ TEST_SIZE = BATCH_SIZE
|
|||
#TEST_SIZE = 10000 #legerement +Rapide / + Consomation memoire !
|
||||
|
||||
#Wether to download data.
|
||||
download_data=False
|
||||
download_data=True
|
||||
#Number of worker to use.
|
||||
num_workers=2 #4
|
||||
#Pin GPU memory
|
||||
|
|
|
@ -814,15 +814,16 @@ class Higher_model(nn.Module):
|
|||
_name (string): Name of the model.
|
||||
_mods (nn.ModuleDict): Models (Orginial and Higher version).
|
||||
"""
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, model_name=None):
|
||||
"""Init Higher_model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Network for which higher gradients can be tracked.
|
||||
model_name (string): Model name. (Default: Class name of model)
|
||||
"""
|
||||
super(Higher_model, self).__init__()
|
||||
|
||||
self._name = model.__class__.__name__ #model.__str__()
|
||||
self._name = model_name if model_name else model.__class__.__name__ #model.__str__()
|
||||
self._mods = nn.ModuleDict({
|
||||
'original': model,
|
||||
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
""" Script to run experiment on smart augmentation.
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
from LeNet import *
|
||||
from dataug import *
|
||||
#from utils import *
|
||||
|
@ -79,7 +79,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
#Parameters
|
||||
n_inner_iter = 1
|
||||
epochs = 150
|
||||
epochs = 2
|
||||
dataug_epoch_start=0
|
||||
optim_param={
|
||||
'Meta':{
|
||||
|
@ -94,18 +94,21 @@ if __name__ == "__main__":
|
|||
}
|
||||
|
||||
#Models
|
||||
model = LeNet(3,10)
|
||||
#model = LeNet(3,10)
|
||||
#model = ResNet(num_classes=10)
|
||||
#import torchvision.models as models
|
||||
import torchvision.models as models
|
||||
#model=models.resnet18()
|
||||
model_name = 'resnet18' #'wide_resnet50_2' #'resnet18' #str(model)
|
||||
model = getattr(models.resnet, model_name)(pretrained=False)
|
||||
|
||||
#### Classic ####
|
||||
if 'classic' in tasks:
|
||||
t0 = time.process_time()
|
||||
model = model.to(device)
|
||||
|
||||
print("{} on {} for {} epochs".format(str(model), device_name, epochs))
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=20)
|
||||
|
||||
print("{} on {} for {} epochs".format(model_name, device_name, epochs))
|
||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=5)
|
||||
#log= train_classic_higher(model=model, epochs=epochs)
|
||||
|
||||
exec_time=time.process_time() - t0
|
||||
|
@ -114,12 +117,12 @@ if __name__ == "__main__":
|
|||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param['Inner'], "Device": device_name, "Log": log}
|
||||
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs".format(str(model),epochs)
|
||||
filename = "{}-{} epochs".format(model_name,epochs)
|
||||
with open("../res/log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
||||
plot_res(log, fig_name="../res/"+filename)
|
||||
#plot_res(log, fig_name="../res/"+filename)
|
||||
|
||||
print('Execution Time : %.00f '%(exec_time))
|
||||
print('-'*9)
|
||||
|
@ -129,8 +132,8 @@ if __name__ == "__main__":
|
|||
t0 = time.process_time()
|
||||
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
model = Higher_model(model) #run_dist_dataugV3
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
||||
model = Higher_model(model, model_name) #run_dist_dataugV3
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||
|
@ -139,7 +142,7 @@ if __name__ == "__main__":
|
|||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=20,
|
||||
print_freq=1,
|
||||
unsup_loss=1,
|
||||
hp_opt=False,
|
||||
save_sample_freq=None)
|
||||
|
@ -157,10 +160,12 @@ if __name__ == "__main__":
|
|||
print('Log :\"',f.name, '\" saved !')
|
||||
except:
|
||||
print("Failed to save logs :",f.name)
|
||||
print(sys.exc_info()[0])
|
||||
try:
|
||||
plot_resV2(log, fig_name="../res/"+filename, param_names=aug_model.TF_names())
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
print(sys.exc_info()[0])
|
||||
|
||||
print('Execution Time : %.00f '%(exec_time))
|
||||
print('-'*9)
|
|
@ -10,6 +10,8 @@ import higher
|
|||
from datasets import *
|
||||
from utils import *
|
||||
|
||||
confmat = ConfusionMatrix(num_classes=len(dl_test.dataset.classes))
|
||||
|
||||
def test(model):
|
||||
"""Evaluate a model on test data.
|
||||
|
||||
|
@ -17,7 +19,7 @@ def test(model):
|
|||
model (nn.Module): Model to test.
|
||||
|
||||
Returns:
|
||||
(float, Tensor) Returns the accuracy and test loss of the model.
|
||||
(float, Tensor) Returns the accuracy and F1 score of the model.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
|
@ -30,7 +32,8 @@ def test(model):
|
|||
|
||||
correct = 0
|
||||
total = 0
|
||||
loss = []
|
||||
#loss = []
|
||||
confmat.reset()
|
||||
with torch.no_grad():
|
||||
for features, labels in dl_test:
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
@ -40,11 +43,16 @@ def test(model):
|
|||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
loss.append(F.cross_entropy(outputs, labels).item())
|
||||
#loss.append(F.cross_entropy(outputs, labels).item())
|
||||
confmat.update(labels, predicted)
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
|
||||
return accuracy, np.mean(loss)
|
||||
#print(confmat)
|
||||
#from sklearn.metrics import f1_score
|
||||
#f1 = f1_score(labels.data.to('cpu'), predicted.data.to('cpu'), average="macro")
|
||||
|
||||
return accuracy, confmat.f1_metric(average="macro")
|
||||
|
||||
def compute_vaLoss(model, dl_it, dl):
|
||||
"""Evaluate a model on a batch of data.
|
||||
|
@ -202,7 +210,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
|
||||
hp_opt (bool): Wether to learn inner optimizer parameters. (default: False)
|
||||
save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, only one save would be done at the end of the training. (default: None)
|
||||
save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, no sample will be saved. (default: None)
|
||||
|
||||
Returns:
|
||||
(list) Logs of training. Each items is a dict containing results of an epoch.
|
||||
|
@ -310,7 +318,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
# Test model
|
||||
accuracy, test_loss =test(model)
|
||||
accuracy, f1 =test(model)
|
||||
model.train()
|
||||
|
||||
#### Log ####
|
||||
|
@ -320,6 +328,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"f1": f1.cpu().numpy().tolist(),
|
||||
"time": tf - t0,
|
||||
|
||||
"param": param,
|
||||
|
@ -360,15 +369,6 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
||||
track_higher_grads=high_grad_track)
|
||||
|
||||
|
||||
#Data sample saving
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch))
|
||||
except:
|
||||
print("Couldn't save finals samples")
|
||||
pass
|
||||
|
||||
return log
|
||||
|
||||
def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1):
|
||||
|
|
|
@ -14,6 +14,75 @@ import torch.nn.functional as F
|
|||
|
||||
import time
|
||||
|
||||
class ConfusionMatrix(object):
|
||||
def __init__(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, a, b):
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||
with torch.no_grad():
|
||||
k = (a >= 0) & (a < n)
|
||||
inds = n * a[k].to(torch.int64) + b[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
if self.mat is not None:
|
||||
self.mat.zero_()
|
||||
|
||||
def compute(self):
|
||||
h = self.mat.float()
|
||||
acc_global = torch.diag(h).sum() / h.sum()
|
||||
acc = torch.diag(h) / h.sum(1)
|
||||
return acc_global, acc
|
||||
|
||||
def f1_metric(self, average=None):
|
||||
#https://discuss.pytorch.org/t/how-to-get-the-sensitivity-and-specificity-of-a-dataset/39373/6
|
||||
h = self.mat.float()
|
||||
TP = torch.diag(h)
|
||||
TN = []
|
||||
FP = []
|
||||
FN = []
|
||||
for c in range(self.num_classes):
|
||||
idx = torch.ones(self.num_classes).bool()
|
||||
idx[c] = 0
|
||||
# all non-class samples classified as non-class
|
||||
TN.append(self.mat[idx.nonzero()[:, None], idx.nonzero()].sum()) #conf_matrix[idx[:, None], idx].sum() - conf_matrix[idx, c].sum()
|
||||
# all non-class samples classified as class
|
||||
FP.append(self.mat[idx, c].sum())
|
||||
# all class samples not classified as class
|
||||
FN.append(self.mat[c, idx].sum())
|
||||
|
||||
#print('Class {}\nTP {}, TN {}, FP {}, FN {}'.format(c, TP[c], TN[c], FP[c], FN[c]))
|
||||
|
||||
tp = (TP/h.sum(1))#.sum()
|
||||
tn = (torch.tensor(TN, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
||||
fp = (torch.tensor(FP, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
||||
fn = (torch.tensor(FN, device=h.device, dtype=torch.float)/h.sum(1))#.sum()
|
||||
|
||||
if average=="micro":
|
||||
tp, tn, fp, fn = tp.sum(), tn.sum(), fp.sum(), fn.sum()
|
||||
|
||||
epsilon = 1e-7
|
||||
precision = tp / (tp + fp + epsilon)
|
||||
recall = tp / (tp + fn + epsilon)
|
||||
|
||||
f1 = 2* (precision*recall) / (precision + recall + epsilon)
|
||||
|
||||
if average=="macro":
|
||||
f1=f1.mean()
|
||||
return f1
|
||||
|
||||
def __str__(self):
|
||||
acc_global, acc = self.compute()
|
||||
return (
|
||||
'global correct: {:.1f}\n'
|
||||
'average row correct: {}').format(
|
||||
acc_global.item() * 100,
|
||||
['{:.1f}'.format(i) for i in (acc * 100).tolist()])
|
||||
|
||||
def print_graph(PyTorch_obj, fig_name='graph'):
|
||||
"""Save the computational graph.
|
||||
|
||||
|
@ -42,8 +111,21 @@ def plot_resV2(log, fig_name='res', param_names=None):
|
|||
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||||
ax[0, 0].legend()
|
||||
|
||||
ax[1, 0].set_title('Acc')
|
||||
ax[1, 0].plot(epochs,[x["acc"] for x in log])
|
||||
ax[1, 0].set_title('Test')
|
||||
ax[1, 0].plot(epochs,[x["acc"] for x in log], label='Acc')
|
||||
|
||||
if "f1" in log[0].keys():
|
||||
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1')
|
||||
'''
|
||||
#print(log[0]["f1"])
|
||||
if len(log[0]["f1"])==1:
|
||||
ax[1, 0].plot(epochs,[x["f1"]*100 for x in log], label='F1')
|
||||
else:
|
||||
for c in range(len(log[0]["f1"])):
|
||||
ax[1, 0].plot(epochs,[x["f1"][c]*100 for x in log], label='F1-'+str(c))
|
||||
'''
|
||||
|
||||
ax[1, 0].legend()
|
||||
|
||||
if log[0]["param"]!= None:
|
||||
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
||||
|
@ -73,7 +155,7 @@ def plot_resV2(log, fig_name='res', param_names=None):
|
|||
plt.sca(ax[1, 2]), plt.xticks(rotation=90)
|
||||
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
fig_name = fig_name.replace('.',',').replace(',,/','../')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue