drugforge.ml.early_stopping.ProgressQuotientEarlyStopping
- class drugforge.ml.early_stopping.ProgressQuotientEarlyStopping(alpha, k, burnin=0)[source]
Bases:
objectClass for stopping based on quotient of generalization loss and training progress. PQ_alpha in the book. This criterion is calculated after every k epochs, and interstitial epochs will automatically not be stopped at.
- __init__(alpha, k, burnin=0)[source]
- Parameters:
alpha (float) – Quotient threshold
k (int) – Length of training strip to evaluate at the end of
burnin (int, default = 0) – If given, ensure that at least this many epochs of training have been done before we stop
Methods
__init__(alpha, k[, burnin])check(epoch, loss, wts_dict, train_loss)Check if training should be stopped.
- check(epoch, loss, wts_dict, train_loss)[source]
Check if training should be stopped. Return True to stop, False to keep going.
- Parameters:
loss (float) – Model loss from the current epoch of training
wts_dict (dict) – Weights dict from Pytorch for keeping track of the best model
- Returns:
Whether to stop training
- Return type:
bool