drugforge.ml.early_stopping.PatientConvergedEarlyStopping

class drugforge.ml.early_stopping.PatientConvergedEarlyStopping(n_check, divergence, patience, burnin=0)[source]

Bases: object

Class for handling early stopping in training based on whether loss is still changing, with patience. Check that the mean difference of the past n losses from the average of those losses is within tolerance, then wait to make sure it’s not a temporary plateau.

__init__(n_check, divergence, patience, burnin=0)[source]
Parameters:
  • n_check (int) – Number of past epochs to keep track of when calculating divergence

  • divergence (float) – Max allowable difference from the mean of the losses

  • patience (int) – The maximum number of epochs to wait after convergence

  • burnin (int, optional) – If given, ensure that at least this many epochs of training have been done before we stop

Methods

__init__(n_check, divergence, patience[, burnin])

check(epoch, loss, wts_dict)

Check if training should be stopped.

check(epoch, loss, wts_dict)[source]

Check if training should be stopped. Return True to stop, False to keep going.

Parameters:
  • loss (float) – Model loss from the current epoch of training

  • wts_dict (dict) – Weights dict from Pytorch for keeping track of the best model

Returns:

Whether to stop training

Return type:

bool