drugforge.ml.viz.plot_split_losses
- drugforge.ml.viz.plot_split_losses(pred_tracker_dict, out_fn=None, splits=['train', 'val', 'test'], loss_label='Loss', legend_title='label', label_trans=None, for_fig=False, **kwargs)[source]
Plot overall losses per split by training epoch.
- Parameters:
pred_tracker_dict (dict[str, TrainingPredictionTracker]) – Dict mapping labels to pred trackers
out_fn (Path, optional) – Path to save plot to
splits (list[str], default=[“train”, “val”, “test”]) – Which splits to actually plot
loss_label (str, default=”Loss”) – What to label the y-axis of the plot
legend_title (str, default=”label”) – Column name for the dict keys, which will be used as the Legend title by default
label_trans (callable, optional) – Function that should take a string as input and return a dict mapping str -> str. This function will be applied to each label, and each key in the output will be added as a column to the DataFrame with its corresponding value as the entry for that row in DF
for_fig (bool, default=False) – Plotting for a figure rather than just visualization. Will take some liberties with capitalization to make labels look a bit more professional
kwargs (dict) – Anything else to pass directly to relplot