Source code for drugforge.workflows.postera.postera_uploader

import logging
from uuid import UUID

import numpy as np
import pandas as pd
from drugforge.data.backend.rdkit import rdkit_smiles_roundtrip
from drugforge.data.services.postera.molecule_set import MoleculeSetAPI, MoleculeSetKeys
from drugforge.data.services.services_config import PosteraSettings
from drugforge.docking.docking_data_validation import DockingResultCols
from pydantic import BaseModel, Field

logger = logging.getLogger(__name__)


[docs] class PosteraUploader(BaseModel): settings: PosteraSettings = Field(default_factory=PosteraSettings) molecule_set_name: str = Field( ..., description="Name of the molecule set to push to Postera, if it doesn't exist it will be created", ) id_field: str = Field( DockingResultCols.LIGAND_ID.value, description="Name of the column in the dataframe to use as the ligand id", ) smiles_field: str = Field( DockingResultCols.SMILES.value, description="Name of the column in the dataframe to use as the SMILES field", ) overwrite: bool = Field( False, description="Overwrite existing data on molecule set" )
[docs] def push( self, df: pd.DataFrame, sort_column: bool = None, sort_ascending: bool = False ) -> tuple[pd.DataFrame, UUID, bool]: """ Push molecules to a Postera molecule set Parameters ---------- df : DataFrame DataFrame of data to upload sort_column : str The column to sort the data by if duplicates are found sort_ascending : bool Whether the data should be sorted in ascending order Returns ------- DataFrame The input dataframe merged with the data from the molecule set, including UUIDs molecule_set_id : UUID The UUID of the molecule set new_molset : bool Whether a new molecule set was created """ if self.smiles_field not in df.columns: raise ValueError(f"smiles_field {self.smiles_field} not found in dataframe") if self.id_field not in df.columns: raise ValueError(f"id_field {self.id_field} not found in dataframe") ms_api = MoleculeSetAPI.from_settings(self.settings) data = df.copy() new_molset = False # if the molecule set doesn't exist, create it if not ms_api.exists(self.molecule_set_name, by="name"): logger.debug( f"molecule set {self.molecule_set_name} does not exist, creating new molecule set" ) molset_id = ms_api.create_molecule_set_from_df_with_manifold_validation( molecule_set_name=self.molecule_set_name, df=df, id_field=self.id_field, smiles_field=self.smiles_field, ) # get the new data including manifold UUIDs and join it with the original data new_molset = True else: # grab id of molecule set molset_id = ms_api.get_id_from_name(self.molecule_set_name) logger.debug( f"molecule set {self.molecule_set_name} exists with id {molset_id}, updating molecule set" ) if not self.id_data_is_uuid_castable(data, self.id_field): # we need to get the data from the manifold API and join it with the original data new_data = ms_api.get_molecules(molset_id, return_as="dataframe") data = self.join_with_manifold_data( data, new_data, smiles_field=self.smiles_field, id_field=self.id_field, ) data = self.remove_duplicates(data, sort_column, sort_ascending) # find rows with blank id, they need to be added to molset, using **add** endpoint rather than **update** has_blank_id_rows, blank_id_rows = self._check_for_blank_ids( data, self.id_field, raise_error=False ) logger.debug(f"data has_blank_id_rows: {has_blank_id_rows}") if has_blank_id_rows: ms_api.add_molecules_from_df_with_manifold_validation( molecule_set_id=molset_id, df=blank_id_rows, id_field=self.id_field, smiles_field=self.smiles_field, ) logger.debug( "appending to molecule set where some molecules have not been matched to an existing molecule in the molecule set, these ligands will be added to the molecule set" ) # find rows with a UUID, they need to be updated using the **update** endpoint uuid_rows = data[~data[self.id_field].isna()] ms_api.update_molecules_from_df_with_manifold_validation( molecule_set_id=molset_id, df=uuid_rows, id_field=self.id_field, smiles_field=self.smiles_field, overwrite=self.overwrite, ) else: # if the id data is castable to UUID, we can just update the molecule set # check for duplicates, removing them data = self.remove_duplicates(df, sort_column, sort_ascending) # check for blanks, raising self._check_for_blank_ids(data, self.id_field, raise_error=True) # ok to update the molecule set ms_api.update_molecules_from_df_with_manifold_validation( molecule_set_id=molset_id, df=data, id_field=self.id_field, smiles_field=self.smiles_field, overwrite=self.overwrite, ) new_data = ms_api.get_molecules(molset_id, return_as="dataframe") molset_name = ms_api.get_name_from_id(molset_id) return new_data, molset_name, new_molset
[docs] @staticmethod def join_with_manifold_data( original, molset_query_df, smiles_field, id_field, drop_no_uuid=False ): """ Join the original dataframe with manifold data that is returned from a query to the manifold API Parameters ---------- original : DataFrame The original dataframe molset_query_df : DataFrame The dataframe returned from a query to the manifold API smiles_field : str The name of the smiles field in the original dataframe id_field : str The name of the id field in the original dataframe drop_no_uuid : bool Whether to drop rows that don't have a UUID """ data = original.copy() subset = molset_query_df[ [MoleculeSetKeys.id.value, MoleculeSetKeys.smiles.value] ] # use rdkit here, to match postera backend which uses rdkit # provides better matching performance with smiles pulled down from postera # do a roundtrip to canonicalize the smiles subset.loc[:, MoleculeSetKeys.smiles.value] = subset[ MoleculeSetKeys.smiles.value ].apply(rdkit_smiles_roundtrip) # do the same to the original data data.loc[:, smiles_field] = data[smiles_field].apply(rdkit_smiles_roundtrip) # rename subset.rename( columns={MoleculeSetKeys.smiles.value: smiles_field}, inplace=True, ) # merge the data, outer join very important here to avoid dropping rows that are present in local data but not in manifold data = data.merge(subset, on=smiles_field, how="outer", suffixes=("", "_y")) data.drop(data.filter(regex="_y$").columns, axis=1, inplace=True) # drop original ID column and replace with the manifold ID if id_field != MoleculeSetKeys.id.value: data.drop(columns=id_field, inplace=True) data.rename( columns={MoleculeSetKeys.id.value: id_field}, inplace=True, ) if drop_no_uuid: data = data[~data[id_field].isna()] return data
[docs] @staticmethod def id_data_is_uuid_castable(df, id_field) -> bool: """ Check if the id data is castable to UUID Parameters ---------- df : DataFrame DataFrame of data to upload id_field : str Name of the column in the dataframe to use as the ligand id Returns ------- bool Whether the entire data column is castable to UUID """ try: df[id_field].apply(lambda x: UUID(x)) return True except: # noqa: E722 return False
@staticmethod def _check_for_duplicates( df, id_field, allow_empty=True, raise_error=False, sort_column=None, sort_ascending=False, ): """ Check for duplicate UUIDs in the dataframe Parameters ---------- df : DataFrame DataFrame of data to upload id_field : str Name of the column in the dataframe to use as the ligand id allow_empty : bool Whether to allow empty UUIDs to be exempt from the check raise_error : bool Whether to raise an error if duplicates are found Raises ------ ValueError If there are duplicate UUIDs """ df = df.copy() df = df.replace("", np.nan) if allow_empty: df = df[~df[id_field].isna()] if df[id_field].duplicated().any(): duplicates = df[df[id_field].duplicated()] num_duplicates = len(duplicates) if raise_error: raise ValueError(f"{num_duplicates} duplicate UUIDs found in dataframe") return True, duplicates else: return False, None @staticmethod def _check_for_blank_ids(df, id_field, raise_error=False): """ Check for blank UUIDs in the dataframe Parameters ---------- df : DataFrame DataFrame of data to upload id_field : str Name of the column in the dataframe to use as the ligand id Raises ------ ValueError If there are blank UUIDs """ df = df.copy() df = df.replace("", np.nan) if df[id_field].isna().any(): if raise_error: raise ValueError("Blank UUIDs found in dataframe") return True, df[df[id_field].isna()] else: return False, None
[docs] def remove_duplicates(self, data, sort_column, sort_ascending=False): """ Remove duplicates from the dataframe Parameters ---------- df : DataFrame DataFrame of data to upload id_field : str Name of the column in the dataframe to use as the ligand id sort_column : str The column to sort the data by if duplicates are found sort_ascending : bool Whether the data should be sorted in ascending order Returns ------- DataFrame The input dataframe with duplicates removed """ dup, _ = self._check_for_duplicates( data, self.id_field, allow_empty=True, raise_error=False ) if dup: if not sort_column: raise ValueError("sort_column must be provided if duplicates are found") if sort_column not in data.columns: raise ValueError(f"sort_column {sort_column} not found in dataframe") data = data.sort_values(by=sort_column, ascending=sort_ascending) data = data.drop_duplicates(subset=[self.id_field], keep="first") return data