{ "cells": [ { "cell_type": "markdown", "id": "431376e6-49ed-472e-9c07-61e7ef4eb2a0", "metadata": {}, "source": [ "# Training ML models " ] }, { "cell_type": "markdown", "id": "18e6585c-6681-474c-bb55-dd9d7429e510", "metadata": {}, "source": [ "## Prerequisites\n", "\n", "Before starting this tutorial, you should have already worked through the tutorials on [Interfacing with databases and systems](https://asapdiscovery.readthedocs.io/en/latest/tutorials/index.html#interfacing-with-databases-and-systems) and [Docking and scoring](https://asapdiscovery.readthedocs.io/en/latest/tutorials/index.html#docking-and-scoring).\n", "\n", "We will use publicly available test files as the starting point for this tutorial, but feel free to substitute with your own data." ] }, { "cell_type": "code", "execution_count": 1, "id": "8575246f-bd67-403c-b46a-b2c753e7659e", "metadata": {}, "outputs": [], "source": [ "# First download the needed files\n", "from drugforge.data.testing.test_resources import fetch_test_file\n", "\n", "docked_files = fetch_test_file(\n", " [\n", " \"ml_testing/docked/AAR-POS-5507155c-1_Mpro-P0018_0A_0_bound_best.pdb\",\n", " \"ml_testing/docked/AAR-RCN-390aeb1f-1_Mpro-P3074_0A_0_bound_best.pdb\",\n", " \"ml_testing/docked/AAR-RCN-67438d21-1_Mpro-P3074_0A_0_bound_best.pdb\",\n", " \"ml_testing/docked/AAR-POS-d2a4d1df-27_Mpro-P0238_0A_0_bound_best.pdb\",\n", " \"ml_testing/docked/AAR-POS-8a4e0f60-7_Mpro-P0053_0A_0_bound_best.pdb\",\n", " \"ml_testing/docked/AAR-RCN-28a8122f-1_Mpro-P2005_0A_0_bound_best.pdb\",\n", " \"ml_testing/docked/AAR-RCN-37d0aa00-1_Mpro-P3074_0A_0_bound_best.pdb\",\n", " \"ml_testing/docked/AAR-RCN-748c104b-1_Mpro-P3074_0A_0_bound_best.pdb\",\n", " \"ml_testing/docked/AAR-RCN-521d1733-1_Mpro-P2005_0A_0_bound_best.pdb\",\n", " \"ml_testing/docked/AAR-RCN-845f9611-1_Mpro-P2005_0A_0_bound_best.pdb\",\n", " ]\n", ")\n", "docked_dir = docked_files[0].parent\n", "unfilt_csv_file = fetch_test_file(\"ml_testing/cdd_unfiltered.csv\")" ] }, { "cell_type": "markdown", "id": "bd7d9d63-75d7-405c-a183-ffadb141f0d4", "metadata": {}, "source": [ "## Intro\n", "\n", "In this guide, we will start with a CSV file downloaded from CDD and a directory of docked protein-ligand complex PDB files. These are the two outputs from the guides mentioned in the Prerequisites section, so be sure to complete those if you haven't already." ] }, { "cell_type": "markdown", "id": "42109390-694f-4638-b9b5-1ecfaa20a49f", "metadata": {}, "source": [ "## Preparing the experimental data\n", "\n", "Before using the data in training, we will do some filtering and processing of the experimental data to ensure that everything is in the correct format. We will use all the default values for column names, which come from the [COVID Moonshot project](https://www.science.org/doi/10.1126/science.abo7201). See the docs for individual functions to see how these values can be tuned for your use-case." ] }, { "cell_type": "code", "execution_count": 2, "id": "7b3b399e-4d94-4940-ad82-e266ae0898d4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Replacing unicode character with - in UNK-CYC-68f84b31−70\n", "Wrote cdd_filtered_processed.json\n" ] } ], "source": [ "from drugforge.data.util.utils import (\n", " cdd_to_schema,\n", " filter_molecules_dataframe,\n", " parse_fluorescence_data_cdd,\n", ")\n", "from pathlib import Path\n", "\n", "import pandas\n", "\n", "# Replace this name with whatever you've saved your CDD download as\n", "mol_df = pandas.read_csv(unfilt_csv_file)\n", "\n", "\"\"\"\n", "In this example, we will ultimately use this data to train both 2D and structure-based\n", "models, so we will keep all achiral and enantiopure molecules, including any molecules\n", "with semiquantitative fluorescence values.\n", "\"\"\"\n", "mol_df_filt = filter_molecules_dataframe(\n", " mol_df,\n", " retain_achiral=True,\n", " retain_enantiopure=True,\n", " retain_semiquantitative_data=True,\n", ")\n", "\n", "\"\"\"\n", "In addition to being appropriately filtered, mol_df_filt now contains some identifying\n", "colums with standardized names.\n", "\n", "The parse_fluorescence_data_cdd function standardizes the fluorescence assay results,\n", "adding a number of columns to the data frame. In addition to IC50 and pIC50 values, it\n", "will also calculate deltaG in kcal/mol and kT units. If you know your fluorescence\n", "assay conditions and they were consistent across all molecules, you can supply the\n", "information to the cp_values arg as a tuple of (substrate_concentration, Km), which the\n", "function will use in the Cheng-Prusoff equation to calculate the deltaG values. If you\n", "don't have these values, the function will use a less accurate approximation. We will\n", "exclude the Cheng-Prusoff values in this example for simplicity.\n", "\n", "More details on the columns that the function expects the input to have and that it\n", "adds to the output can be found in the function's docstring.\n", "\"\"\"\n", "mol_df_filt_processed = parse_fluorescence_data_cdd(mol_df_filt)\n", "\n", "# Save the processed data\n", "mol_df_filt_processed.to_csv(\"cdd_filtered_processed.csv\", index=False)\n", "\n", "\"\"\"\n", "The last step in this process is to convert it into the format that the ML pipeline\n", "expects it in. The below function does that, taking the previously generated CSV file\n", "as input and producing a JSON file that we will load later.\n", "\"\"\"\n", "_ = cdd_to_schema(\n", " cdd_csv=\"cdd_filtered_processed.csv\", out_json=\"cdd_filtered_processed.json\"\n", ")" ] }, { "cell_type": "markdown", "id": "d9e3ad3e-7725-44d2-b791-4e678cf8cbeb", "metadata": {}, "source": [ "## Building the ML `Trainer` object\n", "\n", "Training in the `drugforge-ml` repo is controlled by the `drugforge.ml.trainer.Trainer` class, which itself is composed of several different config `Pydantic` schema objects. These configs are defined in `drugforge.ml.config`, with the exception of the config for the actual model, which is defined in our ML backend [`mtenn`](https://github.com/choderalab/mtenn/) (see the [`mtenn.config`](https://mtenn.readthedocs.io/en/latest/_autosummary/mtenn.config.html#module-mtenn.config) docs for more information).\n", "The required configs are:\n", "\n", "* `OptimizerConfig`: Config describing the optimizer to use in training\n", "* `ModelConfigBase`: Config describing the model to train (defined in `mtenn.config`)\n", "* `EarlyStoppingConfig`: Config describing the early stopping check to use\n", "* `DatasetConfig`: Config describing the dataset object to train on\n", "* `DatasetSplitterConfig`: Config describing how to split the dataset into train, val, and test splits\n", "* `LossFunctionConfig`: Config describing the loss function for training\n", "* `DataAugConfig`: Config describing data augmentations to be applied to each pose\n", "\n", "Note that below we will define each of these configs separately, but this can also be done automatically by passing a `dict` defining each config to the `Trainer` instead of the actual config object.\n", "\n", "In the below example, we'll build a different `Trainer` object with a different model config for each model type that we are demoing:\n", "\n", "* Graph attention ([GAT](https://arxiv.org/abs/1710.10903)): a topology-only, ligand-only GNN\n", "* [SchNet](https://arxiv.org/abs/1706.08566): an E(3)-invariant, structure-based model\n", "* [e3nn](https://arxiv.org/abs/2207.09453): an E(3)-equivariant, structure-based model\n", "\n", "Because we use a config-based setup, we can share configs between the different `Trainer`s as the configs will be used to build the actual Python objects rather than doing the work themselves. The only exception to this sharing is, as previously mentioned, the model configs, which are model-specific and will need to be different for each `Trainer`." ] }, { "cell_type": "code", "execution_count": 3, "id": "cfb2587e-3fe1-4f2a-b411-16dfd3d999c1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Filtering 0 structures that we don't have experimental data for.\n", "10 10\n", "Filtering 0 structures that we don't have experimental data for.\n", "10 10\n", "loading from cache\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.3" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "ds lengths 991 127 93\n", "Epoch 0/1\n" ] }, { "data": { "text/html": [ "\n", "

Run history:


epoch
epoch_time
test_loss
train_loss
val_loss

Run summary:


epoch0
epoch_time3.7519
test_loss1.42651
train_loss7.22673
val_loss2.74728

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "loading from cache\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.3" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "ds lengths 2 2 6\n", "Epoch 0/1\n" ] }, { "data": { "text/html": [ "\n", "

Run history:


epoch
epoch_time
test_loss
train_loss
val_loss

Run summary:


epoch0
epoch_time2.08681
test_loss0.0
train_loss11.28308
val_loss12.07008

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "loading from cache\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.3" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "ds lengths 2 2 6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0/1\n" ] }, { "data": { "text/html": [ "\n", "

Run history:


epoch
epoch_time
test_loss
train_loss
val_loss

Run summary:


epoch0
epoch_time5.29716
test_loss0.0
train_loss9.31578
val_loss10.75333

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from drugforge.data.util.utils import MOONSHOT_CDD_ID_REGEX, MPRO_ID_REGEX\n", "from drugforge.ml.config import (\n", " DataAugConfig,\n", " DatasetConfig,\n", " DatasetSplitterConfig,\n", " EarlyStoppingConfig,\n", " LossFunctionConfig,\n", " OptimizerConfig,\n", ")\n", "from drugforge.ml.trainer import Trainer\n", "from mtenn.config import E3NNModelConfig, GATModelConfig, SchNetModelConfig\n", "\n", "# We will use the Adam optimizer (default for OptimizerConfig) as well as all the\n", "# default parameters, so no need to pass anything here\n", "optimizer_config = OptimizerConfig()\n", "\n", "\"\"\"\n", "We will make 3 different model configs, one for each architecture. In this example we\n", "will build overly small models so that the training can be run in our docs runner, but\n", "in a real training example you should of course perform your own hyperparameter\n", "optimization. Any hyperparameter that is not specified will use the default of the\n", "underlying architecture.\n", "\n", "For the SchNet and e3nn models we will use a DeltaStrategy (default) and PIC50Readout\n", "(see the mtenn docs for more information).\n", "\"\"\"\n", "gat_model_config = GATModelConfig(num_layers=1)\n", "schnet_model_config = SchNetModelConfig(pred_readout=\"pic50\", num_interactions=1)\n", "# We recommend using Irreps of at least l=1, but will stick with l=0 for this example\n", "e3nn_model_config = E3NNModelConfig(pred_readout=\"pic50\", irreps_hidden={\"0\": 5})\n", "\n", "# We will need this configs to rebuild these models for inference, so we'll serialize\n", "# them to JSON files. This isn't strictly necessary, as this same information will be\n", "# stored in the Trainer JSON files that we'll save later, but this prevents us from\n", "# having to load in that whole file\n", "Path(\"gat_model_config.json\").write_text(gat_model_config.json())\n", "Path(\"schnet_model_config.json\").write_text(schnet_model_config.json())\n", "Path(\"e3nn_model_config.json\").write_text(e3nn_model_config.json())\n", "\n", "# We will use the default of no early stopping, but this can be configured as desired.\n", "# See the docs for drugforge.ml.early_stopping and EarlyStoppingConfig for more details\n", "es_config = EarlyStoppingConfig()\n", "\n", "\"\"\"\n", "The DatasetConfig requires a bit more explanation than the other configs, as it\n", "involves some data processing. Constructing a DatasetConfig object directly\n", "requires passing a list of drugforge.data.schema.ligand.Ligand (for 2D models)\n", "or a list of drugforge.data.schema.complex.Complex (for structure-based models).\n", "To get around having to parse the data from files yourself, we offer convenience\n", "functions DatasetConfig.from_exp_file (for 2D models) and DatasetConfig.from_str_files\n", "(for structure-based models). We will also need to create 3 of these configs, a 2D one\n", "for the GAT model, and 1 for each of the structural models, as the models expect\n", "slightly different inputs (this is all handled within the constructor methods).\n", "\n", "The from_exp_file method requires only an experimental data file, which is the file\n", "we generated in the previous step. The from_str_files method only requires this file\n", "if the dataset is being used for training, as this is where the method will pull data\n", "labels from. We will also need to pass in a glob or directory containing complex\n", "structure PDB files, as well as regular expressions defining what the crystal structure\n", "and compound IDs look like. We provide these regexes for the Moonshot data in\n", "drugforge.data.util.utils as MPRO_ID_REGEX and MOONSHOT_CDD_ID_REGEX respectively.\n", "If your structure files are formatted with differently, you will need to modify these\n", "expressions to match your data.\n", "\n", "One other important consideration with DatasetConfig is caching. The ultimate object\n", "that this config will build is one of the classes in drugforge.ml.dataset. These\n", "objects can take a while to create from files (especially the structure-based ones), so\n", "we offer the ability to cache the built object using pickle. This way, if the build\n", "method is called for a DatasetConfig that has a Path set for cache_file, that object\n", "will be loaded directly without having to regenerate it.\n", "\"\"\"\n", "gat_ds_config = DatasetConfig.from_exp_file(\n", " Path(\"cdd_filtered_processed.json\"), cache_file=Path(\"gat_ds_config.pkl\")\n", ")\n", "# We will assume the complex structure files are in the directory ./docking_results/\n", "# and follow the Moonshot naming scheme for crystal structures and compound IDs\n", "schnet_ds_config = DatasetConfig.from_str_files(\n", " structures=f\"{str(docked_dir)}/*.pdb\",\n", " xtal_regex=MPRO_ID_REGEX,\n", " cpd_regex=MOONSHOT_CDD_ID_REGEX,\n", " for_training=True,\n", " exp_file=Path(\"cdd_filtered_processed.json\"),\n", " cache_file=Path(\"schnet_ds_config.pkl\"),\n", ")\n", "e3nn_ds_config = DatasetConfig.from_str_files(\n", " structures=f\"{str(docked_dir)}/*.pdb\",\n", " xtal_regex=MPRO_ID_REGEX,\n", " cpd_regex=MOONSHOT_CDD_ID_REGEX,\n", " for_training=True,\n", " exp_file=Path(\"cdd_filtered_processed.json\"),\n", " cache_file=Path(\"e3nn_ds_config.pkl\"),\n", " for_e3nn=True,\n", ")\n", "\n", "# Split our molecules temporally, using an 80:10:10 split for train:val:test (default)\n", "ds_splitter_config = DatasetSplitterConfig(split_type=\"temporal\")\n", "\n", "# Use a semi-quantitative MSE loss function\n", "# (see drugforge.ml.loss docs for more information)\n", "loss_config = LossFunctionConfig(loss_type=\"mse_step\")\n", "\n", "\"\"\"\n", "Finally, we are ready to build our Trainer and start training. We will set a couple\n", "other options here, including logging to Weights & Biases. This functionality is\n", "optional, and can be avoided by simply not setting use_wandb=True, however we find this\n", "to be a useful way to track experiments. Note that you will first need to set up W&B on\n", "your machine (see their docs for how to get started). The only option other than the\n", "configs that is required to be set is output_dir.\n", "\n", "We are training here for 1 epoch (for docs purposes), with a mini-batch size of 25.\n", "The training will be done on the CPU (also for docs purposes), and will be saved to\n", "./_training/. We will log each training run to W&B, in a project named tutorial\n", "as a run named after the model.\n", "\"\"\"\n", "t_gat = Trainer(\n", " optimizer_config=optimizer_config,\n", " model_config=gat_model_config,\n", " es_config=es_config,\n", " ds_config=gat_ds_config,\n", " ds_splitter_config=ds_splitter_config,\n", " loss_config=loss_config,\n", " n_epochs=1,\n", " batch_size=25,\n", " device=\"cpu\",\n", " output_dir=\"./gat_training/\",\n", " use_wandb=True,\n", " wandb_project=\"tutorial\",\n", " wandb_name=\"gat\",\n", ")\n", "t_schnet = Trainer(\n", " optimizer_config=optimizer_config,\n", " model_config=schnet_model_config,\n", " es_config=es_config,\n", " ds_config=schnet_ds_config,\n", " ds_splitter_config=ds_splitter_config,\n", " loss_config=loss_config,\n", " n_epochs=1,\n", " batch_size=25,\n", " device=\"cpu\",\n", " output_dir=\"./schnet_training/\",\n", " use_wandb=True,\n", " wandb_project=\"tutorial\",\n", " wandb_name=\"schnet\",\n", ")\n", "t_e3nn = Trainer(\n", " optimizer_config=optimizer_config,\n", " model_config=e3nn_model_config,\n", " es_config=es_config,\n", " ds_config=e3nn_ds_config,\n", " ds_splitter_config=ds_splitter_config,\n", " loss_config=loss_config,\n", " n_epochs=1,\n", " batch_size=25,\n", " device=\"cpu\",\n", " output_dir=\"./e3nn_training/\",\n", " use_wandb=True,\n", " wandb_project=\"tutorial\",\n", " wandb_name=\"e3nn\",\n", ")\n", "\n", "# If desired we can save each of these Trainers as a JSON file, which will let us skip\n", "# all of the above steps next time we want to re-run this training or something\n", "# similar to it\n", "Path(\"gat_trainer.json\").write_text(t_gat.json())\n", "Path(\"schnet_trainer.json\").write_text(t_schnet.json())\n", "Path(\"e3nn_trainer.json\").write_text(t_e3nn.json())\n", "\n", "# Finally, we initialize each Trainer and start training. The initialization step\n", "# handles building all the underlying Python objects, as well as syncing with W&B.\n", "t_gat.initialize()\n", "t_gat.train()\n", "t_schnet.initialize()\n", "t_schnet.train()\n", "t_e3nn.initialize()\n", "t_e3nn.train()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }