"""
Class for handling early stopping in training.
"""
from copy import deepcopy
import numpy as np
def _sanitize_loss(loss):
"""
Helper function for the ES classes to make sure that they receive a single float as
their loss value. If an iterable of floats is passed, the mean loss will be returned
Parameters
----------
loss : Union[float, List[float], np.ndarray, torch.Tensor]
Loss value(s)
Returns
-------
float
Sanitized loss value
"""
try:
# This should work for common types of numeric values (single float, list,
# tensor, etc of floats)
return np.asarray(loss).mean()
except Exception:
raise ValueError(f"Bad value passed for loss: {loss}")
[docs]
class BestEarlyStopping:
"""
Class for handling early stopping in training based on improvement over best loss.
"""
[docs]
def __init__(self, patience, burnin=0):
"""
Parameters
----------
patience : int
The maximum number of epochs to continue training with no improvement in the
val loss. If not given, no early stopping will be performed
burnin : int, optional
If given, ensure that at least this many epochs of training have been done
before we stop
"""
super().__init__()
self.patience = patience
self.burnin = burnin
# Variables to track early stopping
self.counter = 0
self.best_loss = None
self.best_wts = None
self.best_epoch = 0
[docs]
def check(self, epoch, loss, wts_dict):
"""
Check if training should be stopped. Return True to stop, False to keep going.
Parameters
----------
loss : float
Model loss from the current epoch of training
wts_dict : dict
Weights dict from Pytorch for keeping track of the best model
Returns
-------
bool
Whether to stop training
"""
# Make sure we've got a reasonable value for loss
loss = _sanitize_loss(loss)
# If this is the first epoch, just set internal variables and return
if self.best_loss is None:
self.best_loss = loss
# Need to deepcopy so it doesn't update with the model weights
self.best_wts = deepcopy(wts_dict)
return False
# Update best loss and best weights
if loss < self.best_loss:
self.best_loss = loss
# Need to deepcopy so it doesn't update with the model weights
self.best_wts = deepcopy(wts_dict)
self.best_epoch = epoch
# Reset counter
self.counter = 0
# Keep training
return False
# Increment counter and check for stopping
self.counter += 1
if (self.counter >= self.patience) and (epoch >= self.burnin):
return True
return False
[docs]
class ConvergedEarlyStopping:
"""
Class for handling early stopping in training based on whether loss is still
changing. Check that the mean difference of the past n losses from the average of
those losses is within tolerance.
"""
[docs]
def __init__(self, n_check, divergence, burnin=0):
"""
Parameters
----------
n_check : int
Number of past epochs to keep track of when calculating divergence
divergence : float
Max allowable difference from the mean of the losses
burnin : int, optional
If given, ensure that at least this many epochs of training have been done
before we stop
"""
super().__init__()
self.n_check = n_check
self.divergence = divergence
self.burnin = burnin
# Variables to track early stopping
self.losses = []
[docs]
def check(self, epoch, loss):
"""
Check if training should be stopped. Return True to stop, False to keep going.
Parameters
----------
loss : float
Loss from the previous training epoch
Returns
-------
bool
Whether to stop training
"""
# Make sure we've got a reasonable value for loss
loss = _sanitize_loss(loss)
# Add most recent loss
self.losses += [loss]
# Don't have enough samples yet, so keep training
if len(self.losses) < self.n_check:
return False
# Full loss buffer, so get rid of earliest loss
if len(self.losses) > self.n_check:
self.losses = self.losses[1:]
# Check for early stopping
mean_loss = np.mean(self.losses)
all_abs_diff = np.abs(np.asarray(self.losses) - mean_loss)
return (np.mean(all_abs_diff) < self.divergence) and (epoch >= self.burnin)
[docs]
class PatientConvergedEarlyStopping:
"""
Class for handling early stopping in training based on whether loss is still
changing, with patience. Check that the mean difference of the past n losses from
the average of those losses is within tolerance, then wait to make sure it's not a
temporary plateau.
"""
[docs]
def __init__(self, n_check, divergence, patience, burnin=0):
"""
Parameters
----------
n_check : int
Number of past epochs to keep track of when calculating divergence
divergence : float
Max allowable difference from the mean of the losses
patience : int
The maximum number of epochs to wait after convergence
burnin : int, optional
If given, ensure that at least this many epochs of training have been done
before we stop
"""
super().__init__()
self.n_check = n_check
self.divergence = divergence
self.patience = patience
self.burnin = burnin
# Variables to track early stopping
# Window of losses to check for convergence
self.losses = []
# Tracker for if we've reached convergence
self.converged = False
# Tracker for how many epochs it's been since we've converged
self.counter = 0
# Loss val at convergence
self.converged_loss = None
# Model weights at convergence
self.converged_wts = None
# Epoch we reached convergence
self.converged_epoch = 0
[docs]
def check(self, epoch, loss, wts_dict):
"""
Check if training should be stopped. Return True to stop, False to keep going.
Parameters
----------
loss : float
Model loss from the current epoch of training
wts_dict : dict
Weights dict from Pytorch for keeping track of the best model
Returns
-------
bool
Whether to stop training
"""
# Make sure we've got a reasonable value for loss
loss = _sanitize_loss(loss)
# Add most recent loss
self.losses += [loss]
# Don't have enough samples yet, so keep training
if len(self.losses) < self.n_check:
return False
# Full loss buffer, so get rid of earliest loss
if len(self.losses) > self.n_check:
self.losses = self.losses[1:]
# Check for early stopping
mean_loss = np.mean(self.losses)
all_abs_diff = np.abs(np.asarray(self.losses) - mean_loss)
converged = np.mean(all_abs_diff) < self.divergence
if converged:
if self.converged:
self.counter += 1
print("converged patience counter", self.counter, flush=True)
if (self.counter >= self.patience) and (epoch >= self.burnin):
return True
else:
self.converged = True
self.converged_loss = loss
# Need to deepcopy so it doesn't update with the model weights
self.converged_wts = deepcopy(wts_dict)
self.converged_epoch = epoch
elif self.converged:
# Reset everything
self.converged = False
self.counter = 0
self.converged_loss = None
self.converged_wts = None
self.converged_epoch = 0
return False
[docs]
class ThresholdEarlyStopping:
"""
Class for handling early stopping in training based on whether loss has been below
a certain threshold for some number of epochs.
"""
[docs]
def __init__(self, threshold, patience, burnin=0):
"""
Parameters
----------
threshold : float
Loss below which to stop model training
patience : ing
Number of epochs to wait once loss has dipped below threshold to make sure
it stays there
burnin : int, optional
If given, ensure that at least this many epochs of training have been done
before we stop
"""
super().__init__()
self.threshold = threshold
self.patience = patience
self.burnin = burnin
# Variables to track early stopping
self.converged_epochs = 0
[docs]
def check(self, epoch, loss):
"""
Check if training should be stopped. Return True to stop, False to keep going.
Parameters
----------
epoch : int
Current training epoch
loss : float
Model loss from the current epoch of training
Returns
-------
bool
Whether to stop training
"""
# Make sure we've got a reasonable value for loss
loss = _sanitize_loss(loss)
if loss > self.threshold:
self.converged_epochs = 0
return False
self.converged_epochs += 1
if (self.converged_epochs == self.patience) and (epoch >= self.burnin):
return True
return False
# Losses from the Tricks of the Trade book
[docs]
class GeneralizationLossEarlyStopping:
"""
Class for stopping based on the relative increase of val loss at epoch t from
lowest val loss up to epoch t. GL_alpha loss in the book.
GL is defined as 100 * (Err_min / Err(t) - 1)
"""
[docs]
def __init__(self, alpha, burnin=0):
"""
Parameters
----------
alpha : float
Relative increase threshold
burnin : int, default = 0
If given, ensure that at least this many epochs of training have been done
before we stop
"""
super().__init__()
self.alpha = alpha
self.burnin = burnin
# Variables to track early stopping
self.best_loss = None
self.best_wts = None
self.best_epoch = 0
[docs]
def check(self, epoch, loss, wts_dict):
"""
Check if training should be stopped. Return True to stop, False to keep going.
Parameters
----------
loss : float
Model loss from the current epoch of training
wts_dict : dict
Weights dict from Pytorch for keeping track of the best model
Returns
-------
bool
Whether to stop training
"""
# Make sure we've got a reasonable value for loss
loss = _sanitize_loss(loss)
# If this is the first epoch, just set internal variables and return
if self.best_loss is None:
self.best_loss = loss
# Need to deepcopy so it doesn't update with the model weights
self.best_wts = deepcopy(wts_dict)
return False
# Update best loss and best weights
if loss < self.best_loss:
self.best_loss = loss
# Need to deepcopy so it doesn't update with the model weights
self.best_wts = deepcopy(wts_dict)
self.best_epoch = epoch
if epoch < self.burnin:
return False
# Calculate generalization loss and check stopping criteria
generalization_loss = 100 * (loss / self.best_loss - 1)
return generalization_loss > self.alpha
[docs]
class ProgressQuotientEarlyStopping:
"""
Class for stopping based on quotient of generalization loss and training progress.
PQ_alpha in the book. This criterion is calculated after every k epochs, and
interstitial epochs will automatically not be stopped at.
"""
[docs]
def __init__(self, alpha, k, burnin=0):
"""
Parameters
----------
alpha : float
Quotient threshold
k : int
Length of training strip to evaluate at the end of
burnin : int, default = 0
If given, ensure that at least this many epochs of training have been done
before we stop
"""
super().__init__()
self.alpha = alpha
self.k = k
self.burnin = burnin
# Variables to track early stopping
self.best_loss = None
self.best_wts = None
self.best_epoch = 0
self.strip_train_losses = []
[docs]
def check(self, epoch, loss, wts_dict, train_loss):
"""
Check if training should be stopped. Return True to stop, False to keep going.
Parameters
----------
loss : float
Model loss from the current epoch of training
wts_dict : dict
Weights dict from Pytorch for keeping track of the best model
Returns
-------
bool
Whether to stop training
"""
# Make sure we've got a reasonable value for loss
loss = _sanitize_loss(loss)
train_loss = _sanitize_loss(train_loss)
self.strip_train_losses += [train_loss]
# If this is the first epoch, just set internal variables and return
if self.best_loss is None:
self.best_loss = loss
# Need to deepcopy so it doesn't update with the model weights
self.best_wts = deepcopy(wts_dict)
# Update best loss and best weights
if loss < self.best_loss:
self.best_loss = loss
# Need to deepcopy so it doesn't update with the model weights
self.best_wts = deepcopy(wts_dict)
self.best_epoch = epoch
# Make sure we're at the end of a training strip
if (
(epoch < self.burnin)
or ((epoch + 1) < self.k)
or ((epoch + 1) % self.k != 0)
):
return False
# Calculate generalization loss and progress
generalization_loss = 100 * (loss / self.best_loss - 1)
progress = 1000 * (
np.mean(self.strip_train_losses) / np.min(self.strip_train_losses) - 1
)
self.strip_train_losses = []
return generalization_loss / progress > self.alpha