Skip to content
152 changes: 151 additions & 1 deletion biomni/tool/database.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from __future__ import annotations

import json
import os
import pickle
import time
from typing import Any

# from typing import Any
from typing import TYPE_CHECKING, Any

import requests
from Bio.Blast import NCBIWWW, NCBIXML
from Bio.Seq import Seq
from langchain_core.messages import HumanMessage, SystemMessage

from biomni.llm import get_llm
from biomni.tool.kp_tool import KPClient
from biomni.utils import parse_hpo_obo

if TYPE_CHECKING:
from pathlib import Path


# Function to map HPO terms to names
def get_hpo_names(hpo_terms: list[str], data_lake_path: str) -> list[str]:
Expand Down Expand Up @@ -4972,3 +4980,145 @@ def query_encode(
api_result["result"] = _format_query_results(api_result["result"])

return api_result


_default_client = KPClient()


def ensure_fresh(
max_age_seconds: int | float | None = 24 * 3600,
force: bool = False,
cache_path: str | Path | None = None,
**_ignored_kwargs: Any,
) -> dict:
if cache_path is not None:
return KPClient(cache_path=cache_path, max_age_seconds=max_age_seconds).refresh(force=force)
return _default_client.refresh(max_age_seconds=max_age_seconds, force=force)


def list_kps(
max_age_seconds: int = 24 * 3600,
*,
cache_path: str | Path | None = None,
) -> dict:
if cache_path is not None:
return KPClient(cache_path=cache_path, max_age_seconds=max_age_seconds).list()
old = _default_client._max_age
_default_client._max_age = max_age_seconds
try:
return _default_client.list()
finally:
_default_client._max_age = old


def describe_kp(
kp_id: str,
max_age_seconds: int = 24 * 3600,
*,
cache_path: str | Path | None = None,
**_ignored_kwargs: Any,
) -> dict:
if cache_path is not None:
return KPClient(cache_path=cache_path, max_age_seconds=max_age_seconds).describe(kp_id)
old = _default_client._max_age
_default_client._max_age = max_age_seconds
try:
return _default_client.describe(kp_id)
finally:
_default_client._max_age = old


def query_kp(
kp_id: str,
q: Any,
*,
normalize: bool = False,
include_raw_hits: bool = False,
max_age_seconds: int = 24 * 3600,
cache_path: str | Path | None = None,
timeout_s: float = 30.0,
) -> dict[str, Any]:
if cache_path is not None:
return KPClient(cache_path=cache_path, max_age_seconds=max_age_seconds).query(
kp_id,
q,
normalize=normalize,
include_raw_hits=include_raw_hits,
timeout_s=timeout_s,
)
old = _default_client._max_age
_default_client._max_age = max_age_seconds
try:
return _default_client.query(
kp_id,
q,
normalize=normalize,
include_raw_hits=include_raw_hits,
timeout_s=timeout_s,
)
finally:
_default_client._max_age = old


def batch_query_kp(
kp_id: str,
queries: Any,
*,
normalize: bool = False,
include_raw_hits: bool = False,
max_age_seconds: int = 24 * 3600,
cache_path: str | Path | None = None,
timeout_s: float = 30.0,
) -> dict[str, Any]:
if cache_path is not None:
return KPClient(cache_path=cache_path, max_age_seconds=max_age_seconds).query_batch(
kp_id,
queries,
normalize=normalize,
include_raw_hits=include_raw_hits,
timeout_s=timeout_s,
)
old = _default_client._max_age
_default_client._max_age = max_age_seconds
try:
return _default_client.query_batch(
kp_id,
queries,
normalize=normalize,
include_raw_hits=include_raw_hits,
timeout_s=timeout_s,
)
finally:
_default_client._max_age = old


def scroll_kp(
kp_id: str,
*,
q: str | None = None,
scroll_id: str | None = None,
size: int = 100,
timeout_s: float = 30.0,
max_age_seconds: int = 24 * 3600,
cache_path: str | Path | None = None,
) -> dict[str, Any]:
if cache_path is not None:
return KPClient(cache_path=cache_path, max_age_seconds=max_age_seconds).scroll(
kp_id,
q=q,
scroll_id=scroll_id,
size=size,
timeout_s=timeout_s,
)
old = _default_client._max_age
_default_client._max_age = max_age_seconds
try:
return _default_client.scroll(
kp_id,
q=q,
scroll_id=scroll_id,
size=size,
timeout_s=timeout_s,
)
finally:
_default_client._max_age = old
Loading