Inference

Prerequisites

Before starting this tutorial, you should have already worked through the tutorial on Running ML inference.

We will use publicly available test files for the model weights and inference files for this tutorial, but feel free to substitute with your own data.

[1]:
# First download the needed files
from drugforge.data.testing.test_resources import fetch_test_file

docked_file = fetch_test_file(
    "ml_testing/docked/AAR-POS-5507155c-1_Mpro-P0018_0A_0_bound_best.pdb"
)
smi_file = fetch_test_file("Mpro_combined_labeled.smi")
pred_smiles = smi_file.read_text().split()[0]
print(pred_smiles)
CC(=O)NNC(=O)c1cc2CCCCc2s1

Intro

In this guide, we will start with an existing model weights and model config file, which should both have been generated in the previous tutorial (or can be downloaded as above).

We will first build the model spec object, and then the Inference model object. We will show examples of this using a local model weights file, or pulling one of our available pre-trained models. For the inference, we will pass the model a complex structure PDB file or SMILES string, depending on the model type.

Building the Inference model from a local weights file

The first way that we will show to build an Inference model is using local weights and config files. We will first define the drugforge.ml.models.LocalMLModelSpec, and then use that to construct the Inference model.

[2]:
from drugforge.ml.inference import E3nnInference, GATInference, SchnetInference
from drugforge.ml.models import LocalMLModelSpec
from datetime import datetime
from pathlib import Path

# Define weights and config files
gat_wts_path = next(iter(Path("gat_training/").glob("*/final.th")))
gat_config_path = Path("gat_model_config.json")
schnet_wts_path = next(iter(Path("schnet_training/").glob("*/final.th")))
schnet_config_path = Path("schnet_model_config.json")
e3nn_wts_path = next(iter(Path("e3nn_training/").glob("*/final.th")))
e3nn_config_path = Path("e3nn_model_config.json")

# Build LocalMLModelSpecs and Inference models
gat_model_spec = LocalMLModelSpec(
    name="GAT",
    type="GAT",
    last_updated=datetime.today(),
    targets={"SARS-CoV-2-Mpro"},
    weights_file=gat_wts_path,
    config_file=gat_config_path,
)
gat_inf_model = GATInference.from_local_model_spec(gat_model_spec)
schnet_model_spec = LocalMLModelSpec(
    name="SchNet",
    type="schnet",
    last_updated=datetime.today(),
    targets={"SARS-CoV-2-Mpro"},
    weights_file=schnet_wts_path,
    config_file=schnet_config_path,
)
schnet_inf_model = SchnetInference.from_local_model_spec(schnet_model_spec)
e3nn_model_spec = LocalMLModelSpec(
    name="e3nn",
    type="e3nn",
    last_updated=datetime.today(),
    targets={"SARS-CoV-2-Mpro"},
    weights_file=e3nn_wts_path,
    config_file=e3nn_config_path,
)
e3nn_inf_model = E3nnInference.from_local_model_spec(e3nn_model_spec)

# Make predictions
gat_pred = gat_inf_model.predict_from_smiles(pred_smiles)
schnet_pred = schnet_inf_model.predict_from_structure_file(docked_file)
e3nn_pred = e3nn_inf_model.predict_from_structure_file(docked_file)

print("GAT prediction:", gat_pred)
print("SchNet prediction:", schnet_pred)
print("e3nn prediction:", e3nn_pred)
GAT prediction: 2.397318124771118
SchNet prediction: -0.17303919792175293
e3nn prediction: 0.01794694922864437

Building the Inference model from ASAP-trained models

We also offer several models that we have pre-trained on various Moonshot/ASAP data. In this section we simply specify the ASAP target that we want to pull the model for, and drugforge takes care of pulling the necessary files, and building the Inference model.

[3]:
from drugforge.ml.inference import E3nnInference, GATInference, SchnetInference

# Get latest model trained on SARS-CoV-2 Mpro data for each architecture
gat_inf_model = GATInference.from_latest_by_target("SARS-CoV-2-Mpro")
schnet_inf_model = SchnetInference.from_latest_by_target("SARS-CoV-2-Mpro")
e3nn_inf_model = E3nnInference.from_latest_by_target("SARS-CoV-2-Mpro")

# Make predictions
gat_pred = gat_inf_model.predict_from_smiles(pred_smiles)
schnet_pred = schnet_inf_model.predict_from_structure_file(docked_file)
e3nn_pred = e3nn_inf_model.predict_from_structure_file(docked_file)

print("GAT prediction:", gat_pred)
print("SchNet prediction:", schnet_pred)
print("e3nn prediction:", e3nn_pred)
GAT prediction: 4.059902191162109
SchNet prediction: 1.3631529808044434
e3nn prediction: 3.001865863800049