mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Confmat / F1 + Minor fix
This commit is contained in:
parent
250ce2c3cf
commit
3ccacd0366
5 changed files with 120 additions and 32 deletions
|
@ -814,15 +814,16 @@ class Higher_model(nn.Module):
|
|||
_name (string): Name of the model.
|
||||
_mods (nn.ModuleDict): Models (Orginial and Higher version).
|
||||
"""
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, model_name=None):
|
||||
"""Init Higher_model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Network for which higher gradients can be tracked.
|
||||
model_name (string): Model name. (Default: Class name of model)
|
||||
"""
|
||||
super(Higher_model, self).__init__()
|
||||
|
||||
self._name = model.__class__.__name__ #model.__str__()
|
||||
self._name = model_name if model_name else model.__class__.__name__ #model.__str__()
|
||||
self._mods = nn.ModuleDict({
|
||||
'original': model,
|
||||
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue