drugforge.ml.early_stopping.GeneralizationLossEarlyStopping

class drugforge.ml.early_stopping.GeneralizationLossEarlyStopping(alpha, burnin=0)[source]

Bases: object

Class for stopping based on the relative increase of val loss at epoch t from lowest val loss up to epoch t. GL_alpha loss in the book. GL is defined as 100 * (Err_min / Err(t) - 1)

__init__(alpha, burnin=0)[source]
Parameters:
  • alpha (float) – Relative increase threshold

  • burnin (int, default = 0) – If given, ensure that at least this many epochs of training have been done before we stop

Methods

__init__(alpha[, burnin])

check(epoch, loss, wts_dict)

Check if training should be stopped.

check(epoch, loss, wts_dict)[source]

Check if training should be stopped. Return True to stop, False to keep going.

Parameters:
  • loss (float) – Model loss from the current epoch of training

  • wts_dict (dict) – Weights dict from Pytorch for keeping track of the best model

Returns:

Whether to stop training

Return type:

bool