Confmat / F1 + Minor fix

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-31 16:43:10 -05:00
parent 250ce2c3cf
commit 3ccacd0366
5 changed files with 120 additions and 32 deletions

View file

@ -814,15 +814,16 @@ class Higher_model(nn.Module):
_name (string): Name of the model.
_mods (nn.ModuleDict): Models (Orginial and Higher version).
"""
def __init__(self, model):
def __init__(self, model, model_name=None):
"""Init Higher_model.
Args:
model (nn.Module): Network for which higher gradients can be tracked.
model_name (string): Model name. (Default: Class name of model)
"""
super(Higher_model, self).__init__()
self._name = model.__class__.__name__ #model.__str__()
self._name = model_name if model_name else model.__class__.__name__ #model.__str__()
self._mods = nn.ModuleDict({
'original': model,
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)