diff --git a/ersilia/cli/commands/fetch.py b/ersilia/cli/commands/fetch.py index 94db41bd8..325f9e202 100644 --- a/ersilia/cli/commands/fetch.py +++ b/ersilia/cli/commands/fetch.py @@ -1,15 +1,18 @@ import click +import asyncio +import nest_asyncio from . import ersilia_cli from .. import echo from ...hub.fetch.fetch import ModelFetcher from ... import ModelBase +nest_asyncio.apply() def fetch_cmd(): """Create fetch commmand""" def _fetch(mf, model_id): - mf.fetch(model_id) + asyncio.run(mf.fetch(model_id)) # Example usage: ersilia fetch {MODEL} @ersilia_cli.command( diff --git a/ersilia/cli/commands/run.py b/ersilia/cli/commands/run.py index a8994ab5d..3b0b3cd81 100644 --- a/ersilia/cli/commands/run.py +++ b/ersilia/cli/commands/run.py @@ -25,7 +25,7 @@ def run_cmd(): @click.option( "--standard", is_flag=True, - default=False, + default=True, help="Assume that the run is standard and, therefore, do not do so many checks.", ) def run(input, output, batch_size, standard): diff --git a/ersilia/core/model.py b/ersilia/core/model.py index 925e4c0b5..4eabaab76 100644 --- a/ersilia/core/model.py +++ b/ersilia/core/model.py @@ -278,7 +278,7 @@ def _standard_api_runner(self, input, output): "Standard CSV Api runner is not ready for this particular model" ) return None - if not scra.is_amenable(input, output): + if not scra.is_amenable(output): self.logger.debug( "Standard CSV Api runner is not amenable for this model, input and output" ) diff --git a/ersilia/hub/fetch/fetch.py b/ersilia/hub/fetch/fetch.py index e3a9207ac..6bb666553 100644 --- a/ersilia/hub/fetch/fetch.py +++ b/ersilia/hub/fetch/fetch.py @@ -1,5 +1,3 @@ -"""Fetch Model from the Ersilia Model Hub.""" - import os import json import importlib @@ -13,6 +11,7 @@ NotInstallableWithFastAPI, NotInstallableWithBentoML, ) +from .register.standard_example import ModelStandardExample from ...utils.exceptions_utils.throw_ersilia_exception import throw_ersilia_exception from ...default import PACK_METHOD_BENTOML, PACK_METHOD_FASTAPI, EOS, MODEL_SOURCE_FILE from . import STATUS_FILE, DONE_TAG @@ -156,9 +155,13 @@ def _fetch_not_from_dockerhub(self, model_id): else: self.logger.debug("Model already exists in your local, skipping fetching") - def _fetch_from_dockerhub(self, model_id): + def _standard_csv_example(self, model_id): + ms = ModelStandardExample(model_id=model_id, config_json=self.config_json) + ms.run() + + async def _fetch_from_dockerhub(self, model_id): self.logger.debug("Fetching from DockerHub") - self.model_dockerhub_fetcher.fetch(model_id=model_id) + await self.model_dockerhub_fetcher.fetch(model_id=model_id) def _fetch_from_hosted(self, model_id): self.logger.debug("Fetching from hosted") @@ -213,14 +216,14 @@ def exists(self, model_id): return True else: return False - - def _fetch(self, model_id): + + async def _fetch(self, model_id): self.logger.debug("Starting fetching procedure") do_dockerhub = self._decide_if_use_dockerhub(model_id=model_id) if do_dockerhub: self.logger.debug("Decided to fetch from DockerHub") - self._fetch_from_dockerhub(model_id=model_id) + await self._fetch_from_dockerhub(model_id=model_id) return do_hosted = self._decide_if_use_hosted(model_id=model_id) if do_hosted: @@ -233,10 +236,14 @@ def _fetch(self, model_id): self.logger.debug("Fetching in your system, not from DockerHub") self._fetch_not_from_dockerhub(model_id=model_id) - def fetch(self, model_id): - self._fetch(model_id) + async def fetch(self, model_id): + await self._fetch(model_id) + self._standard_csv_example(model_id) self.logger.debug("Writing model source to file") model_source_file = os.path.join(self._model_path(model_id), MODEL_SOURCE_FILE) + try: + os.makedirs(self._model_path(model_id), exist_ok=True) + except OSError as error: + self.logger.error(f"Error during folder creation: {error}") with open(model_source_file, "w") as f: f.write(self.model_source) - diff --git a/ersilia/hub/fetch/lazy_fetchers/dockerhub.py b/ersilia/hub/fetch/lazy_fetchers/dockerhub.py index 78b04ebba..eb6affadb 100644 --- a/ersilia/hub/fetch/lazy_fetchers/dockerhub.py +++ b/ersilia/hub/fetch/lazy_fetchers/dockerhub.py @@ -1,5 +1,6 @@ import os import json +import asyncio from ..register.register import ModelRegisterer from .... import ErsiliaBase, throw_ersilia_exception @@ -16,13 +17,11 @@ from ....serve.services import PulledDockerImageService from ....setup.requirements.docker import DockerRequirement from ....utils.docker import SimpleDocker, resolve_pack_method_docker, PACK_METHOD_BENTOML -from ....utils.exceptions_utils.fetch_exceptions import DockerNotActiveError from .. import STATUS_FILE - class ModelDockerHubFetcher(ErsiliaBase): def __init__(self, overwrite=None, config_json=None): - ErsiliaBase.__init__(self, config_json=config_json, credentials_json=None) + super().__init__(config_json=config_json, credentials_json=None) self.simple_docker = SimpleDocker() self.overwrite = overwrite @@ -42,7 +41,7 @@ def is_available(self, model_id): return True return False - def write_apis(self, model_id): + async def write_apis(self, model_id): self.logger.debug("Writing APIs") di = PulledDockerImageService( model_id=model_id, config_json=self.config_json, preferred_port=None @@ -50,21 +49,24 @@ def write_apis(self, model_id): di.serve() di.close() - def _copy_from_bentoml_image(self, model_id, file): - fr_file = "/root/eos/dest/{0}/{1}".format(model_id, file) - to_file = "{0}/dest/{1}/{2}".format(EOS, model_id, file) - self.simple_docker.cp_from_image( - img_path=fr_file, - local_path=to_file, - org=DOCKERHUB_ORG, - img=model_id, - tag=DOCKERHUB_LATEST_TAG, - ) - - def _copy_from_ersiliapack_image(self, model_id, file): - fr_file = "/root/{0}".format(file) - to_file = "{0}/dest/{1}/{2}".format(EOS, model_id, file) - self.simple_docker.cp_from_image( + async def _copy_from_bentoml_image(self, model_id, file): + fr_file = f"/root/eos/dest/{model_id}/{file}" + to_file = f"{EOS}/dest/{model_id}/{file}" + try: + await self.simple_docker.cp_from_image( + img_path=fr_file, + local_path=to_file, + org=DOCKERHUB_ORG, + img=model_id, + tag=DOCKERHUB_LATEST_TAG, + ) + except Exception as e: + self.logger.error(f"Exception when copying: {e}") + + async def _copy_from_ersiliapack_image(self, model_id, file): + fr_file = f"/root/{file}" + to_file = f"{EOS}/dest/{model_id}/{file}" + await self.simple_docker.cp_from_image( img_path=fr_file, local_path=to_file, org=DOCKERHUB_ORG, @@ -72,30 +74,30 @@ def _copy_from_ersiliapack_image(self, model_id, file): tag=DOCKERHUB_LATEST_TAG, ) - def _copy_from_image_to_local(self, model_id, file): + async def _copy_from_image_to_local(self, model_id, file): pack_method = resolve_pack_method_docker(model_id) if pack_method == PACK_METHOD_BENTOML: - self._copy_from_bentoml_image(model_id, file) + await self._copy_from_bentoml_image(model_id, file) else: - self._copy_from_ersiliapack_image(model_id, file) + await self._copy_from_ersiliapack_image(model_id, file) - def copy_information(self, model_id): + async def copy_information(self, model_id): self.logger.debug("Copying information file from model container") - self._copy_from_image_to_local(model_id, INFORMATION_FILE) + await self._copy_from_image_to_local(model_id, INFORMATION_FILE) - def copy_metadata(self, model_id): + async def copy_metadata(self, model_id): self.logger.debug("Copying api_schema_file file from model container") - self._copy_from_image_to_local(model_id, API_SCHEMA_FILE) + await self._copy_from_image_to_local(model_id, API_SCHEMA_FILE) - def copy_status(self, model_id): + async def copy_status(self, model_id): self.logger.debug("Copying status file from model container") - self._copy_from_image_to_local(model_id, STATUS_FILE) - - def copy_example_if_available(self, model_id): - # TODO This also needs to change to accomodate ersilia pack + await self._copy_from_image_to_local(model_id, STATUS_FILE) + + async def copy_example_if_available(self, model_id): + # This needs to accommodate ersilia pack for pf in PREDEFINED_EXAMPLE_FILES: - fr_file = "/root/eos/dest/{0}/{1}".format(model_id, pf) - to_file = "{0}/dest/{1}/{2}".format(EOS, model_id, "input.csv") + fr_file = f"/root/eos/dest/{model_id}/{pf}" + to_file = f"{EOS}/dest/{model_id}/input.csv" try: self.simple_docker.cp_from_image( img_path=fr_file, @@ -108,7 +110,7 @@ def copy_example_if_available(self, model_id): except: self.logger.debug("Could not find example file in docker image") - def modify_information(self, model_id): + async def modify_information(self, model_id): """ Modify the information file being copied from docker container to the host machine. :param file: The model information file being copied. @@ -124,7 +126,7 @@ def modify_information(self, model_id): self.logger.error("Information file not found, not modifying anything") return None - # Using this literal here to prevent a file read + # Using this literal here to prevent a file read # from service class file for a model fetched through DockerHub # since we already know the service class. data["service_class"] = "pulled_docker" @@ -133,15 +135,20 @@ def modify_information(self, model_id): json.dump(data, outfile, indent=4) @throw_ersilia_exception - def fetch(self, model_id): + async def fetch(self, model_id): mp = ModelPuller(model_id=model_id, config_json=self.config_json) self.logger.debug("Pulling model image from DockerHub") - mp.pull() + # Asynchronous pulling + await mp.async_pull() mr = ModelRegisterer(model_id=model_id, config_json=self.config_json) - mr.register(is_from_dockerhub=True) - self.write_apis(model_id) - self.copy_information(model_id) - self.modify_information(model_id) - self.copy_metadata(model_id) - self.copy_status(model_id) - self.copy_example_if_available(model_id) + # Asynchronous and concurent execution + self.logger.debug("Asynchronous and concurrent execution started!") + await asyncio.gather( + mr.register(is_from_dockerhub=True), + self.write_apis(model_id), + self.copy_information(model_id), + self.modify_information(model_id), + self.copy_metadata(model_id), + self.copy_status(model_id), + self.copy_example_if_available(model_id) + ) \ No newline at end of file diff --git a/ersilia/hub/fetch/register/register.py b/ersilia/hub/fetch/register/register.py index 23c826a79..03ae34933 100644 --- a/ersilia/hub/fetch/register/register.py +++ b/ersilia/hub/fetch/register/register.py @@ -120,7 +120,7 @@ def register_not_from_hosted(self): with open(file_name, "w") as f: json.dump(data, f) - def register(self, is_from_dockerhub=False, is_from_hosted=False): + async def register(self, is_from_dockerhub=False, is_from_hosted=False): if is_from_dockerhub and is_from_hosted: raise Exception if is_from_dockerhub and not is_from_hosted: diff --git a/ersilia/hub/pull/pull.py b/ersilia/hub/pull/pull.py index ccb869f80..80a4a1668 100644 --- a/ersilia/hub/pull/pull.py +++ b/ersilia/hub/pull/pull.py @@ -4,7 +4,8 @@ import json import os import re - +import asyncio +import aiofiles from ... import ErsiliaBase from ...utils.terminal import yes_no_input, run_command from ... import throw_ersilia_exception @@ -84,6 +85,107 @@ def _get_size_of_local_docker_image_in_mb(self): self.logger.warning("Image not found locally") return None + @throw_ersilia_exception + async def async_pull(self): + if self.is_available_locally(): + if self.overwrite is None: + do_pull = yes_no_input( + "Requested image {0} is available locally. Do you still want to fetch it? [Y/n]".format( + self.model_id + ), + default_answer=PULL_IMAGE, + ) + elif self.overwrite: + do_pull = True + else: + do_pull = False + if not do_pull: + self.logger.info("Skipping pulling the image") + return + self._delete() + else: + self.logger.debug("Docker image of the model is not available locally") + if self.is_available_in_dockerhub(): + self.logger.debug( + "Pulling image {0} from DockerHub...".format(self.image_name) + ) + try: + self.logger.debug( + "Trying to pull image {0}/{1}".format(DOCKERHUB_ORG, self.model_id) + ) + tmp_file = os.path.join( + make_temp_dir(prefix="ersilia-"), "docker_pull.log" + ) + self.logger.debug("Keeping logs of pull in {0}".format(tmp_file)) + + # Construct the pull command + pull_command = f"docker pull {DOCKERHUB_ORG}/{self.model_id}:{DOCKERHUB_LATEST_TAG} > {tmp_file} 2>&1" + + # Use asyncio to run the pull command asynchronously + process = await asyncio.create_subprocess_shell( + pull_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + + # Wait for the command to complete + stdout, stderr = await process.communicate() + + # Handle output + if process.returncode != 0: + self.logger.error(f"Pull command failed: {stderr.decode()}") + raise subprocess.CalledProcessError(process.returncode, pull_command) + + self.logger.debug(stdout.decode()) + + # Reading log asynchronously + async with aiofiles.open(tmp_file, 'r') as f: + pull_log = await f.read() + self.logger.debug(pull_log) + + if re.search(r"no match.*manifest", pull_log): + self.logger.warning( + "No matching manifest for image {0}".format(self.model_id) + ) + raise DockerConventionalPullError(model=self.model_id) + + self.logger.debug("Image pulled successfully!") + + except DockerConventionalPullError: + self.logger.warning( + "Conventional pull did not work, Ersilia is now forcing linux/amd64 architecture" + ) + # Force platform specification pull command + force_pull_command = f"docker pull {DOCKERHUB_ORG}/{self.model_id}:{DOCKERHUB_LATEST_TAG} --platform linux/amd64" + + # Run forced pull asynchronously + process = await asyncio.create_subprocess_shell( + force_pull_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await process.communicate() + + if process.returncode != 0: + self.logger.error(f"Forced pull command failed: {stderr.decode()}") + raise subprocess.CalledProcessError(process.returncode, force_pull_command) + + self.logger.debug(stdout.decode()) + size = self._get_size_of_local_docker_image_in_mb() + if size: + self.logger.debug("Size of image {0} MB".format(size)) + # path = os.path.join(self._model_path(self.model_id), MODEL_SIZE_FILE) + # with open(path, "w") as f: + # json.dump({"size": size, "units": "MB"}, f, indent=4) + # self.logger.debug("Size written to {}".format(path)) + else: + self.logger.warning("Could not obtain size of image") + return size + else: + self.logger.info("Image {0} is not available".format(self.image_name)) + raise DockerImageNotAvailableError(model=self.model_id) + @throw_ersilia_exception def pull(self): if self.is_available_locally(): @@ -151,4 +253,4 @@ def pull(self): return size else: self.logger.info("Image {0} is not available".format(self.image_name)) - raise DockerImageNotAvailableError(model=self.model_id) + raise DockerImageNotAvailableError(model=self.model_id) \ No newline at end of file diff --git a/ersilia/serve/standard_api.py b/ersilia/serve/standard_api.py index 497ff02d4..3daa7784e 100644 --- a/ersilia/serve/standard_api.py +++ b/ersilia/serve/standard_api.py @@ -3,19 +3,21 @@ import json import importlib import requests - -from .. import ErsiliaBase +import asyncio +import nest_asyncio from ..store.api import InferenceStoreApi from ..store.utils import OutputSource +from .. import ErsiliaBase from ..default import ( EXAMPLE_STANDARD_INPUT_CSV_FILENAME, EXAMPLE_STANDARD_OUTPUT_CSV_FILENAME, ) -from ..default import INFORMATION_FILE +from ..default import INFORMATION_FILE, API_SCHEMA_FILE from ..default import DEFAULT_API_NAME MAX_INPUT_ROWS_STANDARD = 1000 +nest_asyncio.apply() class StandardCSVRunApi(ErsiliaBase): def __init__(self, model_id, url, config_json=None): @@ -40,6 +42,7 @@ def __init__(self, model_id, url, config_json=None): self.input_type = self.get_input_type() self.logger.debug("This is the input type: {0}".format(self.input_type)) self.encoder = self.get_identifier_object_by_input_type() + self.validate_smiles = self.get_identifier_object_by_input_type().validate_smiles self.header = self.get_expected_output_header(self.standard_output_csv) self.logger.debug( "This is the expected header (max 10): {0}".format(self.header[:10]) @@ -111,9 +114,25 @@ def is_input_standard_csv_file(self, input_data): return False def get_input_type(self): - with open(os.path.join(self.path, INFORMATION_FILE), "r") as f: - info = json.load(f) - return info["metadata"]["Input"] + try: + with open(os.path.join(self.path, INFORMATION_FILE), "r") as f: + info = json.load(f) + if "metadata" in info and "Input" in info["metadata"]: + return info["metadata"]["Input"] + elif "card" in info and "Input" in info["card"]: + return info["card"]["Input"] + else: + raise KeyError("Neither 'metadata' nor 'card' contains 'Input' key.") + + except FileNotFoundError: + self.logger.debug(f"Error: File '{INFORMATION_FILE}' not found in the path '{self.path}'") + except json.JSONDecodeError: + self.logger.debug(f"Error: Failed to parse JSON in file '{INFORMATION_FILE}'") + except KeyError as e: + self.logger.debug(f"Error: {e}") + except Exception as e: + self.logger.debug(f"An unexpected error occurred: {e}") + def is_input_type_standardizable(self): if len(self.input_type) != 1: @@ -121,36 +140,53 @@ def is_input_type_standardizable(self): return True def is_output_type_standardizable(self): - with open(os.path.join(self.path, INFORMATION_FILE), "r") as f: - api_schema = json.load(f)["api_schema"] - if DEFAULT_API_NAME not in api_schema: - return False - meta = api_schema[DEFAULT_API_NAME] - output_keys = meta["output"].keys() - if len(output_keys) != 1: - return False + api_schema_file_path = os.path.join(self.path, API_SCHEMA_FILE) + try: + with open(api_schema_file_path, "r") as f: + api_schema = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + return False + + meta = api_schema.get(DEFAULT_API_NAME) + if not meta or len(meta.get("output", {})) != 1: + return False + return True + def is_output_csv_file(self, output_data): if type(output_data) != str: return False if not output_data.endswith(".csv"): return False return True - + def get_expected_output_header(self, output_data): with open(output_data, "r") as f: reader = csv.reader(f) header = next(reader) return header + + def parse_smiles_list(self, input_data): + if not input_data or all(not s.strip() for s in input_data): + raise ValueError("The list of SMILES strings is empty or contains only empty strings.") + return [{'key': self.encoder.encode(smiles), 'input': smiles, 'text': smiles} for smiles in input_data if self.validate_smiles(smiles)] + + def parse_smiles_string(self, input): + if not self.validate_smiles(input): + raise ValueError("The SMILES string is invalid.") + key = self.encoder.encode(input) + return [{'key': key, 'input': input, 'text': input}] + def serialize_to_json_three_columns(self, input_data): json_data = [] with open(input_data, "r") as f: reader = csv.reader(f) - next(reader) - for r in reader: - json_data += [{"key": r[0], "input": r[1], "text": r[2]}] + next(reader) + for row in reader: + if self.validate_smiles(row[1]): + json_data += [{"key": row[0], "input": row[1], "text": row[2]}] return json_data def serialize_to_json_two_columns(self, input_data): @@ -158,8 +194,9 @@ def serialize_to_json_two_columns(self, input_data): with open(input_data, "r") as f: reader = csv.reader(f) next(reader) - for r in reader: - json_data += [{"key": r[0], "input": r[1], "text": r[1]}] + for row in reader: + if self.validate_smiles(row[1]): + json_data += [{"key": row[0], "input": row[1], "text": row[1]}] return json_data def serialize_to_json_one_column(self, input_data): @@ -167,42 +204,62 @@ def serialize_to_json_one_column(self, input_data): with open(input_data, "r") as f: reader = csv.reader(f) next(reader) - for r in reader: - key = self.encoder.encode(r[0]) - json_data += [{"key": key, "input": r[0], "text": r[0]}] + for row in reader: + if self.validate_smiles(row[0]): + key = self.encoder.encode(row[0]) + json_data += [{"key": key, "input": row[0], "text": row[0]}] + return json_data + + async def async_serialize_to_json_one_columns(self, input_data): + smiles_list = self.get_list_from_csv(input_data) + smiles_list = [smiles for smiles in smiles_list if self.validate_smiles(smiles)] + json_data = await self.encoder.encode_batch(smiles_list) return json_data + def get_list_from_csv(self, input_data): + smiles_list = [] + with open(input_data, mode='r') as file: + reader = csv.DictReader(file) + for row in reader: + smiles = row.get('input') + if smiles and smiles not in smiles_list and self.validate_smiles(smiles): + smiles_list.append(smiles) + return smiles_list + def serialize_to_json(self, input_data): - with open(input_data, "r") as f: - reader = csv.reader(f) - h = next(reader) - if len(h) == 1: - self.logger.debug("One column found in input") - return self.serialize_to_json_one_column(input_data=input_data) - elif len(h) == 2: - self.logger.debug("Two columns found in input") - return self.serialize_to_json_two_columns(input_data=input_data) - elif len(h) == 3: - self.logger.debug("Three columns found in input") - return self.serialize_to_json_three_columns(input_data=input_data) + if isinstance(input_data, str) and os.path.isfile(input_data): + with open(input_data, "r") as f: + reader = csv.reader(f) + h = next(reader) + if len(h) == 1: + self.logger.debug("One column found in input") + return asyncio.run(self.async_serialize_to_json_one_columns(input_data)) + elif len(h) == 2: + self.logger.debug("Two columns found in input") + return self.serialize_to_json_two_columns(input_data=input_data) + elif len(h) == 3: + self.logger.debug("Three columns found in input") + return self.serialize_to_json_three_columns(input_data=input_data) + else: + self.logger.info("More than three columns found in input! This is not standard.") + return None + elif isinstance(input_data, str): + return self.parse_smiles_string(input_data) + elif isinstance(input_data, list): + return self.parse_smiles_list(input_data) else: - self.logger.info( - "More than two columns found in input! This is not standard." - ) - return None - - def is_amenable(self, input_data, output_data): + raise ValueError("Input must be either a file path (string), a SMILES string, or a list of SMILES strings.") + + def is_amenable(self, output_data): if not self.is_input_type_standardizable(): return False if not self.is_output_type_standardizable(): return False - if not self.is_input_standard_csv_file(input_data): - return False if not self.is_output_csv_file(output_data): return False self.logger.debug("It seems amenable for standard run") return True - + def serialize_to_csv(self, input_data, result, output_data): k = list(result[0].keys())[0] v = result[0][k] @@ -212,8 +269,7 @@ def serialize_to_csv(self, input_data, result, output_data): is_list = False with open(output_data, "w") as f: writer = csv.writer(f) - if not os.path.exists(output_data): - writer.writerow(self.header) + writer.writerow(self.header) for i_d, r_d in zip(input_data, result): v = r_d[k] if not is_list: @@ -237,7 +293,6 @@ def post(self, input, output, output_source=OutputSource.LOCAL_ONLY): else: return None - class StandardQueryApi(object): def __init__(self, model_id, url): # TODO This class will be used to query directly the calculations lake. diff --git a/ersilia/utils/docker.py b/ersilia/utils/docker.py index 448df52dd..0bee3384d 100644 --- a/ersilia/utils/docker.py +++ b/ersilia/utils/docker.py @@ -246,7 +246,7 @@ def cp_from_container(name, img_path, local_path, org=None, img=None, tag=None): cmd = "docker cp %s:%s %s" % (name, img_path, local_path) run_command(cmd) - def cp_from_image(self, img_path, local_path, org, img, tag): + async def cp_from_image(self, img_path, local_path, org, img, tag): name = self.run(org, img, tag, name=None) self.cp_from_container(name, img_path, local_path, org=org, img=img, tag=tag) self.remove(name) diff --git a/ersilia/utils/identifiers/compound.py b/ersilia/utils/identifiers/compound.py index 65465567d..a68db6ae6 100644 --- a/ersilia/utils/identifiers/compound.py +++ b/ersilia/utils/identifiers/compound.py @@ -1,6 +1,10 @@ +import asyncio +import nest_asyncio +import aiohttp import urllib.parse import requests -import json +from functools import lru_cache +from ..logging import logger try: from chembl_webresource_client.unichem import unichem_client as unichem @@ -16,8 +20,10 @@ from ...default import UNPROCESSABLE_INPUT +nest_asyncio.apply() + class CompoundIdentifier(object): - def __init__(self, local=True): + def __init__(self, local=True, concurrency_limit=10, cache_maxsize=128): if local: self.Chem = Chem else: @@ -26,32 +32,40 @@ def __init__(self, local=True): self.default_type = "smiles" self.input_header_synonyms = set(["smiles", "input"]) self.key_header_synonyms = set(["inchikey", "key"]) + # The APIs have the worst rate limitation so its better not to increase more than 10 + self.concurrency_limit = concurrency_limit + self.cache_maxsize = cache_maxsize + # defining the cache to be dynamic + self._pubchem_smiles_to_inchikey = lru_cache(maxsize=self.cache_maxsize)( + self._pubchem_smiles_to_inchikey + ) + self._nci_smiles_to_inchikey = lru_cache(maxsize=self.cache_maxsize)( + self._nci_smiles_to_inchikey + ) + self.chemical_identifier_resolver = lru_cache(maxsize=self.cache_maxsize)( + self.chemical_identifier_resolver + ) + self.convert_smiles_to_inchikey_with_rdkit = lru_cache( + maxsize=self.cache_maxsize + )(self.convert_smiles_to_inchikey_with_rdkit) def is_input_header(self, h): - if h.lower() in self.input_header_synonyms: - return True - else: - return False + return h.lower() in self.input_header_synonyms def is_key_header(self, h): - if h.lower() in self.key_header_synonyms: - return True - else: - return False + return h.lower() in self.key_header_synonyms def _is_smiles(self, text): if self.Chem is None: - if self._pubchem_smiles_to_inchikey(text) is not None: - return True - else: - return False + return asyncio.run(self._process_pubchem_inchikey(text)) is not None else: mol = self.Chem.MolFromSmiles(text) - if mol is None: - return False - else: - return True + return mol is not None + async def _process_pubchem_inchikey(self, text): + async with aiohttp.ClientSession() as session: + return await self._pubchem_smiles_to_inchikey(session, text) + @staticmethod def _is_inchikey(text): if len(text) != 27: @@ -82,66 +96,144 @@ def unichem_resolver(self, inchikey): try: ret = self.unichem.inchiFromKey(inchikey) except: - return self.chemical_identifier_resolver(inchikey) + return None inchi = ret[0]["standardinchi"] mol = self.Chem.inchi.MolFromInchi(inchi) return self.Chem.MolToSmiles(mol) @staticmethod - def _nci_smiles_to_inchikey(smiles): - identifier = urllib.parse.quote(smiles) - url = "https://cactus.nci.nih.gov/chemical/structure/{0}/stdinchikey".format( + def chemical_identifier_resolver(identifier): + """Returns SMILES string of a given identifier, using NCI tool""" + if not identifier or not isinstance(identifier, str): + return UNPROCESSABLE_INPUT + identifier = urllib.parse.quote(identifier) + url = "https://cactus.nci.nih.gov/chemical/structure/{0}/smiles".format( identifier ) req = requests.get(url) if req.status_code != 200: return None - return req.text.split("=")[1] + return req.text @staticmethod - def _pubchem_smiles_to_inchikey(smiles): + async def _pubchem_smiles_to_inchikey(session, smiles): + """ + Fetch InChIKey for a single SMILES using PubChem API asynchronously. + The cache is used to store the results of the requests for unique SMILES. + """ identifier = urllib.parse.quote(smiles) - url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/{0}/property/InChIKey/json".format( - identifier - ) - req = requests.get(url) - if req.status_code != 200: + url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/{identifier}/property/InChIKey/json" + try: + async with session.get(url) as response: + if response.status != 200: + return None + data = await response.json() + return data["PropertyTable"]["Properties"][0]["InChIKey"] + except Exception as e: return None - return json.loads(req.text)["PropertyTable"]["Properties"][0]["InChIKey"] + @staticmethod - def chemical_identifier_resolver(identifier): - """Returns SMILES string of a given identifier, using NCI tool""" - if not identifier or not isinstance(identifier, str): - return UNPROCESSABLE_INPUT - identifier = urllib.parse.quote(identifier) - url = "https://cactus.nci.nih.gov/chemical/structure/{0}/smiles".format( + async def _nci_smiles_to_inchikey(session, smiles): + """ + Fetch InChIKey for a single SMILES using NCI asynchronously. + The cache is used to store the results of the requests for unique SMILES. + """ + identifier = urllib.parse.quote(smiles) + url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/{0}/property/InChIKey/json".format( identifier ) - req = requests.get(url) - if req.status_code != 200: + try: + async with session.get(url) as response: + if response.status != 200: + return None + text = await response.text() + return text.split("=")[1] + except Exception as e: + logger.info(f"Failed to fetch InChIKey from NCI for {smiles}: {e}") return None - return req.text + + def convert_smiles_to_inchikey_with_rdkit(self, smiles): + """ + Converts a SMILES string to an InChIKey using RDKit. + The results are cached to improve performance. + """ + if not self.Chem: + return None + + try: + mol = self.Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError(f"Invalid SMILES string: {smiles}") + inchi = self.Chem.rdinchi.MolToInchi(mol)[0] + inchikey = self.Chem.rdinchi.InchiToInchiKey(inchi) + return inchikey + except Exception as e: + logger.info(f"RDKit failed to convert SMILES {smiles}: {e}") + return None + + async def process_smiles(self, smiles, semaphore, session, result_list): + async with semaphore: # high performance resource manager + inchikey = self.convert_smiles_to_inchikey_with_rdkit(smiles) + + if inchikey is None: + inchikey = await self._pubchem_smiles_to_inchikey(session, smiles) + if inchikey: + logger.info("Inchikey converted using PUBCHEM") + + if inchikey is None: + inchikey = self._nci_smiles_to_inchikey(smiles) + if inchikey: + logger.info("Inchikey converted using NCI") + + if inchikey: + result_list.append({"key": inchikey, "input": smiles, "text": smiles}) + else: + logger.info(f"No InChIKey found for SMILES {smiles}. Skipping.") + + async def encode_batch(self, smiles_list): + result_list = [] + semaphore = asyncio.Semaphore(self.concurrency_limit) + async with aiohttp.ClientSession() as session: + tasks = [] + for _, smiles in enumerate(smiles_list): + tasks.append( + self.process_smiles(smiles, semaphore, session, result_list) + ) + + await asyncio.gather(*tasks) + + return result_list def encode(self, smiles): - """Get InChIKey of compound based on SMILES string""" - if not isinstance(smiles, str) or not smiles.strip() or smiles == UNPROCESSABLE_INPUT: - return UNPROCESSABLE_INPUT - - if self.Chem is None: - inchikey = self._pubchem_smiles_to_inchikey(smiles) or self._nci_smiles_to_inchikey(smiles) - else: - try: - mol = self.Chem.MolFromSmiles(smiles) - if mol is None: - return UNPROCESSABLE_INPUT - inchi = self.Chem.rdinchi.MolToInchi(mol)[0] - if inchi is None: - return UNPROCESSABLE_INPUT - inchikey = self.Chem.rdinchi.InchiToInchiKey(inchi) - except: - inchikey = None - - return inchikey if inchikey else UNPROCESSABLE_INPUT + """Get InChIKey of compound based on SMILES string""" + if not isinstance(smiles, str) or not smiles.strip() or smiles == UNPROCESSABLE_INPUT: + return UNPROCESSABLE_INPUT + + if self.Chem is None: + async def fetch_inchikeys(): + async with aiohttp.ClientSession() as session: + inchikey = await self._pubchem_smiles_to_inchikey(session, smiles) + if inchikey: + return inchikey + inchikey = await self._nci_smiles_to_inchikey(session, smiles) + return inchikey + + inchikey = asyncio.run(fetch_inchikeys()) + else: + try: + mol = self.Chem.MolFromSmiles(smiles) + if mol is None: + return UNPROCESSABLE_INPUT + inchi = self.Chem.rdinchi.MolToInchi(mol)[0] + if inchi is None: + return UNPROCESSABLE_INPUT + inchikey = self.Chem.rdinchi.InchiToInchiKey(inchi) + except: + inchikey = None + return inchikey if inchikey else UNPROCESSABLE_INPUT + + def validate_smiles(self, smiles): + return smiles.strip() != "" and Chem.MolFromSmiles(smiles) is not None -Identifier = CompoundIdentifier \ No newline at end of file +Identifier = CompoundIdentifier diff --git a/pyproject.toml b/pyproject.toml index 7ca233d23..f6cd9b95c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,20 +8,26 @@ readme = "README.md" homepage = "https://ersilia.io" repository = "https://github.com/ersilia-os/ersilia" documentation = "https://ersilia.io/model-hub" -keywords= ["drug-discovery", "machine-learning", "ersilia", "open-science", "global-health", "model-hub", "infectious-diseases"] -classifiers=[ - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", - "Operating System :: OS Independent", - "Topic :: Scientific/Engineering :: Artificial Intelligence", +keywords = [ + "drug-discovery", + "machine-learning", + "ersilia", + "open-science", + "global-health", + "model-hub", + "infectious-diseases", ] -packages = [ - {include = "ersilia"}, +classifiers = [ + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ] +packages = [{ include = "ersilia" }] include = [ "ersilia/hub/content/metadata/*.txt", "ersilia/io/types/examples/*.tsv", @@ -31,13 +37,11 @@ include = [ python = ">=3.8" inputimeout = "^1.0.4" emoji = "^2.8.0" -validators = [ - {version="~0.21.0", python=">=3.8"}, -] +validators = [{ version = "~0.21.0", python = ">=3.8" }] psutil = ">=5.9.0" -h5py = "^3.7.0" # For compatibility with isaura -loguru = "^0.6.0" # For compatibility with isaura +h5py = "^3.7.0" # For compatibility with isaura +loguru = "^0.6.0" # For compatibility with isaura PyYAML = "^6.0.1" dockerfile-parse = "^2.0.1" tqdm = "^4.66.1" @@ -46,18 +50,18 @@ docker = "^6.1.3" boto3 = "^1.28.40" requests = "<=2.31.0" numpy = "<=1.26.4" -setuptools = "^65.0.0" # added to fix the issue with setuptools -isaura = {version="0.1", optional=true} +setuptools = "^65.0.0" # added to fix the issue with setuptools +isaura = { version = "0.1", optional = true } aiofiles = "<=24.1.0" aiohttp = "<=3.10.9" -pytest = {version = "^7.4.0", optional = true} -pytest-asyncio = {version = "<=0.24.0", optional = true} -pytest-benchmark = {version = "<=4.0.0", optional = true} -fuzzywuzzy = {version = "^0.18.0", optional = true} -sphinx = {version = ">=6.0.0", optional = true} # for minimum version and support for Python 3.10 -jinja2 = {version = "^3.1.2", optional = true} -scipy = {version = "<=1.10.0", optional = true} - +nest_asyncio = "<=1.6.0" +pytest = { version = "^7.4.0", optional = true } +pytest-asyncio = { version = "<=0.24.0", optional = true } +pytest-benchmark = { version = "<=4.0.0", optional = true } +fuzzywuzzy = { version = "^0.18.0", optional = true } +sphinx = { version = ">=6.0.0", optional = true } # for minimum version and support for Python 3.10 +jinja2 = { version = "^3.1.2", optional = true } +scipy = { version = "<=1.10.0", optional = true } [tool.poetry.extras] # Instead of using poetry dependency groups, we use extras to make it pip installable diff --git a/test/test_models.py b/test/test_models.py index 2fc0ea842..69339da73 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,8 +1,9 @@ import os +import asyncio from ersilia.hub.fetch.fetch import ModelFetcher from ersilia import ErsiliaModel -MODELS = ["eos0t01", "eos0t02", "eos0t03", "eos0t04"] +MODELS = ["eos0t01", "eos3b5e", "eos0t03", "eos0t04"] def test_model_1(): @@ -21,12 +22,12 @@ def test_model_1(): def test_model_2(): MODEL_ID = MODELS[1] INPUT = "CCCC" - ModelFetcher(repo_path=os.path.join(os.getcwd(), "test/models", MODEL_ID)).fetch( + mf = ModelFetcher(overwrite=True) + asyncio.run(mf.fetch( MODEL_ID - ) + )) em = ErsiliaModel(MODEL_ID) em.serve() - em.predict(INPUT) em.close() assert 1 == 1