Source code for drugforge.alchemy.schema.prep_workflow

import copy
from typing import Any, Literal, Optional, Union

import rich
from drugforge.alchemy.schema.base import _SchemaBase
from drugforge.alchemy.schema.charge import OpenFFCharges
from drugforge.data.operators.state_expanders.protomer_expander import (
    EpikExpander,
    ProtomerExpander,
)
from drugforge.data.operators.state_expanders.stereo_expander import StereoExpander
from drugforge.data.operators.state_expanders.tautomer_expander import TautomerExpander
from drugforge.data.schema.ligand import Ligand
from drugforge.docking.schema.pose_generation import (
    OpenEyeConstrainedPoseGenerator,
    RDKitConstrainedPoseGenerator,
)
from drugforge.modeling.schema import PreppedComplex
from pydantic import Field
from rich import pretty
from rich.padding import Padding


class _AlchemyPrepBase(_SchemaBase):
    """
    A base class for the Alchemy prep workflow to capture the settings for the factory and results objects.
    """

    type: Literal["_AlchemyPrepBase"] = "_AlchemyPrepBase"

    stereo_expander: Optional[StereoExpander] = Field(
        StereoExpander(),
        description="A class to expand the stereo"
        "chemistry of the ligands. This stage will be skipped if set to `None`.",
    )
    charge_expander: Optional[
        Union[EpikExpander, ProtomerExpander, TautomerExpander]
    ] = Field(
        None,
        description="The charge and tautomer expander that"
        "should be applied to the ligands. This stage will be skipped if set to `None`.",
    )
    pose_generator: Union[
        OpenEyeConstrainedPoseGenerator, RDKitConstrainedPoseGenerator
    ] = Field(
        RDKitConstrainedPoseGenerator(),
        description="The method "
        "to generate the initial poses for the molecules for FEC.",
    )
    core_smarts: Optional[str] = Field(
        None,
        description="The SMARTS string which should be used to identify the MCS between the "
        "input and reference ligand if not provided the MCS will be automatically generated. SMARTS strings can be created manually, or with e.g. ChemDraw or https://smarts.plus/.",
    )
    strict_stereo: bool = Field(
        True,
        description="Molecules will have conformers generated if their stereo chemistry matches the input molecule.",
    )
    n_references: int = Field(
        3,
        description="The number of experimental reference molecules we should try to generate "
        "poses for.",
    )
    charge_method: Optional[OpenFFCharges] = Field(
        OpenFFCharges(charge_method="am1bccelf10"),
        description="The method which should be used to charge the ligands locally.",
    )


