This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-13 18:02:36 -05:00
parent 53bd421670
commit e291bc2e44
9 changed files with 55 additions and 44 deletions

View file

@ -157,7 +157,7 @@ def accuracy(output, target, topk=(1,)):
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0):
def __init__(self, patience=7, verbose=False, delta=0, augmented_model=False):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
@ -175,6 +175,8 @@ class EarlyStopping:
self.val_loss_min = np.Inf
self.delta = delta
self.augmented_model = augmented_model
def __call__(self, val_loss, model):
score = -val_loss
@ -196,5 +198,5 @@ class EarlyStopping:
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), 'checkpoint.pt')
torch.save(model.state_dict(), 'checkpoint.pt') if not self.augmented_model else torch.save(model['model'].state_dict(), 'checkpoint.pt')
self.val_loss_min = val_loss