drugforge.ml.early_stopping.ConvergedEarlyStopping

class drugforge.ml.early_stopping.ConvergedEarlyStopping(n_check, divergence, burnin=0)[source]

Bases: object

Class for handling early stopping in training based on whether loss is still changing. Check that the mean difference of the past n losses from the average of those losses is within tolerance.

__init__(n_check, divergence, burnin=0)[source]
Parameters:
  • n_check (int) – Number of past epochs to keep track of when calculating divergence

  • divergence (float) – Max allowable difference from the mean of the losses

  • burnin (int, optional) – If given, ensure that at least this many epochs of training have been done before we stop

Methods

__init__(n_check, divergence[, burnin])

check(epoch, loss)

Check if training should be stopped.

check(epoch, loss)[source]

Check if training should be stopped. Return True to stop, False to keep going.

Parameters:

loss (float) – Loss from the previous training epoch

Returns:

Whether to stop training

Return type:

bool