mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Brutus
This commit is contained in:
parent
53bd421670
commit
e291bc2e44
9 changed files with 55 additions and 44 deletions
|
@ -157,7 +157,7 @@ def accuracy(output, target, topk=(1,)):
|
|||
|
||||
class EarlyStopping:
|
||||
"""Early stops the training if validation loss doesn't improve after a given patience."""
|
||||
def __init__(self, patience=7, verbose=False, delta=0):
|
||||
def __init__(self, patience=7, verbose=False, delta=0, augmented_model=False):
|
||||
"""
|
||||
Args:
|
||||
patience (int): How long to wait after last time validation loss improved.
|
||||
|
@ -175,6 +175,8 @@ class EarlyStopping:
|
|||
self.val_loss_min = np.Inf
|
||||
self.delta = delta
|
||||
|
||||
self.augmented_model = augmented_model
|
||||
|
||||
def __call__(self, val_loss, model):
|
||||
|
||||
score = -val_loss
|
||||
|
@ -196,5 +198,5 @@ class EarlyStopping:
|
|||
'''Saves model when validation loss decrease.'''
|
||||
if self.verbose:
|
||||
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
|
||||
torch.save(model.state_dict(), 'checkpoint.pt')
|
||||
torch.save(model.state_dict(), 'checkpoint.pt') if not self.augmented_model else torch.save(model['model'].state_dict(), 'checkpoint.pt')
|
||||
self.val_loss_min = val_loss
|
Loading…
Add table
Add a link
Reference in a new issue