-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Embeddings search experimental API #1164
Open
mlin
wants to merge
11
commits into
main
Choose a base branch
from
mlin/similarity-search-api
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 10 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
639e64c
squash for PR
mlin 8088f9e
use DEFAULT_TILEDB_CONFIGURATION
mlin fc91d2d
workaround
mlin c8bb01a
workaround
mlin e73102b
fix
mlin fcc05b4
Merge remote-tracking branch 'origin/main' into mlin/similarity-searc…
mlin ca8d44e
resolve indexes through JSONs
mlin ee6c184
lint
mlin a1e4daa
API refactoring
mlin 874b2eb
Merge remote-tracking branch 'origin/main' into mlin/similarity-searc…
mlin c33df3f
Merge remote-tracking branch 'origin/main' into mlin/similarity-searc…
mlin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
167 changes: 167 additions & 0 deletions
167
api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
"""Nearest-neighbor search based on vector index of Census embeddings.""" | ||
|
||
from contextlib import ExitStack | ||
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, cast | ||
|
||
import anndata as ad | ||
import numpy as np | ||
import numpy.typing as npt | ||
import pandas as pd | ||
import tiledb.vector_search as vs | ||
import tiledbsoma as soma | ||
|
||
from .._experiment import _get_experiment_name | ||
from .._open import DEFAULT_TILEDB_CONFIGURATION, open_soma | ||
from .._release_directory import CensusMirror, _get_census_mirrors | ||
from .._util import _uri_join | ||
from ._embedding import get_embedding_metadata_by_name | ||
|
||
|
||
class NeighborObs(NamedTuple): | ||
"""Results of nearest-neighbor search for Census obs embeddings.""" | ||
|
||
distances: npt.NDArray[np.float32] | ||
""" | ||
Distances to the nearest neighbors for each query obs embedding (q by k, where q is the number | ||
of query embeddings and k is the desired number of neighbors). The distance metric is | ||
implementation-dependent. | ||
""" | ||
|
||
neighbor_ids: npt.NDArray[np.int64] | ||
""" | ||
obs soma_joinid's of the nearest neighbors for each query embedding (q by k). | ||
""" | ||
|
||
|
||
def find_nearest_obs( | ||
embedding_name: str, | ||
organism: str, | ||
census_version: str, | ||
query: ad.AnnData, | ||
*, | ||
k: int = 10, | ||
nprobe: int = 100, | ||
memory_GiB: int = 4, | ||
mirror: Optional[str] = None, | ||
embedding_metadata: Optional[Dict[str, Any]] = None, | ||
**kwargs: Dict[str, Any], | ||
) -> NeighborObs: | ||
"""Search Census for similar obs (cells) based on nearest neighbors in embedding space. | ||
|
||
Args: | ||
embedding_name, organism, census_version: | ||
Identify the embedding to search, as in :func:`get_embedding_metadata_by_name`. | ||
query: | ||
AnnData object with an obsm layer embedding the query cells. The obsm layer name | ||
matches ``embedding_metadata["embedding_name"]`` (e.g. scvi, geneformer). The layer | ||
shape matches the number of query cells and the number of features in the embedding. | ||
k: | ||
Number of nearest neighbors to return for each query obs. | ||
nprobe: | ||
Sensitivity parameter; defaults to 100 (roughly N^0.25 where N is the number of Census | ||
cells) for a thorough search. Decrease for faster but less accurate search. | ||
memory_GiB: | ||
Memory budget for the search index, in gibibytes; defaults to 4 GiB. | ||
mirror: | ||
Name of the Census mirror to use for the search. | ||
embedding_metadata: | ||
The result of `get_embedding_metadata_by_name(embedding_name, organism, census_version)`. | ||
Supplying this saves a network request for repeated searches. | ||
""" | ||
if embedding_metadata is None: | ||
embedding_metadata = get_embedding_metadata_by_name(embedding_name, organism, census_version) | ||
assert embedding_metadata["embedding_name"] == embedding_name | ||
n_features = embedding_metadata["n_features"] | ||
|
||
# validate query (expected obsm layer exists with the expected dimensionality) | ||
if embedding_name not in query.obsm: | ||
raise ValueError(f"Query does not have the expected layer {embedding_name}") | ||
if query.obsm[embedding_name].shape[1] != n_features: | ||
raise ValueError( | ||
f"Query embedding {embedding_name} has {query.obsm[embedding_name].shape[1]} features, expected {n_features}" | ||
) | ||
|
||
# formulate index URI and run query | ||
resolved_index = _resolve_embedding_index(embedding_metadata, mirror=mirror) | ||
if not resolved_index: | ||
raise ValueError("No suitable embedding index found for " + embedding_name) | ||
index_uri, index_region = resolved_index | ||
config = {k: str(v) for k, v in DEFAULT_TILEDB_CONFIGURATION.items()} | ||
config["vfs.s3.region"] = index_region | ||
memory_vectors = memory_GiB * (2**30) // (4 * n_features) # number of float32 vectors | ||
index = vs.ivf_flat_index.IVFFlatIndex(uri=index_uri, config=config, memory_budget=memory_vectors) | ||
distances, neighbor_ids = index.query(query.obsm[embedding_name], k=k, nprobe=nprobe, **kwargs) | ||
|
||
return NeighborObs(distances=distances, neighbor_ids=neighbor_ids) | ||
|
||
|
||
def _resolve_embedding_index( | ||
embedding_metadata: Dict[str, Any], | ||
mirror: Optional[str] = None, | ||
) -> Optional[Tuple[str, str]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ebezzi new index resolution method here |
||
index_metadata = embedding_metadata.get("indexes", None) | ||
if not index_metadata: | ||
return None | ||
# TODO (future): support multiple index [types] | ||
assert index_metadata[0]["type"] == "IVFFlat", "Only IVFFlat index is supported (update cellxgene_census)" | ||
mirrors = _get_census_mirrors() | ||
mirror = mirror or cast(str, mirrors["default"]) | ||
mirror_info = cast(CensusMirror, mirrors[mirror]) | ||
uri = _uri_join(mirror_info["embeddings_base_uri"], index_metadata[0]["relative_uri"]) | ||
return uri, cast(str, mirror_info["region"]) | ||
|
||
|
||
def predict_obs_metadata( | ||
organism: str, | ||
census_version: str, | ||
neighbors: NeighborObs, | ||
column_names: Sequence[str], | ||
experiment: Optional[soma.Experiment] = None, | ||
) -> pd.DataFrame: | ||
"""Predict obs metadata attributes for the query cells based on the embedding nearest neighbors. | ||
|
||
Args: | ||
organism, census_version: | ||
Embedding information as supplied to :func:`find_nearest_obs`. | ||
neighbors: | ||
Results of a :func:`find_nearest_obs` search. | ||
column_names: | ||
Desired obs metadata column names. The current implementation is suitable for | ||
categorical attributes (e.g. cell_type, tissue_general). | ||
experiment: | ||
Open handle for the relevant SOMAExperiment, if available (otherwise, will be opened | ||
internally). e.g. ``census["census_data"]["homo_sapiens"]`` with the relevant Census | ||
version. | ||
|
||
Returns: | ||
Pandas DataFrame with the desired column predictions. Additionally, for each predicted | ||
column ``col``, an additional column ``col_confidence`` with a confidence score between 0 | ||
and 1. | ||
""" | ||
with ExitStack() as cleanup: | ||
if experiment is None: | ||
# open Census transiently | ||
census = cleanup.enter_context(open_soma(census_version=census_version)) | ||
experiment = census["census_data"][_get_experiment_name(organism)] | ||
|
||
# fetch the desired obs metadata for all of the found neighbors | ||
neighbor_obs = ( | ||
experiment.obs.read( | ||
coords=(neighbors.neighbor_ids.flatten(),), column_names=(["soma_joinid"] + list(column_names)) | ||
) | ||
.concat() | ||
.to_pandas() | ||
).set_index("soma_joinid") | ||
|
||
# step through query cells to generate prediction for each column as the plurality value | ||
# found among its neighbors, with a confidence score based on the simple fraction (for now) | ||
# TODO: something more intelligent for numeric columns! also use distances, etc. | ||
out: Dict[str, List[Any]] = {} | ||
for i in range(neighbors.neighbor_ids.shape[0]): | ||
neighbors_i = neighbor_obs.loc[neighbors.neighbor_ids[i]] | ||
for col in column_names: | ||
col_value_counts = neighbors_i[col].value_counts(normalize=True) | ||
out.setdefault(col, []).append(col_value_counts.idxmax()) | ||
out.setdefault(col + "_confidence", []).append(col_value_counts.max()) | ||
mlin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return pd.DataFrame.from_dict(out) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the API side, it would be nice if this could produce output that can be directly with sklearn style classes. For example, if this returned a KNNTransformer subclass, that could be used directly with the KNeighborsClassifier and KNeighborsRegressor classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ivirshup I like this idea very much, but I'm not quite sure it's workable (albeit I'm not as familiar with those APIs)...
Those scikit-learn classes seem oriented around the scenario where you're providing either all the points (in the "universe") or the complete distance matrix for them. Here we're working with a more limited view of the query points and their neighbor distances; we don't have or want the complete distance matrix, and actually we don't even have the coordinates of the neighbors immediately handy.
Do you think the shoe fits? I see there's some stuff about the "K neighbors graph" that might be relevant, but I'm not personally familiar enough to use them in an unconventional/advanced way like this.