[docs] class AlchemyDataSet(_AlchemyPrepBase): """ A dataset of prepared ligands ready for FEC generated by the AlchemyPrepWorkflow. """ type: Literal["AlchemyDataSet"] = "AlchemyDataSet" dataset_name: str = Field(..., description="The name of the dataset.") reference_complex: PreppedComplex = Field( ..., description="The prepared complex which was used in pose generation including the crystal reference ligand.", ) input_ligands: list[Ligand] = Field( ..., description="The list of ligands input to the workflow." ) posed_ligands: list[Ligand] = Field( ..., description="The list of Ligands with their generated poses." ) failed_ligands: Optional[dict[str, list[Ligand]]] = Field( None, description="A list of ligands removed from the workflow stored by the stage that removed them.", ) provenance: dict[str, dict[str, Any]] = Field( ..., description="The provenance information for each of the stages in the workflow stored by the stage name.", )
[docs] def save_posed_ligands(self, filename: str): """ Save the posed ligands to an SDF file using openeye. Parameters ---------- filename: The name of the SDF the ligands should be saved to. """ from drugforge.data.backend.openeye import save_openeye_sdfs oemols = [ligand.to_oemol() for ligand in self.posed_ligands] save_openeye_sdfs(oemols, filename)
[docs] class AlchemyPrepWorkflow(_AlchemyPrepBase): """ A factory to handle the state expansion and constrained pose generation used as inputs to the Alchemy workflow. """ type: Literal["AlchemyPrepWorkflow"] = "AlchemyPrepWorkflow" @staticmethod def _validate_ligands(ligands: list[Ligand]) -> list[Ligand]: """ For the given set of ligands make sure that the docked ligand is the intended ligand target i.e does the 3D stereo match what we intend at input. """ failed_ligands = [] for ligand in ligands: # check the original fixed inchikey against the current one if ligand.provenance.fixed_inchikey != ligand.fixed_inchikey: failed_ligands.append(ligand) return failed_ligands @staticmethod def _sort_similar_molecules( reference_ligand: Ligand, experimental_ligands: list[Ligand] ) -> list[Ligand]: """ Sort the list of experimental ligands by MCS overlap with the reference crystal ligand to determine the order in which the structures should be generated. Args: reference_ligand: The crystal structure ligand which will be the basis for the constrained pose generation. experimental_ligands: The list experimental ligands we would like to add to this dataset. Returns: The experimental_ligands in order of MCS overlap with the reference ligand """ import numpy as np from drugforge.docking.selectors.mcs_selector import sort_by_mcs # use the mcs code to get the ordered indices of the matches sort_idx = sort_by_mcs( reference_ligand=reference_ligand, target_ligands=experimental_ligands, structure_matching=False, ) ligands_sorted = np.asarray(experimental_ligands)[sort_idx] return ligands_sorted
[docs] def pose_experimental_molecules( self, reference_complex: PreppedComplex, experimental_ligands: list[Ligand], processors: int = 1, ) -> list[Ligand]: """ Iteratively try and generate poses for the experimental ligands until we have `self.n_references` posed. Args: reference_complex: The complex with the crystal structure which is used to constrain the generated poses. experimental_ligands: The list of experimental ligands ordered in list of priority. processors: The number of processor available to the pose generator. Returns: A list of posed experimental ligands. """ posed_refs = [] # run in batches so we don't try and generate poses for everything but run faster than serial batch_size = self.n_references * 2 for i in range(0, len(experimental_ligands), batch_size): ligand_batch = experimental_ligands[i : i + batch_size] poses = self.pose_generator.generate_poses( prepared_complex=reference_complex, ligands=experimental_ligands[i : i + batch_size], core_smarts=self.core_smarts, processors=processors, ) posed_ligands = poses.posed_ligands if self.strict_stereo: # remove the stereo issue molecules before checking how many have been posed stereo_fails = AlchemyPrepWorkflow._validate_ligands( ligands=posed_ligands ) posed_ligands = AlchemyPrepWorkflow._remove_fails( posed_ligands=posed_ligands, stereo_issue_ligands=stereo_fails ) # skip to the next batch if none were generated if not posed_ligands: continue posed_ligands_by_inchi = { ligand.provenance.fixed_inchikey: ligand for ligand in posed_ligands } # ligands are not in order so check them in the input ordering for ligand in ligand_batch: try: posed_refs.append( posed_ligands_by_inchi[ligand.provenance.fixed_inchikey] ) except KeyError: continue # stop if we have enough posed ligands if len(posed_refs) >= self.n_references: break # finally return either when we have enough or run out of ligands return posed_refs[: self.n_references]
@staticmethod def _remove_fails( posed_ligands: list[Ligand], stereo_issue_ligands: list[Ligand] ) -> list[Ligand]: """ A helper method to remove ligands from the posed list which are in the stereo issue list. Args: posed_ligands: A list of posed ligands which should be filtered. stereo_issue_ligands: The list of ligands with stereo issues which should be removed from the posed list. Returns: A list of posed ligands which have correct and consistent stereo. """ # we need to carefully remove the molecules from the posed_ligands list failed_hash = [ ligand.provenance.fixed_inchikey for ligand in stereo_issue_ligands ] final_ligands = [ mol for mol in posed_ligands if mol.provenance.fixed_inchikey not in failed_hash ] return final_ligands @staticmethod def _remove_charge_fails( posed_ligands: list[Ligand], charge_issue_ligands: list[Ligand] ) -> list[Ligand]: """ A helper method to remove ligands from the posed list which are in the charge issue list. Args: posed_ligands: A list of posed ligands which should be filtered. charge_issue_ligands: The list of ligands with charge issues which should be removed from the posed list. Returns: A list of posed ligands which have correct and consistent stereo. """ # we need to carefully remove the molecules from the posed_ligands list failed_hash = [ ligand.provenance.fixed_inchikey for ligand in charge_issue_ligands ] final_ligands = [ mol for mol in posed_ligands if mol.provenance.fixed_inchikey not in failed_hash ] return final_ligands @staticmethod def _deduplicate_experimental_ligands( posed_ligands: list[Ligand], experimental_ligands: list[Ligand] ) -> list[Ligand]: """ Remove duplicated ligands in the experimental list which have already been posed. Notes: This function marks the duplicated ligands in the posed list as experimental which helps with predictions later in the workflow. Args: posed_ligands: A list of posed ligands. experimental_ligands: A list of experimental ligands which can be posed. Returns: The deduplicated list of experimental ligands which should be posed. """ # find the protocol name so we can mark the experimental ligands protocol_name = experimental_ligands[0].tags.get("cdd_protocol") posed_ligand_by_hash = { ligand.provenance.fixed_inchikey: ligand for ligand in posed_ligands } final_exp_ligands = [] for ligand in experimental_ligands: ligand_hash = ligand.provenance.fixed_inchikey if ligand_hash not in posed_ligand_by_hash: final_exp_ligands.append(ligand) else: posed_ligand_by_hash[ligand_hash].tags.update( {"experimental": "True", "cdd_protocol": protocol_name} ) return final_exp_ligands
[docs] def create_alchemy_dataset( self, dataset_name: str, ligands: list[Ligand], reference_complex: PreppedComplex, processors: int = 1, reference_ligands: Optional[list[Ligand]] = None, ) -> AlchemyDataSet: """ Run the set of input ligands through the state enumeration and pose generation workflow to create a set of posed ligands ready for ASAP-Alchemy. Notes: Ligands with experimental data can be supplied via `reference_ligands`, poses will be generated until `self.n_references` have been successfully added. The ligands will be sorted by their MCS overlap with the crystal reference ligand to ensure a pose can be generated. Args: dataset_name: The name which should be given to this dataset. ligands: The list of input ligands which should be run through the workflow. reference_complex: The prepared target crystal structure with a reference ligand which the poses should be constrained to. processors: The number of parallel processors that should be used to run the workflow. reference_ligands: The list of reference ligands with experimental data which we should also generate poses for if `self.n_references` > 0. Returns: A prepared AlchemyDataset with state expanded ligands posed in the receptor ready for FEC, along with the provenance information of the workflow. """ # use rich to display progress pretty.install() console = rich.get_console() # deduplicate ligands first important for FEC networks? input_ligands = copy.deepcopy(ligands) provenance = {} failed_ligands = {} # Build the workflow we want to run workflow = [ stage for stage in ["stereo_expander", "charge_expander"] if getattr(self, stage) is not None ] # loop over each expansion stage and run for stage in workflow: expansion_engine = getattr(self, stage) stage_status = console.status( f"Running state expansion using {expansion_engine.expander_type}" ) stage_status.start() ligands = expansion_engine.expand(ligands) # log the software versions used provenance[expansion_engine.expander_type] = expansion_engine.provenance() stage_status.stop() console.print( f"[[green]✓[/green]] {expansion_engine.expander_type} successful, number of unique ligands {len(ligands)}." ) console.line() # now run the pose generation stage console.print( f"Generating constrained poses using {self.pose_generator.type} for {len(ligands)} ligands." ) # check for stereo in the reference ligand if reference_complex.ligand.has_perceived_stereo: console.print( "[yellow]! WARNING the reference structure is chiral, check output structures carefully! [/yellow]" ) console.line() pose_result = self.pose_generator.generate_poses( prepared_complex=reference_complex, ligands=ligands, core_smarts=self.core_smarts, processors=processors, ) posed_ligands = pose_result.posed_ligands provenance[self.pose_generator.type] = self.pose_generator.provenance() # save any failed ligands if pose_result.failed_ligands: failed_ligands[self.pose_generator.type] = pose_result.failed_ligands console.print( f"[[green]✓[/green]] Pose generation successful for {len(pose_result.posed_ligands)}/{len(ligands)}." ) console.line() if self.strict_stereo: stereo_status = console.status( "Removing molecules with inconsistent stereochemistry." ) stereo_status.start() stereo_fails = AlchemyPrepWorkflow._validate_ligands(ligands=posed_ligands) if stereo_fails: # add the new fails to the rest failed_ligands["InconsistentStereo"] = stereo_fails posed_ligands = AlchemyPrepWorkflow._remove_fails( posed_ligands=posed_ligands, stereo_issue_ligands=stereo_fails ) stereo_status.stop() console.print( f"[[green]✓[/green]] Stereochemistry filtering complete {len(stereo_fails)} molecules removed." ) console.line() if reference_ligands is not None and self.n_references > 0: # we need to check if any of the ligands we have already generated poses for are in the experimental list # if so mark them with the correct tags for later and remove them from this list filter_status = console.status("Removing duplicated reference ligands") filter_status.start() reference_ligands = AlchemyPrepWorkflow._deduplicate_experimental_ligands( posed_ligands=posed_ligands, experimental_ligands=reference_ligands ) filter_status.stop() if not reference_ligands: console.print("All experimental ligands removed!") sort_status = console.status("Sorting reference ligands by MCS overlap.") sort_status.start() sorted_exp_ligands = AlchemyPrepWorkflow._sort_similar_molecules( reference_ligand=reference_complex.ligand, experimental_ligands=reference_ligands, ) sort_status.stop() console.print( f"Generating constrained poses using {self.pose_generator.type} for {self.n_references} reference" " ligands." ) # use the wrapper to keep generating poses until we have the correct number posed_refs = self.pose_experimental_molecules( reference_complex=reference_complex, experimental_ligands=sorted_exp_ligands, processors=processors, ) console.print( f"[[green]✓[/green]] Pose generation successful for {len(posed_refs)}/{self.n_references} experimental " "ligands:" ) for ref_ligand in posed_refs: console.print( f"Injected ligand: {ref_ligand.compound_name}; SMILES: {ref_ligand.smiles}", ) posed_ligands.extend(posed_refs) message = Padding( f"Poses successfully generated for {len(posed_ligands)} ligands.", (1, 0, 1, 0), ) console.print(message) # Generate charges locally if requested if self.charge_method is not None: console.print(f"Generating charges locally using {self.charge_method}") posed_ligands, charge_fails = self.charge_method.generate_charges( ligands=posed_ligands, processors=processors ) if charge_fails: # add the new fails to the rest failed_ligands["ChargeFail"] = charge_fails posed_ligands = AlchemyPrepWorkflow._remove_charge_fails( posed_ligands=posed_ligands, charge_issue_ligands=charge_fails ) provenance[self.charge_method.type] = self.charge_method.provenance() message = Padding( f"[[green]✓[/green]] Charges successfully generated for {len(posed_ligands)} ligands.", (1, 0, 1, 0), ) console.print(message) # gather the results return AlchemyDataSet( **self.model_dump(exclude={"type"}), dataset_name=dataset_name, reference_complex=reference_complex, input_ligands=input_ligands, posed_ligands=posed_ligands, failed_ligands=failed_ligands if failed_ligands else None, provenance=provenance, )