Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embeddings search experimental API #1164

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
3 changes: 3 additions & 0 deletions api/python/cellxgene_census/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ experimental = [
"psutil~=5.0",
"datasets~=2.0",
"tdigest~=0.5",
# choose newest version of tiledb-vector-search that doesn't need a newer version of tiledb
# than tiledbsoma: https://github.com/TileDB-Inc/TileDB-Vector-Search/blob/0.2.2/pyproject.toml
"tiledb-vector-search~=0.2",
# Not expressible in pyproject.toml:
#"git+https://huggingface.co/ctheodoris/Geneformer",
# instead, experimental/ml/geneformer_tokenizer.py catches ImportError to ask user to install that.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class CensusMirror(TypedDict):
provider: Provider
base_uri: str
region: Optional[str]
embeddings_base_uri: str


CensusMirrors = Dict[CensusMirrorName, Union[CensusMirrorName, CensusMirror]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
get_embedding_metadata,
get_embedding_metadata_by_name,
)
from ._embedding_search import NeighborObs, find_nearest_obs, predict_obs_metadata

__all__ = [
"get_embedding",
"get_embedding_metadata",
"get_embedding_metadata_by_name",
"get_all_available_embeddings",
"get_all_census_versions_with_embedding",
"find_nearest_obs",
"NeighborObs",
"predict_obs_metadata",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Nearest-neighbor search based on vector index of Census embeddings."""

from contextlib import ExitStack
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, cast

import anndata as ad
import numpy as np
import numpy.typing as npt
import pandas as pd
import tiledb.vector_search as vs
import tiledbsoma as soma

from .._experiment import _get_experiment_name
from .._open import DEFAULT_TILEDB_CONFIGURATION, open_soma
from .._release_directory import CensusMirror, _get_census_mirrors
from .._util import _uri_join
from ._embedding import get_embedding_metadata_by_name


class NeighborObs(NamedTuple):
"""Results of nearest-neighbor search for Census obs embeddings."""

distances: npt.NDArray[np.float32]
"""
Distances to the nearest neighbors for each query obs embedding (q by k, where q is the number
of query embeddings and k is the desired number of neighbors). The distance metric is
implementation-dependent.
"""

neighbor_ids: npt.NDArray[np.int64]
"""
obs soma_joinid's of the nearest neighbors for each query embedding (q by k).
"""


def find_nearest_obs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the API side, it would be nice if this could produce output that can be directly with sklearn style classes. For example, if this returned a KNNTransformer subclass, that could be used directly with the KNeighborsClassifier and KNeighborsRegressor classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ivirshup I like this idea very much, but I'm not quite sure it's workable (albeit I'm not as familiar with those APIs)...

Those scikit-learn classes seem oriented around the scenario where you're providing either all the points (in the "universe") or the complete distance matrix for them. Here we're working with a more limited view of the query points and their neighbor distances; we don't have or want the complete distance matrix, and actually we don't even have the coordinates of the neighbors immediately handy.

Do you think the shoe fits? I see there's some stuff about the "K neighbors graph" that might be relevant, but I'm not personally familiar enough to use them in an unconventional/advanced way like this.

embedding_name: str,
organism: str,
census_version: str,
query: ad.AnnData,
*,
k: int = 10,
nprobe: int = 100,
memory_GiB: int = 4,
mirror: Optional[str] = None,
embedding_metadata: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> NeighborObs:
"""Search Census for similar obs (cells) based on nearest neighbors in embedding space.

Args:
embedding_name, organism, census_version:
Identify the embedding to search, as in :func:`get_embedding_metadata_by_name`.
query:
AnnData object with an obsm layer embedding the query cells. The obsm layer name
matches ``embedding_metadata["embedding_name"]`` (e.g. scvi, geneformer). The layer
shape matches the number of query cells and the number of features in the embedding.
k:
Number of nearest neighbors to return for each query obs.
nprobe:
Sensitivity parameter; defaults to 100 (roughly N^0.25 where N is the number of Census
cells) for a thorough search. Decrease for faster but less accurate search.
memory_GiB:
Memory budget for the search index, in gibibytes; defaults to 4 GiB.
mirror:
Name of the Census mirror to use for the search.
embedding_metadata:
The result of `get_embedding_metadata_by_name(embedding_name, organism, census_version)`.
Supplying this saves a network request for repeated searches.
"""
if embedding_metadata is None:
embedding_metadata = get_embedding_metadata_by_name(embedding_name, organism, census_version)
assert embedding_metadata["embedding_name"] == embedding_name
n_features = embedding_metadata["n_features"]

# validate query (expected obsm layer exists with the expected dimensionality)
if embedding_name not in query.obsm:
raise ValueError(f"Query does not have the expected layer {embedding_name}")
if query.obsm[embedding_name].shape[1] != n_features:
raise ValueError(
f"Query embedding {embedding_name} has {query.obsm[embedding_name].shape[1]} features, expected {n_features}"
)

# formulate index URI and run query
resolved_index = _resolve_embedding_index(embedding_metadata, mirror=mirror)
if not resolved_index:
raise ValueError("No suitable embedding index found for " + embedding_name)
index_uri, index_region = resolved_index
config = {k: str(v) for k, v in DEFAULT_TILEDB_CONFIGURATION.items()}
config["vfs.s3.region"] = index_region
memory_vectors = memory_GiB * (2**30) // (4 * n_features) # number of float32 vectors
index = vs.ivf_flat_index.IVFFlatIndex(uri=index_uri, config=config, memory_budget=memory_vectors)
distances, neighbor_ids = index.query(query.obsm[embedding_name], k=k, nprobe=nprobe, **kwargs)

return NeighborObs(distances=distances, neighbor_ids=neighbor_ids)


def _resolve_embedding_index(
embedding_metadata: Dict[str, Any],
mirror: Optional[str] = None,
) -> Optional[Tuple[str, str]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebezzi new index resolution method here

index_metadata = embedding_metadata.get("indexes", None)
if not index_metadata:
return None
# TODO (future): support multiple index [types]
assert index_metadata[0]["type"] == "IVFFlat", "Only IVFFlat index is supported (update cellxgene_census)"
mirrors = _get_census_mirrors()
mirror = mirror or cast(str, mirrors["default"])
mirror_info = cast(CensusMirror, mirrors[mirror])
uri = _uri_join(mirror_info["embeddings_base_uri"], index_metadata[0]["relative_uri"])
return uri, cast(str, mirror_info["region"])


def predict_obs_metadata(
organism: str,
census_version: str,
neighbors: NeighborObs,
column_names: Sequence[str],
experiment: Optional[soma.Experiment] = None,
) -> pd.DataFrame:
"""Predict obs metadata attributes for the query cells based on the embedding nearest neighbors.

Args:
organism, census_version:
Embedding information as supplied to :func:`find_nearest_obs`.
neighbors:
Results of a :func:`find_nearest_obs` search.
column_names:
Desired obs metadata column names. The current implementation is suitable for
categorical attributes (e.g. cell_type, tissue_general).
experiment:
Open handle for the relevant SOMAExperiment, if available (otherwise, will be opened
internally). e.g. ``census["census_data"]["homo_sapiens"]`` with the relevant Census
version.

Returns:
Pandas DataFrame with the desired column predictions. Additionally, for each predicted
column ``col``, an additional column ``col_confidence`` with a confidence score between 0
and 1.
"""
with ExitStack() as cleanup:
if experiment is None:
# open Census transiently
census = cleanup.enter_context(open_soma(census_version=census_version))
experiment = census["census_data"][_get_experiment_name(organism)]

# fetch the desired obs metadata for all of the found neighbors
neighbor_obs = (
experiment.obs.read(
coords=(neighbors.neighbor_ids.flatten(),), column_names=(["soma_joinid"] + list(column_names))
)
.concat()
.to_pandas()
).set_index("soma_joinid")

# step through query cells to generate prediction for each column as the plurality value
# found among its neighbors, with a confidence score based on the simple fraction (for now)
# TODO: something more intelligent for numeric columns! also use distances, etc.
out: Dict[str, List[Any]] = {}
for i in range(neighbors.neighbor_ids.shape[0]):
neighbors_i = neighbor_obs.loc[neighbors.neighbor_ids[i]]
for col in column_names:
col_value_counts = neighbors_i[col].value_counts(normalize=True)
out.setdefault(col, []).append(col_value_counts.idxmax())
out.setdefault(col + "_confidence", []).append(col_value_counts.max())
mlin marked this conversation as resolved.
Show resolved Hide resolved

return pd.DataFrame.from_dict(out)
Loading