import pandas
import seaborn as sns
[docs]
def plot_split_losses(
pred_tracker_dict,
out_fn=None,
splits=["train", "val", "test"],
loss_label="Loss",
legend_title="label",
label_trans=None,
for_fig=False,
**kwargs,
):
"""
Plot overall losses per split by training epoch.
Parameters
----------
pred_tracker_dict : dict[str, TrainingPredictionTracker]
Dict mapping labels to pred trackers
out_fn : Path, optional
Path to save plot to
splits : list[str], default=["train", "val", "test"]
Which splits to actually plot
loss_label : str, default="Loss"
What to label the y-axis of the plot
legend_title : str, default="label"
Column name for the dict keys, which will be used as the Legend title by default
label_trans : callable, optional
Function that should take a string as input and return a dict mapping
str -> str. This function will be applied to each label, and each key in the
output will be added as a column to the DataFrame with its corresponding value
as the entry for that row in DF
for_fig : bool, default=False
Plotting for a figure rather than just visualization. Will take some liberties
with capitalization to make labels look a bit more professional
kwargs : dict
Anything else to pass directly to relplot
"""
# Build overall DF
all_dfs = []
for lab, pred_tracker in pred_tracker_dict.items():
df = pred_tracker.to_plot_df(agg_compounds=True, agg_losses=True)
if for_fig:
df[legend_title] = lab.title()
else:
df[legend_title] = lab
# Apply label transform and add any new columns
if callable(label_trans):
new_cols = label_trans(lab)
for k, v in new_cols.items():
df[k] = v
all_dfs.append(df)
all_dfs = pandas.concat(all_dfs, ignore_index=True)
# Subset
all_dfs = all_dfs.loc[all_dfs["split"].isin(splits), :]
# Capitalize for figure
if for_fig:
all_dfs["Split"] = [s.title() for s in all_dfs["split"]]
splits_fig = [s.title() for s in splits]
if ("hue" not in kwargs) and ("style" not in kwargs):
# Figure out styles
if len(pred_tracker_dict) > 1:
# More than one different experiment, so use color for experiment and style
# for split
hue = legend_title
hue_order = [lab.title() for lab in pred_tracker_dict.keys()]
if len(splits) > 1:
if for_fig:
style = "Split"
style_order = splits_fig
else:
style = "split"
style_order = splits
else:
style = None
style_order = None
else:
if for_fig:
hue = "Split"
hue_order = splits_fig
else:
hue = "split"
hue_order = splits
style = None
style_order = None
else:
# Pull from kwargs
hue = kwargs.pop("hue", None)
hue_order = kwargs.pop("hue_order", None)
style = kwargs.pop("style", None)
style_order = kwargs.pop("style_order", None)
# Other various kwargs
aspect = kwargs.pop("aspect", 1.5)
# Make plot
# fig = plt.figure(figsize=(7, 5))
fg = sns.relplot(
all_dfs,
x="epoch",
y="loss",
hue=hue,
style=style,
hue_order=hue_order,
style_order=style_order,
kind="line",
aspect=aspect,
**kwargs,
)
# Set axes
fg.set_axis_labels("Training Epoch", loss_label)
if out_fn:
fg.savefig(out_fn, bbox_inches="tight", dpi=200)
return fg
[docs]
def plot_model_preds_scatter(
pred_tracker_dict,
stats_dict,
out_fn=None,
split="test",
use_epoch=-1,
label_trans=None,
plot_stats=True,
table_stats=False,
**kwargs,
):
"""
Plot a scatterplot of experimental vs predicted values.
Parameters
----------
pred_tracker_dict : dict[str, TrainingPredictionTracker]
Dict mapping labels to pred trackers
stats_dict : dict
Dict mapping lab -> pred stats (generated by
pred_tracker.calculate_pred_statistics)
out_fn : Path, optional
Path to save plot to
split : str, default="test"
Which split to plot
use_epoch : int, default=-1
Which epoch of training to take predictions from. Set to -1 to use final epoch
label_trans : callable, optional
Function that should take a string as input and return a dict mapping
str -> str. This function will be applied to each label, and each key in the
output will be added as a column to the DataFrame with its corresponding value
as the entry for that row in DF
kwargs : dict
Anything else to pass directly to relplot
Returns
-------
"""
# Build overall DF
all_dfs = []
for lab, pred_tracker in pred_tracker_dict.items():
df = pred_tracker.to_plot_df(agg_losses=True)
df["label"] = lab
# Apply label transform and add any new columns
if callable(label_trans):
new_cols = label_trans(lab)
for k, v in new_cols.items():
df[k] = v
all_dfs.append(df)
all_dfs = pandas.concat(all_dfs, ignore_index=True)
# Subset by split and epoch
epoch_idx = []
for lab, g in all_dfs.groupby("label"):
if use_epoch < 0:
cur_use_epoch = g["epoch"].max()
else:
cur_use_epoch = use_epoch
epoch_idx.extend(g.index[g["epoch"] == cur_use_epoch])
all_dfs = all_dfs.iloc[epoch_idx, :]
split_idx = all_dfs["split"] == split
all_dfs = all_dfs.loc[split_idx, :]
# Set so the legend looks nicer
legend_text_mapper = {
-1: "Below Assay Range",
0: "In Assay Range",
1: "Above Assay Range",
}
all_dfs["Assay Range"] = list(map(legend_text_mapper.get, all_dfs["in_range"]))
# If any facet_kws are passed in kwargs, update the defaults
facet_kws = {"sharex": False, "sharey": False} | kwargs.pop("facet_kws", {})
col = kwargs.pop("col", "label")
col_order = kwargs.pop("col_order", list(all_dfs[col].unique()))
if table_stats and plot_stats:
col_order += ["blank"]
# plt.rc("font", size=18)
fg = sns.relplot(
data=all_dfs,
x="target",
y="pred",
col=col,
col_order=col_order,
style="Assay Range",
markers={
"Below Assay Range": "<",
"In Assay Range": "o",
"Above Assay Range": ">",
},
style_order=["Below Assay Range", "In Assay Range", "Above Assay Range"],
facet_kws=facet_kws,
**kwargs,
)
# Figure title
fg.figure.subplots_adjust(top=0.8)
fg.figure.suptitle("Test Set Predictions", fontweight="bold")
# Axes bounds
min_val = -0.5
max_val = all_dfs.loc[:, ["target", "pred"]].values.flatten().max() + 0.5
# Axis labels
for ax in fg.axes[:, 0]:
ax.set_ylabel(r"Predicted $\mathrm{pIC}_{50}$")
for ax in fg.axes[-1, :]:
ax.set_xlabel(r"Experimental $\mathrm{pIC}_{50}$")
sns.move_legend(fg, loc="upper center", bbox_to_anchor=(0.5, 0), ncols=3)
if table_stats and plot_stats:
stats_table_text = [
[""],
["MAE"],
["RMSE"],
["Spearman's $\\rho$"],
["Kendall's $\\tau$"],
]
for lab, ax in fg.axes_dict.items():
if lab == "blank":
ax.set_title("")
continue
# Set title
ax.set_title(lab, fontweight="bold")
# Plot y=x line
ax.plot(
[min_val, max_val],
[min_val, max_val],
color="black",
ls="--",
)
# Shade 0.5 pIC50 and 1 pIC50 regions
ax.fill_between(
[min_val, max_val],
[min_val - 0.5, max_val - 0.5],
[min_val + 0.5, max_val + 0.5],
color="gray",
alpha=0.2,
)
ax.fill_between(
[min_val, max_val],
[min_val - 1, max_val - 1],
[min_val + 1, max_val + 1],
color="gray",
alpha=0.2,
)
# Stats labels
if table_stats and plot_stats:
stats_table_text[0].append(lab)
for i, stat in enumerate(["mae", "rmse", "sp_r", "tau"]):
stats_str = (
f"{stats_dict[lab]['test'][stat]['value']:0.2f}"
f"$_{{{stats_dict[lab]['test'][stat]['95ci_low']:0.2f}}}"
f"^{{{stats_dict[lab]['test'][stat]['95ci_high']:0.2f}}}$"
)
stats_table_text[i + 1].append(stats_str)
elif plot_stats:
stats_text = []
for stat, stat_label in zip(
["mae", "rmse", "sp_r", "tau"],
["MAE", "RMSE", "Spearman's $\\rho$", "Kendall's $\\tau$"],
):
stats_str = (
f"{stat_label}: "
f"{stats_dict[lab]['test'][stat]['value']:0.2f}"
f"$_{{{stats_dict[lab]['test'][stat]['95ci_low']:0.2f}}}"
f"^{{{stats_dict[lab]['test'][stat]['95ci_high']:0.2f}}}$"
)
stats_text.append(stats_str)
ax.text(
0.7,
0,
"\n".join(stats_text),
transform=ax.transAxes,
va="bottom",
linespacing=0.8,
# fontsize=14,
)
# Make it a square
ax.set_aspect("equal", "box")
ax.set_xlim((min_val, max_val))
ax.set_ylim((min_val, max_val))
if table_stats and plot_stats:
ax = fg.axes.flatten()[-1]
ax.set_axis_off()
ax.table(
cellText=stats_table_text, cellLoc="center", loc="center", edges="open"
)
if out_fn:
fg.savefig(out_fn, bbox_inches="tight", dpi=200)
return fg