Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

formally support int8 and uint8 within langchain and 2 distance metrics #561

Open
BBC-Esq opened this issue Jan 6, 2025 · 6 comments
Open

Comments

@BBC-Esq
Copy link

BBC-Esq commented Jan 6, 2025

Sorry for the overlapping posts, but I thought it prudent to initiate a separate issue if this is going to be worked on:

Here's a summary:

  1. Both langchain and sentence transformers allow a model to be run in float32, bfloat16, and float16.
  2. Sentence Transformers "encode" method only supports float32 (and certain quantizations), which results in the embeddings themselves being in float32.
  3. Tiledb.py, langchain's integration with the tiledb library, automatically converts embeddings to float32 here:

np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),

  1. Relatively newer versions of sentence-transformers support int8, uint8, binary, ubinary. Pull request here

  2. Tiledb seems to support int8 and uint8 but not the other two.

  3. Again, langchain's integration of tiledb within tiledb.py doesn't distinguish and converts everything to float32.

Does that succinctly summarize the current state of affairs? Is it possible to at least modify tiledb.py to formally support int8 and uint8 if not the other two? I noticed that @nikolaos did the initial integration in November, 2023. Is he still around at the company by chance? lol.

Here is the related "issue" where I realized this for peoples' cross reference: TileDB-Inc/TileDB-Py#2130 (comment)

@BBC-Esq
Copy link
Author

BBC-Esq commented Jan 6, 2025

I noticed that the underlying c++ library doesn't support float16 or bfloat16...where would I put in an issue for that support as well?

@BBC-Esq
Copy link
Author

BBC-Esq commented Jan 6, 2025

Here is a prompt that you can feed to AI to get an initial draft. It's a starting point, but unfortunately I don't have the expertise to perfect it; otherwise, I would. In my experience, the AI is good at initial drafts and this will save time, but this level requires more expertise.

DRAFT QUESTION TO AI TO GET START IMPLEMENTING int8/uint8 and add 2 new distance metrics
It's my understanding that this script from the langchain library essentially acts as a wrapper around the tiledb class.

"""Wrapper around TileDB vector database."""

from __future__ import annotations

import pickle
import random
import sys
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

INDEX_METRICS = frozenset(["euclidean"])
DEFAULT_METRIC = "euclidean"
DOCUMENTS_ARRAY_NAME = "documents"
VECTOR_INDEX_NAME = "vectors"
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
MAX_FLOAT_32 = np.finfo(np.dtype("float32")).max
MAX_FLOAT = sys.float_info.max


def dependable_tiledb_import() -> Any:
    """Import tiledb-vector-search if available, otherwise raise error."""
    return (
        guard_import("tiledb.vector_search"),
        guard_import("tiledb"),
    )


def get_vector_index_uri_from_group(group: Any) -> str:
    """Get the URI of the vector index."""
    return group[VECTOR_INDEX_NAME].uri


def get_documents_array_uri_from_group(group: Any) -> str:
    """Get the URI of the documents array from group.

    Args:
        group: TileDB group object.

    Returns:
        URI of the documents array.
    """
    return group[DOCUMENTS_ARRAY_NAME].uri


def get_vector_index_uri(uri: str) -> str:
    """Get the URI of the vector index."""
    return f"{uri}/{VECTOR_INDEX_NAME}"


def get_documents_array_uri(uri: str) -> str:
    """Get the URI of the documents array."""
    return f"{uri}/{DOCUMENTS_ARRAY_NAME}"


class TileDB(VectorStore):
    """TileDB vector store.

    To use, you should have the ``tiledb-vector-search`` python package installed.

    Example:
        .. code-block:: python

            from langchain_community import TileDB
            embeddings = OpenAIEmbeddings()
            db = TileDB(embeddings, index_uri, metric)

    """

    def __init__(
        self,
        embedding: Embeddings,
        index_uri: str,
        metric: str,
        *,
        vector_index_uri: str = "",
        docs_array_uri: str = "",
        config: Optional[Mapping[str, Any]] = None,
        timestamp: Any = None,
        allow_dangerous_deserialization: bool = False,
        **kwargs: Any,
    ):
        """Initialize with necessary components.

        Args:
            allow_dangerous_deserialization: whether to allow deserialization
                of the data which involves loading data using pickle.
                data can be modified by malicious actors to deliver a
                malicious payload that results in execution of
                arbitrary code on your machine.
        """
        if not allow_dangerous_deserialization:
            raise ValueError(
                "TileDB relies on pickle for serialization and deserialization. "
                "This can be dangerous if the data is intercepted and/or modified "
                "by malicious actors prior to being de-serialized. "
                "If you are sure that the data is safe from modification, you can "
                " set allow_dangerous_deserialization=True to proceed. "
                "Loading of compromised data using pickle can result in execution of "
                "arbitrary code on your machine."
            )
        self.embedding = embedding
        self.embedding_function = embedding.embed_query
        self.index_uri = index_uri
        self.metric = metric
        self.config = config

        tiledb_vs, tiledb = (
            guard_import("tiledb.vector_search"),
            guard_import("tiledb"),
        )
        with tiledb.scope_ctx(ctx_or_config=config):
            index_group = tiledb.Group(self.index_uri, "r")
            self.vector_index_uri = (
                vector_index_uri
                if vector_index_uri != ""
                else get_vector_index_uri_from_group(index_group)
            )
            self.docs_array_uri = (
                docs_array_uri
                if docs_array_uri != ""
                else get_documents_array_uri_from_group(index_group)
            )
            index_group.close()
            group = tiledb.Group(self.vector_index_uri, "r")
            self.index_type = group.meta.get("index_type")
            group.close()
            self.timestamp = timestamp
            if self.index_type == "FLAT":
                self.vector_index = tiledb_vs.flat_index.FlatIndex(
                    uri=self.vector_index_uri,
                    config=self.config,
                    timestamp=self.timestamp,
                    **kwargs,
                )
            elif self.index_type == "IVF_FLAT":
                self.vector_index = tiledb_vs.ivf_flat_index.IVFFlatIndex(
                    uri=self.vector_index_uri,
                    config=self.config,
                    timestamp=self.timestamp,
                    **kwargs,
                )

    @property
    def embeddings(self) -> Optional[Embeddings]:
        return self.embedding

    def process_index_results(
        self,
        ids: List[int],
        scores: List[float],
        *,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        score_threshold: float = MAX_FLOAT,
    ) -> List[Tuple[Document, float]]:
        """Turns TileDB results into a list of documents and scores.

        Args:
            ids: List of indices of the documents in the index.
            scores: List of distances of the documents in the index.
            k: Number of Documents to return. Defaults to 4.
            filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
            score_threshold: Optional, a floating point value to filter the
                resulting set of retrieved docs
        Returns:
            List of Documents and scores.
        """
        tiledb = guard_import("tiledb")
        docs = []
        docs_array = tiledb.open(
            self.docs_array_uri, "r", timestamp=self.timestamp, config=self.config
        )
        for idx, score in zip(ids, scores):
            if idx == 0 and score == 0:
                continue
            if idx == MAX_UINT64 and score == MAX_FLOAT_32:
                continue
            doc = docs_array[idx]
            if doc is None or len(doc["text"]) == 0:
                raise ValueError(f"Could not find document for id {idx}, got {doc}")
            pickled_metadata = doc.get("metadata")
            result_doc = Document(page_content=str(doc["text"][0]))
            if pickled_metadata is not None:
                metadata = pickle.loads(  # ignore[pickle]: explicit-opt-in
                    np.array(pickled_metadata.tolist()).astype(np.uint8).tobytes()
                )
                result_doc.metadata = metadata
            if filter is not None:
                filter = {
                    key: [value] if not isinstance(value, list) else value
                    for key, value in filter.items()
                }
                if all(
                    result_doc.metadata.get(key) in value
                    for key, value in filter.items()
                ):
                    docs.append((result_doc, score))
            else:
                docs.append((result_doc, score))
        docs_array.close()
        docs = [(doc, score) for doc, score in docs if score <= score_threshold]
        return docs[:k]

    def similarity_search_with_score_by_vector(
        self,
        embedding: List[float],
        *,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        fetch_k: int = 20,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Return docs most similar to query.

        Args:
            embedding: Embedding vector to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
            fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
                      Defaults to 20.
            **kwargs: kwargs to be passed to similarity search. Can include:
                nprobe: Optional, number of partitions to check if using IVF_FLAT index
                score_threshold: Optional, a floating point value to filter the
                    resulting set of retrieved docs

        Returns:
            List of documents most similar to the query text and distance
            in float for each. Lower score represents more similarity.
        """
        if "score_threshold" in kwargs:
            score_threshold = kwargs.pop("score_threshold")
        else:
            score_threshold = MAX_FLOAT
        d, i = self.vector_index.query(
            np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
            k=k if filter is None else fetch_k,
            **kwargs,
        )
        return self.process_index_results(
            ids=i[0], scores=d[0], filter=filter, k=k, score_threshold=score_threshold
        )

    def similarity_search_with_score(
        self,
        query: str,
        *,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        fetch_k: int = 20,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Return docs most similar to query.

        Args:
            query: Text to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
            fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
                      Defaults to 20.

        Returns:
            List of documents most similar to the query text with
            Distance as float. Lower score represents more similarity.
        """
        embedding = self.embedding_function(query)
        docs = self.similarity_search_with_score_by_vector(
            embedding,
            k=k,
            filter=filter,
            fetch_k=fetch_k,
            **kwargs,
        )
        return docs

    def similarity_search_by_vector(
        self,
        embedding: List[float],
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        fetch_k: int = 20,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs most similar to embedding vector.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
            fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
                      Defaults to 20.

        Returns:
            List of Documents most similar to the embedding.
        """
        docs_and_scores = self.similarity_search_with_score_by_vector(
            embedding,
            k=k,
            filter=filter,
            fetch_k=fetch_k,
            **kwargs,
        )
        return [doc for doc, _ in docs_and_scores]

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        fetch_k: int = 20,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs most similar to query.

        Args:
            query: Text to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
            fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
                      Defaults to 20.

        Returns:
            List of Documents most similar to the query.
        """
        docs_and_scores = self.similarity_search_with_score(
            query, k=k, filter=filter, fetch_k=fetch_k, **kwargs
        )
        return [doc for doc, _ in docs_and_scores]

    def max_marginal_relevance_search_with_score_by_vector(
        self,
        embedding: List[float],
        *,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        filter: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Return docs and their similarity scores selected using the maximal marginal
            relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents and similarity scores selected by maximal marginal
                relevance and score for each.
        """
        if "score_threshold" in kwargs:
            score_threshold = kwargs.pop("score_threshold")
        else:
            score_threshold = MAX_FLOAT
        scores, indices = self.vector_index.query(
            np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
            k=fetch_k if filter is None else fetch_k * 2,
            **kwargs,
        )
        results = self.process_index_results(
            ids=indices[0],
            scores=scores[0],
            filter=filter,
            k=fetch_k if filter is None else fetch_k * 2,
            score_threshold=score_threshold,
        )
        embeddings = [
            self.embedding.embed_documents([doc.page_content])[0] for doc, _ in results
        ]
        mmr_selected = maximal_marginal_relevance(
            np.array([embedding], dtype=np.float32),
            embeddings,
            k=k,
            lambda_mult=lambda_mult,
        )
        docs_and_scores = []
        for i in mmr_selected:
            docs_and_scores.append(results[i])
        return docs_and_scores

    def max_marginal_relevance_search_by_vector(
        self,
        embedding: List[float],
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        filter: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
            embedding,
            k=k,
            fetch_k=fetch_k,
            lambda_mult=lambda_mult,
            filter=filter,
            **kwargs,
        )
        return [doc for doc, _ in docs_and_scores]

    def max_marginal_relevance_search(
        self,
        query: str,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        filter: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            query: Text to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering (if needed) to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        embedding = self.embedding_function(query)
        docs = self.max_marginal_relevance_search_by_vector(
            embedding,
            k=k,
            fetch_k=fetch_k,
            lambda_mult=lambda_mult,
            filter=filter,
            **kwargs,
        )
        return docs

    @classmethod
    def create(
        cls,
        index_uri: str,
        index_type: str,
        dimensions: int,
        vector_type: np.dtype,
        *,
        metadatas: bool = True,
        config: Optional[Mapping[str, Any]] = None,
    ) -> None:
        tiledb_vs, tiledb = (
            guard_import("tiledb.vector_search"),
            guard_import("tiledb"),
        )
        with tiledb.scope_ctx(ctx_or_config=config):
            try:
                tiledb.group_create(index_uri)
            except tiledb.TileDBError as err:
                raise err
            group = tiledb.Group(index_uri, "w")
            vector_index_uri = get_vector_index_uri(group.uri)
            docs_uri = get_documents_array_uri(group.uri)
            if index_type == "FLAT":
                tiledb_vs.flat_index.create(
                    uri=vector_index_uri,
                    dimensions=dimensions,
                    vector_type=vector_type,
                    config=config,
                )
            elif index_type == "IVF_FLAT":
                tiledb_vs.ivf_flat_index.create(
                    uri=vector_index_uri,
                    dimensions=dimensions,
                    vector_type=vector_type,
                    config=config,
                )
            group.add(vector_index_uri, name=VECTOR_INDEX_NAME)

            # Create TileDB array to store Documents
            # TODO add a Document store API to tiledb-vector-search to allow storing
            #  different types of objects and metadata in a more generic way.
            dim = tiledb.Dim(
                name="id",
                domain=(0, MAX_UINT64 - 1),
                dtype=np.dtype(np.uint64),
            )
            dom = tiledb.Domain(dim)

            text_attr = tiledb.Attr(name="text", dtype=np.dtype("U1"), var=True)
            attrs = [text_attr]
            if metadatas:
                metadata_attr = tiledb.Attr(name="metadata", dtype=np.uint8, var=True)
                attrs.append(metadata_attr)
            schema = tiledb.ArraySchema(
                domain=dom,
                sparse=True,
                allows_duplicates=False,
                attrs=attrs,
            )
            tiledb.Array.create(docs_uri, schema)
            group.add(docs_uri, name=DOCUMENTS_ARRAY_NAME)
            group.close()

    @classmethod
    def __from(
        cls,
        texts: List[str],
        embeddings: List[List[float]],
        embedding: Embeddings,
        index_uri: str,
        *,
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        metric: str = DEFAULT_METRIC,
        index_type: str = "FLAT",
        config: Optional[Mapping[str, Any]] = None,
        index_timestamp: int = 0,
        **kwargs: Any,
    ) -> TileDB:
        if metric not in INDEX_METRICS:
            raise ValueError(
                (
                    f"Unsupported distance metric: {metric}. "
                    f"Expected one of {list(INDEX_METRICS)}"
                )
            )
        tiledb_vs, tiledb = (
            guard_import("tiledb.vector_search"),
            guard_import("tiledb"),
        )
        

        # Check if embeddings is a list
        if isinstance(embeddings, list) and embeddings:
            # Determine if it's a list of lists or a flat list
            is_nested = all(isinstance(item, list) for item in embeddings)
            structure = "List of lists" if is_nested else "Flat list"

            # Analyze the first element
            first_item = embeddings[0]
            if is_nested:
                first_item_dtype = type(first_item[0]) if first_item else None
                try:
                    first_item_shape = np.array(first_item).shape
                except Exception as e:
                    first_item_shape = f"Cannot determine shape: {e}"
            else:
                first_item_dtype = type(first_item)
                first_item_shape = ()

            # Print results
            print("Structure:", structure)
            print("First item dtype:", first_item_dtype)
            print("First item shape:", first_item_shape)
        else:
            print("Embeddings is not a non-empty list.")


        
        input_vectors = np.array(embeddings).astype(np.float32)
        cls.create(
            index_uri=index_uri,
            index_type=index_type,
            dimensions=input_vectors.shape[1],
            vector_type=input_vectors.dtype,
            metadatas=metadatas is not None,
            config=config,
        )
        with tiledb.scope_ctx(ctx_or_config=config):
            if not embeddings:
                raise ValueError("embeddings must be provided to build a TileDB index")

            vector_index_uri = get_vector_index_uri(index_uri)
            docs_uri = get_documents_array_uri(index_uri)
            if ids is None:
                ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]
            external_ids = np.array(ids).astype(np.uint64)

            tiledb_vs.ingestion.ingest(
                index_type=index_type,
                index_uri=vector_index_uri,
                input_vectors=input_vectors,
                external_ids=external_ids,
                index_timestamp=index_timestamp if index_timestamp != 0 else None,
                config=config,
                **kwargs,
            )
            with tiledb.open(docs_uri, "w") as A:
                if external_ids is None:
                    external_ids = np.zeros(len(texts), dtype=np.uint64)
                    for i in range(len(texts)):
                        external_ids[i] = i
                data = {}
                data["text"] = np.array(texts)
                if metadatas is not None:
                    metadata_attr = np.empty([len(metadatas)], dtype=object)
                    i = 0
                    for metadata in metadatas:
                        metadata_attr[i] = np.frombuffer(
                            pickle.dumps(metadata), dtype=np.uint8
                        )
                        i += 1
                    data["metadata"] = metadata_attr

                A[external_ids] = data
        return cls(
            embedding=embedding,
            index_uri=index_uri,
            metric=metric,
            config=config,
            **kwargs,
        )

    def delete(
        self, ids: Optional[List[str]] = None, timestamp: int = 0, **kwargs: Any
    ) -> Optional[bool]:
        """Delete by vector ID or other criteria.

        Args:
            ids: List of ids to delete.
            timestamp: Optional timestamp to delete with.
            **kwargs: Other keyword arguments that subclasses might use.

        Returns:
            Optional[bool]: True if deletion is successful,
            False otherwise, None if not implemented.
        """

        external_ids = np.array(ids).astype(np.uint64)
        self.vector_index.delete_batch(
            external_ids=external_ids, timestamp=timestamp if timestamp != 0 else None
        )
        return True

    def add_texts(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        timestamp: int = 0,
        **kwargs: Any,
    ) -> List[str]:
        """Run more texts through the embeddings and add to the vectorstore.

        Args:
            texts: Iterable of strings to add to the vectorstore.
            metadatas: Optional list of metadatas associated with the texts.
            ids: Optional ids of each text object.
            timestamp: Optional timestamp to write new texts with.
            kwargs: vectorstore specific parameters

        Returns:
            List of ids from adding the texts into the vectorstore.
        """
        tiledb = guard_import("tiledb")
        embeddings = self.embedding.embed_documents(list(texts))
        if ids is None:
            ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]

        external_ids = np.array(ids).astype(np.uint64)
        vectors = np.empty((len(embeddings)), dtype="O")
        for i in range(len(embeddings)):
            vectors[i] = np.array(embeddings[i], dtype=np.float32)
        self.vector_index.update_batch(
            vectors=vectors,
            external_ids=external_ids,
            timestamp=timestamp if timestamp != 0 else None,
        )

        docs = {}
        docs["text"] = np.array(texts)
        if metadatas is not None:
            metadata_attr = np.empty([len(metadatas)], dtype=object)
            i = 0
            for metadata in metadatas:
                metadata_attr[i] = np.frombuffer(pickle.dumps(metadata), dtype=np.uint8)
                i += 1
            docs["metadata"] = metadata_attr

        docs_array = tiledb.open(
            self.docs_array_uri,
            "w",
            timestamp=timestamp if timestamp != 0 else None,
            config=self.config,
        )
        docs_array[external_ids] = docs
        docs_array.close()
        return ids

    @classmethod
    def from_texts(
        cls,
        texts: List[str],
        embedding: Embeddings,
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        metric: str = DEFAULT_METRIC,
        index_uri: str = "/tmp/tiledb_array",
        index_type: str = "FLAT",
        config: Optional[Mapping[str, Any]] = None,
        index_timestamp: int = 0,
        **kwargs: Any,
    ) -> TileDB:
        """Construct a TileDB index from raw documents.

        Args:
            texts: List of documents to index.
            embedding: Embedding function to use.
            metadatas: List of metadata dictionaries to associate with documents.
            ids: Optional ids of each text object.
            metric: Metric to use for indexing. Defaults to "euclidean".
            index_uri: The URI to write the TileDB arrays
            index_type: Optional,  Vector index type ("FLAT", IVF_FLAT")
            config: Optional, TileDB config
            index_timestamp: Optional, timestamp to write new texts with.

        Example:
            .. code-block:: python

                from langchain_community import TileDB
                from langchain_community.embeddings import OpenAIEmbeddings
                embeddings = OpenAIEmbeddings()
                index = TileDB.from_texts(texts, embeddings)
        """
        embeddings = []
        embeddings = embedding.embed_documents(texts)
        return cls.__from(
            texts=texts,
            embeddings=embeddings,
            embedding=embedding,
            metadatas=metadatas,
            ids=ids,
            metric=metric,
            index_uri=index_uri,
            index_type=index_type,
            config=config,
            index_timestamp=index_timestamp,
            **kwargs,
        )

    @classmethod
    def from_embeddings(
        cls,
        text_embeddings: List[Tuple[str, List[float]]],
        embedding: Embeddings,
        index_uri: str,
        *,
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        metric: str = DEFAULT_METRIC,
        index_type: str = "FLAT",
        config: Optional[Mapping[str, Any]] = None,
        index_timestamp: int = 0,
        **kwargs: Any,
    ) -> TileDB:
        """Construct TileDB index from embeddings.

        Args:
            text_embeddings: List of tuples of (text, embedding)
            embedding: Embedding function to use.
            index_uri: The URI to write the TileDB arrays
            metadatas: List of metadata dictionaries to associate with documents.
            metric: Optional, Metric to use for indexing. Defaults to "euclidean".
            index_type: Optional, Vector index type ("FLAT", IVF_FLAT")
            config: Optional, TileDB config
            index_timestamp: Optional, timestamp to write new texts with.

        Example:
            .. code-block:: python

                from langchain_community import TileDB
                from langchain_community.embeddings import OpenAIEmbeddings
                embeddings = OpenAIEmbeddings()
                text_embeddings = embeddings.embed_documents(texts)
                text_embedding_pairs = list(zip(texts, text_embeddings))
                db = TileDB.from_embeddings(text_embedding_pairs, embeddings)
        """
        texts = [t[0] for t in text_embeddings]
        embeddings = [t[1] for t in text_embeddings]

        return cls.__from(
            texts=texts,
            embeddings=embeddings,
            embedding=embedding,
            metadatas=metadatas,
            ids=ids,
            metric=metric,
            index_uri=index_uri,
            index_type=index_type,
            config=config,
            index_timestamp=index_timestamp,
            **kwargs,
        )

    @classmethod
    def load(
        cls,
        index_uri: str,
        embedding: Embeddings,
        *,
        metric: str = DEFAULT_METRIC,
        config: Optional[Mapping[str, Any]] = None,
        timestamp: Any = None,
        **kwargs: Any,
    ) -> TileDB:
        """Load a TileDB index from a URI.

        Args:
            index_uri: The URI of the TileDB vector index.
            embedding: Embeddings to use when generating queries.
            metric: Optional, Metric to use for indexing. Defaults to "euclidean".
            config: Optional, TileDB config
            timestamp: Optional, timestamp to use for opening the arrays.
        """
        return cls(
            embedding=embedding,
            index_uri=index_uri,
            metric=metric,
            config=config,
            timestamp=timestamp,
            **kwargs,
        )

    def consolidate_updates(self, **kwargs: Any) -> None:
        self.vector_index = self.vector_index.consolidate_updates(**kwargs)

But what specific scripts or classes or functions are used within the tiledb or tiledb-vector-search libraries once they're imported?  Or would this be determined by other scripts?

Yes, several of these files look highly relevant to the TileDB vector store implementation:

1. `flat_index` and `ivf_flat_index` - These are directly referenced in the wrapper code for handling different types of vector indices. The wrapper specifically uses both "FLAT" and "IVF_FLAT" index types.

2. `dimension_coordinate` and `dimension_tile_extent` - Likely relevant to how TileDB handles dimensionality in the vector space, which is important for the embedding storage and retrieval.

3. `domain_array` - Used for managing the array domains, which the wrapper code explicitly works with via TileDB's Domain class.

4. `filter_data` - Could be relevant to how the wrapper implements its filtering functionality for vector searches.

5. `datatype` - The wrapper deals extensively with data types (like np.float32, np.uint64, etc.), so this would be relevant to understanding the underlying type system.

Would you like me to look more closely at any of these specific files? Based on the wrapper code, I'd suggest starting with `flat_index` and `ivf_flat_index` since those are explicitly used in the vector store implementation.

Here is datatypes.py.  Does it indicate that if I pass embeddings that are in "int8" (versus float32" will it work?

from __future__ import annotations

from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Tuple

import numpy as np

import tiledb.libtiledb as lt


@dataclass(frozen=True)
class DataType:
    np_dtype: np.dtype
    tiledb_type: lt.DataType
    ncells: int

    @classmethod
    @lru_cache()
    def from_numpy(cls, dtype: np.dtype) -> DataType:
        if dtype == "ascii":
            return cls(np.dtype("S"), lt.DataType.STRING_ASCII, lt.TILEDB_VAR_NUM())

        if dtype == "blob":
            return cls(np.dtype("S"), lt.DataType.BLOB, 1)

        if hasattr(lt.DataType, "GEOM_WKB") and dtype == "wkb":
            return cls(np.dtype("S"), lt.DataType.GEOM_WKB, 1)

        if hasattr(lt.DataType, "GEOM_WKT") and dtype == "wkt":
            return cls(np.dtype("S"), lt.DataType.GEOM_WKT, 1)

        dtype = np.dtype(dtype)
        if dtype.kind == "V":
            # fixed-size record dtypes
            if dtype.shape != ():
                raise TypeError("nested sub-array numpy dtypes are not supported")

            # check that types are the same
            field_dtypes = set(v[0] for v in dtype.fields.values())
            if len(field_dtypes) > 1:
                raise TypeError("heterogenous record numpy dtypes are not supported")

            base_dtype = field_dtypes.pop()
            ncells = len(dtype.fields)

        elif np.issubdtype(dtype, np.character):
            # - flexible datatypes of unknown size have an itemsize of 0 (str, bytes, etc.)
            # - character types are always stored as VAR because we don't want to store
            #   the pad (numpy pads to max length for 'S' and 'U' dtypes)
            base_dtype = np.dtype((dtype.kind, 1))
            if dtype.itemsize == 0:
                ncells = lt.TILEDB_VAR_NUM()
            else:
                ncells = dtype.itemsize // base_dtype.itemsize

        else:
            base_dtype = dtype
            ncells = 2 if np.issubdtype(dtype, np.complexfloating) else 1

        tiledb_type = _NUMPY_TO_TILEDB.get(base_dtype)
        if tiledb_type is None:
            raise TypeError(f"{dtype!r} cannot be mapped to a DataType")

        return cls(dtype, tiledb_type, ncells)

    @classmethod
    @lru_cache()
    def from_tiledb(cls, tiledb_type: lt.DataType, ncells: int = 1) -> DataType:
        base_dtype = _TILEDB_TO_NUMPY[tiledb_type]
        if tiledb_type in (lt.DataType.CHAR, lt.DataType.STRING_UTF8):
            kind = base_dtype.kind
            dtype = np.dtype((kind, ncells) if ncells != lt.TILEDB_VAR_NUM() else kind)
        elif ncells == 1 or ncells == lt.TILEDB_VAR_NUM():
            dtype = base_dtype
        elif ncells == 2 and np.issubdtype(base_dtype, np.floating):
            dtype = np.dtype("complex64" if base_dtype.itemsize == 4 else "complex128")
        else:
            # construct anonymous record dtype
            assert ncells > 1
            dtype = np.dtype([("", base_dtype)] * ncells)

        return cls(dtype, tiledb_type, ncells)

    @property  # TODO: change to functools.cached_property in Python 3.8+
    def domain(self) -> Tuple[Any, Any]:
        dtype = self.np_dtype

        if np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64):
            info = np.iinfo(np.int64)
            dt_data = np.datetime_data(dtype)
            # +1 to exclude NaT
            return dtype.type(info.min + 1, dt_data), dtype.type(info.max, dt_data)

        if np.issubdtype(dtype, np.integer):
            info = np.iinfo(dtype)
            return info.min, info.max

        if np.issubdtype(dtype, np.inexact):
            info = np.finfo(dtype)
            return info.min, info.max

        if np.issubdtype(dtype, np.bool_):
            return False, True

        if np.issubdtype(dtype, np.character):
            return None, None

        raise TypeError(f"Cannot determine min/max for {dtype!r}")

    def cast_tile_extent(self, tile_extent: Any) -> np.ndarray:
        """Given a tile extent value, cast it to np.array of this datatype's np_dtype."""
        if np.issubdtype(self.np_dtype, np.datetime64):
            # Special handling for datetime domains
            if isinstance(tile_extent, np.timedelta64):
                unit = np.datetime_data(self.np_dtype)[0]
                tile_extent /= np.timedelta64(1, unit)
            tile_dtype = np.dtype(np.int64)
        else:
            tile_dtype = self.np_dtype
        tile_size_array = np.array(tile_extent, tile_dtype)
        if tile_size_array.size != 1:
            raise ValueError("tile extent must be a scalar")
        return tile_size_array

    def uncast_tile_extent(self, tile_extent: Any) -> np.generic:
        """Given a tile extent value from PyBind, cast it to appropriate output."""
        if np.issubdtype(self.np_dtype, np.character):
            return tile_extent
        if np.issubdtype(self.np_dtype, np.datetime64):
            unit = np.datetime_data(self.np_dtype)[0]
            return np.timedelta64(tile_extent, unit)
        return self.np_dtype.type(tile_extent)


# datatype pairs that have a 1-1 mapping between tiledb and numpy
_COMMON_DATATYPES = [
    (np.dtype("bool"), lt.DataType.BOOL),
    # signed int
    (np.dtype("int8"), lt.DataType.INT8),
    (np.dtype("int16"), lt.DataType.INT16),
    (np.dtype("int32"), lt.DataType.INT32),
    (np.dtype("int64"), lt.DataType.INT64),
    # unsigned int
    (np.dtype("uint8"), lt.DataType.UINT8),
    (np.dtype("uint16"), lt.DataType.UINT16),
    (np.dtype("uint32"), lt.DataType.UINT32),
    (np.dtype("uint64"), lt.DataType.UINT64),
    # float
    (np.dtype("float32"), lt.DataType.FLOAT32),
    (np.dtype("float64"), lt.DataType.FLOAT64),
    # datetime
    (np.dtype("<M8[Y]"), lt.DataType.DATETIME_YEAR),
    (np.dtype("<M8[M]"), lt.DataType.DATETIME_MONTH),
    (np.dtype("<M8[W]"), lt.DataType.DATETIME_WEEK),
    (np.dtype("<M8[D]"), lt.DataType.DATETIME_DAY),
    (np.dtype("<M8[h]"), lt.DataType.DATETIME_HR),
    (np.dtype("<M8[m]"), lt.DataType.DATETIME_MIN),
    (np.dtype("<M8[s]"), lt.DataType.DATETIME_SEC),
    (np.dtype("<M8[ms]"), lt.DataType.DATETIME_MS),
    (np.dtype("<M8[us]"), lt.DataType.DATETIME_US),
    (np.dtype("<M8[ns]"), lt.DataType.DATETIME_NS),
    (np.dtype("<M8[ps]"), lt.DataType.DATETIME_PS),
    (np.dtype("<M8[fs]"), lt.DataType.DATETIME_FS),
    (np.dtype("<M8[as]"), lt.DataType.DATETIME_AS),
    # timedelta
    (np.dtype("<m8[h]"), lt.DataType.TIME_HR),
    (np.dtype("<m8[m]"), lt.DataType.TIME_MIN),
    (np.dtype("<m8[s]"), lt.DataType.TIME_SEC),
    (np.dtype("<m8[ms]"), lt.DataType.TIME_MS),
    (np.dtype("<m8[us]"), lt.DataType.TIME_US),
    (np.dtype("<m8[ns]"), lt.DataType.TIME_NS),
    (np.dtype("<m8[ps]"), lt.DataType.TIME_PS),
    (np.dtype("<m8[fs]"), lt.DataType.TIME_FS),
    (np.dtype("<m8[as]"), lt.DataType.TIME_AS),
    # byte/string
    (np.dtype("S1"), lt.DataType.CHAR),
    (np.dtype("<U1"), lt.DataType.STRING_UTF8),
]
assert len(set(x for x, y in _COMMON_DATATYPES)) == len(_COMMON_DATATYPES)
assert len(set(y for x, y in _COMMON_DATATYPES)) == len(_COMMON_DATATYPES)

# numpy has complex, tiledb doesn't
_NUMPY_TO_TILEDB = {n: t for n, t in _COMMON_DATATYPES}
_NUMPY_TO_TILEDB[np.dtype("complex64")] = lt.DataType.FLOAT32
_NUMPY_TO_TILEDB[np.dtype("complex128")] = lt.DataType.FLOAT64

# tiledb has STRING_ASCII, BLOB, WKB and WKT types, numpy doesn't
_TILEDB_TO_NUMPY = {t: n for n, t in _COMMON_DATATYPES}
_TILEDB_TO_NUMPY[lt.DataType.STRING_ASCII] = np.dtype("S")
_TILEDB_TO_NUMPY[lt.DataType.BLOB] = np.dtype("S")
if hasattr(lt.DataType, "GEOM_WKB"):
    _TILEDB_TO_NUMPY[lt.DataType.GEOM_WKB] = np.dtype("S")
    _TILEDB_TO_NUMPY[lt.DataType.GEOM_WKT] = np.dtype("S")

# pre-populate the LRU caches with all ncell=1 datatypes
list(map(DataType.from_numpy, _NUMPY_TO_TILEDB.keys()))
assert DataType.from_numpy.cache_info().currsize == len(_NUMPY_TO_TILEDB)
list(map(DataType.from_tiledb, _TILEDB_TO_NUMPY.keys()))
assert DataType.from_tiledb.cache_info().currsize == len(_TILEDB_TO_NUMPY)

Looking at the code, the TileDB wrapper you shared earlier specifically converts embeddings to float32 before storage. Here's the relevant line from the earlier code:

input_vectors = np.array(embeddings).astype(np.float32)

So while TileDB itself can handle int8 data types, the vector store wrapper is specifically designed to work with float32 embeddings. If you pass int8 embeddings, they will be automatically converted to float32 during the storage process.
If you need to use int8 embeddings for some specific reason (like storage efficiency), you would need to modify the wrapper code to remove the forced float32 conversion. 

I'm only interested in utilizing the "int8" and "uint8" types in my python program.  What changes would I need to make to tiledb.py to support them?

Looking at the code, to support int8/uint8 types, you'd need to modify several parts of the TileDB wrapper where it forces float32 conversion. Here are the key changes needed:

1. In the `__from` method (around line 558), change this line:
```python
input_vectors = np.array(embeddings).astype(np.float32)

to something like:

# Allow int8/uint8 types to pass through
if vector_type in [np.dtype('int8'), np.dtype('uint8')]:
    input_vectors = np.array(embeddings).astype(vector_type)
else:
    input_vectors = np.array(embeddings).astype(np.float32)
  1. In the similarity_search_with_score_by_vector method (around line 297), change:
np.array([np.array(embedding).astype(np.float32)]).astype(np.float32)

to:

# Preserve original vector type
vector_type = self.vector_index.vector_type
np.array([np.array(embedding).astype(vector_type)]).astype(vector_type)
  1. Similar changes would be needed in max_marginal_relevance_search_with_score_by_vector and other methods that perform vector operations.

  2. In the add_texts method (around line 512):

vectors[i] = np.array(embeddings[i], dtype=np.float32)

would need to be modified to use the correct type.

However, there are some important caveats:

  1. The INDEX_METRICS currently only supports "euclidean" distance, which typically works better with floating-point values
  2. Many embedding models output float32 vectors by default
  3. Vector similarity calculations might be less accurate with integer types

You might want to add validation to ensure the int8/uint8 vectors stay within their valid ranges (-128 to 127 for int8, 0 to 255 for uint8).

I want you to tell me if any of the source code or the picture of the source code files indicate where the euclidean distance is implemented. I want to check if, perhaps, the tiledb library supports it or whether langchain's tiledb.py simply only implements it as a design choice.

Looking at the tiledb.py wrapper code, we can see INDEX_METRICS = frozenset(["euclidean"]) is defined at the top, but the actual distance calculation isn't implemented in this file. The closest reference to it is when the code calls the vector index's query method:

d, i = self.vector_index.query(
np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
k=k if filter is None else fetch_k,
**kwargs,
)

Here is ivf_flat_index.py:

"""
IVFFlat Index implementation.

IVFFlatIndex is based on k-means clustering and shuffling of the dataset vectors.

During ingestion, TileDB computes the k-means clusters and shuffles the vectors into partitions.
The vectors are stored grouped by partition in a 2D TileDB array allowing for partitions to be read
with minimal I/O overhead.

To answer a query, the search focuses only on a small number of partitions, based on the query’s proximity
to the k-means centroids. This is specified with a parameter called nprobe controlling how many partitions
are checked for each query.

IVFFlatIndex provides a vector search implementation that can trade-off accuracy for performance.

Queries can be run in multiple modes:

  • Local main memory:
    • Loads the entire index in memory during initialization and uses it to answer queries.
  • Local out of core:
    • Avoids loading index data in memory by interleaving I/O and query execution, respecting the
      memory budget defined by the user.
  • Distributed execution:
    • Executes the queries using multiple workers in TileDB Cloud.
      """
      import json
      import multiprocessing
      from threading import Thread
      from typing import Any, Mapping, Sequence

import numpy as np

from tiledb.cloud.dag import Mode
from tiledb.vector_search import index
from tiledb.vector_search.module import *
from tiledb.vector_search.storage_formats import STORAGE_VERSION
from tiledb.vector_search.storage_formats import storage_formats
from tiledb.vector_search.storage_formats import validate_storage_version
from tiledb.vector_search.utils import MAX_FLOAT32
from tiledb.vector_search.utils import MAX_INT32
from tiledb.vector_search.utils import MAX_UINT64
from tiledb.vector_search.utils import create_array_and_add_to_group
from tiledb.vector_search.utils import normalize_vector
from tiledb.vector_search.utils import normalize_vectors

TILE_SIZE_BYTES = 64000000 # 64MB
INDEX_TYPE = "IVF_FLAT"

def submit_local(d, func, *args, **kwargs):
# Drop kwarg
kwargs.pop("image_name", None)
kwargs.pop("resource_class", None)
kwargs.pop("resources", None)
return d.submit_local(func, *args, **kwargs)

class IVFFlatIndex(index.Index):
"""
Opens an IVFFlatIndex.

Parameters
----------
uri: str
    URI of the index.
config: Optional[Mapping[str, Any]]
    TileDB config dictionary.
timestamp: int or tuple(int)
    If int, open the index at a given timestamp.
    If tuple, open at the given start and end timestamps.
memory_budget: int
    Main memory budget, in number of vectors, for query execution.
    If not provided, all index data are loaded in main memory.
    Otherwise, no index data are loaded in main memory and this memory budget is
    applied during queries.
open_for_remote_query_execution: bool
    If `True`, do not load any index data in main memory locally, and instead load index data in the TileDB Cloud taskgraph created when a non-`None` `driver_mode` is passed to `query()`. We then load index data in the taskgraph based on `memory_budget`.
    If `False`, load index data in main memory locally according to `memory_budget`. Note that you can still use a taskgraph for query execution, you'll just end up loading the data both on your local machine and in the cloud taskgraph..
"""

def __init__(
    self,
    uri: str,
    config: Optional[Mapping[str, Any]] = None,
    timestamp=None,
    memory_budget: int = -1,
    open_for_remote_query_execution: bool = False,
    group: tiledb.Group = None,
    **kwargs,
):
    self.index_open_kwargs = {
        "uri": uri,
        "config": config,
        "timestamp": timestamp,
        "memory_budget": memory_budget,
    }
    self.index_open_kwargs.update(kwargs)
    self.index_type = INDEX_TYPE
    super().__init__(
        uri=uri,
        config=config,
        timestamp=timestamp,
        open_for_remote_query_execution=open_for_remote_query_execution,
        group=group,
    )
    self.db_uri = self.group[
        storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
        + self.index_version
    ].uri
    self.centroids_uri = self.group[
        storage_formats[self.storage_version]["CENTROIDS_ARRAY_NAME"]
        + self.index_version
    ].uri
    self.index_array_uri = self.group[
        storage_formats[self.storage_version]["INDEX_ARRAY_NAME"]
        + self.index_version
    ].uri
    self.ids_uri = self.group[
        storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version
    ].uri
    self.memory_budget = memory_budget

    schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
    self.dimensions = schema.shape[0]

    self.dtype = self.group.meta.get("dtype", None)
    if self.dtype is None:
        self.dtype = np.dtype(schema.attr("values").dtype)
    else:
        self.dtype = np.dtype(self.dtype)

    if self.base_size == 0:
        self.size = 0
        self.partitions = 0
        return

    self.partition_history = [
        int(x)
        for x in list(json.loads(self.group.meta.get("partition_history", "[]")))
    ]
    if len(self.partition_history) == 0:
        schema = tiledb.ArraySchema.load(
            self.centroids_uri, ctx=tiledb.Ctx(self.config)
        )
        self.partitions = schema.domain.dim("cols").domain[1] + 1
    else:
        self.partitions = self.partition_history[self.history_index]

    if not open_for_remote_query_execution:
        self._centroids = load_as_matrix(
            self.centroids_uri,
            ctx=self.ctx,
            size=self.partitions,
            config=config,
            timestamp=self.base_array_timestamp,
        )
        self._index = read_vector_u64(
            self.ctx,
            self.index_array_uri,
            0,
            self.partitions + 1,
            self.base_array_timestamp,
        )

    if self.base_size == -1:
        self.size = self._index[self.partitions]
    else:
        self.size = self.base_size

    # TODO pass in a context
    if not open_for_remote_query_execution and self.memory_budget == -1:
        self._db = load_as_matrix(
            self.db_uri,
            ctx=self.ctx,
            config=config,
            size=self.size,
            timestamp=self.base_array_timestamp,
        )
        self._ids = read_vector_u64(
            self.ctx, self.ids_uri, 0, self.size, self.base_array_timestamp
        )

def get_dimensions(self):
    """
    Returns the dimension of the vectors in the index.
    """
    return self.dimensions

def query_internal(
    self,
    queries: np.ndarray,
    k: int = 10,
    nprobe: int = 1,
    nthreads: int = -1,
    use_nuv_implementation: bool = False,
    mode: Optional[Mode] = None,
    resource_class: Optional[str] = None,
    resources: Optional[Mapping[str, Any]] = None,
    num_partitions: int = -1,
    num_workers: int = -1,
    **kwargs,
):
    """
    Queries an `IVFFlatIndex`.

    Parameters
    ----------
    queries: np.ndarray
        2D array of query vectors. This can be used as a batch query interface by passing multiple queries in one call.
    k: int
        Number of results to return per query vector.
    nprobe: int
        Number of partitions to check per query.
        Use this parameter to trade-off accuracy for latency and cost.
        As a rule of thumb, configuring `nprobe` to be the square root of `partitions` should result in accuracy close to 100%.
    nthreads: int
        Number of threads to use for local query execution.
    use_nuv_implementation: bool
        Whether to use the nuv query implementation. Default: False
    mode: Mode
        If provided the query will be executed using TileDB cloud taskgraphs.
        For distributed execution you can use REALTIME or BATCH mode.
        For local execution you can use LOCAL mode.
    resource_class:
        The name of the resource class to use ("standard" or "large"). Resource classes define maximum
        limits for cpu and memory usage. Can only be used in REALTIME or BATCH mode.
        Cannot be used alongside resources.
        In REALTIME or BATCH mode if neither resource_class nor resources are provided,
        we default to the "large" resource class.
    resources:
        A specification for the amount of resources to use when executing using TileDB cloud
        taskgraphs, of the form: {"cpu": "6", "memory": "12Gi", "gpu": 1}. Can only be used
        in BATCH mode. Cannot be used alongside resource_class.
    num_partitions: int
        Only relevant for taskgraph based execution.
        If provided, we split the query execution in that many partitions.
    num_workers: int
        Only relevant for taskgraph based execution.
        If provided, this is the number of workers to use for the query execution.
    """
    if self.size == 0:
        return np.full((queries.shape[0], k), MAX_FLOAT32), np.full(
            (queries.shape[0], k), MAX_UINT64
        )

    if mode != Mode.BATCH and resources:
        raise TypeError("Can only pass resources in BATCH mode")
    if (mode != Mode.REALTIME and mode != Mode.BATCH) and resource_class:
        raise TypeError("Can only pass resource_class in REALTIME or BATCH mode")

    if queries.ndim == 1:
        queries = np.array([queries])

    if self.distance_metric == vspy.DistanceMetric.COSINE:
        queries = normalize_vectors(queries)

    if nthreads == -1:
        nthreads = multiprocessing.cpu_count()

    nprobe = min(nprobe, self.partitions)
    if mode is None:
        queries_m = array_to_matrix(np.transpose(queries))
        if self.memory_budget == -1:
            d, i = ivf_query_ram(
                self.dtype,
                self._db,
                self._centroids,
                queries_m,
                self._index,
                self._ids,
                nprobe=nprobe,
                k_nn=k,
                nthreads=nthreads,
                ctx=self.ctx,
                use_nuv_implementation=use_nuv_implementation,
                distance_metric=self.distance_metric,
            )
        else:
            d, i = ivf_query(
                self.dtype,
                self.db_uri,
                self._centroids,
                queries_m,
                self._index,
                self.ids_uri,
                nprobe=nprobe,
                k_nn=k,
                memory_budget=self.memory_budget,
                nthreads=nthreads,
                ctx=self.ctx,
                use_nuv_implementation=use_nuv_implementation,
                timestamp=self.base_array_timestamp,
                distance_metric=self.distance_metric,
            )

        return np.transpose(np.array(d)), np.transpose(np.array(i))
    else:
        return self._taskgraph_query(
            queries=queries,
            k=k,
            nthreads=nthreads,
            nprobe=nprobe,
            mode=mode,
            resource_class=resource_class,
            resources=resources,
            num_partitions=num_partitions,
            num_workers=num_workers,
            config=self.config,
            distance_metric=self.distance_metric,
        )

def update(self, vector: np.array, external_id: np.uint64, timestamp: int = None):
    if self.distance_metric == vspy.DistanceMetric.COSINE:
        vector = normalize_vector(vector)
    super().update(vector, external_id, timestamp)

def update_batch(
    self, vectors: np.ndarray, external_ids: np.array, timestamp: int = None
):
    if self.distance_metric == vspy.DistanceMetric.COSINE:
        vectors = normalize_vectors(vectors)
    super().update_batch(vectors, external_ids, timestamp)

def query(
    self,
    queries: np.ndarray,
    k: int,
    **kwargs,
):
    if self.distance_metric == vspy.DistanceMetric.COSINE:
        queries = normalize_vectors(queries)
    return super().query(
        queries=queries,
        k=k,
        **kwargs,
    )

def _taskgraph_query(
    self,
    queries: np.ndarray,
    k: int = 10,
    nprobe: int = 10,
    nthreads: int = -1,
    mode: Mode = None,
    resource_class: Optional[str] = None,
    resources: Optional[Mapping[str, Any]] = None,
    num_partitions: int = -1,
    num_workers: int = -1,
    config: Optional[Mapping[str, Any]] = None,
    distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.SUM_OF_SQUARES,
):
    """
    Query an IVF_FLAT index using TileDB cloud taskgraphs

    Parameters
    ----------
    queries: numpy.ndarray
        ND Array of queries
    k: int
        Number of top results to return per query
    nprobe: int
        number of probes
    nthreads: int
        Number of threads to use for query
    mode: Mode
        If provided the query will be executed using TileDB cloud taskgraphs.
        For distributed execution you can use REALTIME or BATCH mode.
        For local execution you can use LOCAL mode.
    resource_class:
        The name of the resource class to use ("standard" or "large"). Resource classes define maximum
        limits for cpu and memory usage. Can only be used in REALTIME or BATCH mode.
        Cannot be used alongside resources.
        In REALTIME or BATCH mode if neither resource_class nor resources are provided,
        we default to the "large" resource class.
    resources:
        A specification for the amount of resources to use when executing using TileDB cloud
        taskgraphs, of the form: {"cpu": "6", "memory": "12Gi", "gpu": 1}. Can only be used
        in BATCH mode. Cannot be used alongside resource_class.
    num_partitions: int
        Only relevant for taskgraph based execution.
        If provided, we split the query execution in that many partitions.
    num_workers: int
        Only relevant for taskgraph based execution.
        If provided, this is the number of workers to use for the query execution.
    config: None
        config dictionary, defaults to None
    """
    import math
    from functools import partial

    import numpy as np

    from tiledb.cloud import dag
    from tiledb.cloud.dag import Mode
    from tiledb.vector_search.module import array_to_matrix
    from tiledb.vector_search.module import dist_qv
    from tiledb.vector_search.module import partition_ivf_index

    if resource_class and resources:
        raise TypeError("Cannot provide both resource_class and resources")

    def dist_qv_udf(
        dtype: np.dtype,
        parts_uri: str,
        ids_uri: str,
        query_vectors: np.ndarray,
        active_partitions: np.array,
        active_queries: np.array,
        indices: np.array,
        k_nn: int,
        config: Optional[Mapping[str, Any]] = None,
        timestamp: int = 0,
        memory_budget: int = -1,
    ):
        queries_m = array_to_matrix(np.transpose(query_vectors))
        r = dist_qv(
            dtype=dtype,
            parts_uri=parts_uri,
            ids_uri=ids_uri,
            query_vectors=queries_m,
            active_partitions=active_partitions,
            active_queries=active_queries,
            indices=indices,
            k_nn=k_nn,
            ctx=Ctx(config),
            timestamp=timestamp,
            upper_bound=0 if memory_budget == -1 else memory_budget,
        )
        results = []
        for q in range(len(r)):
            tmp_results = []
            for j in range(len(r[q])):
                tmp_results.append(r[q][j])
            results.append(tmp_results)
        return results

    if num_partitions == -1:
        num_partitions = 5
    if num_workers == -1:
        num_workers = num_partitions
    if mode == Mode.BATCH:
        d = dag.DAG(
            name="vector-query",
            mode=Mode.BATCH,
            max_workers=num_workers,
        )
    elif mode == Mode.REALTIME:
        d = dag.DAG(
            name="vector-query",
            mode=Mode.REALTIME,
            max_workers=num_workers,
        )
    else:
        d = dag.DAG(
            name="vector-query",
            mode=Mode.REALTIME,
            max_workers=1,
            namespace="default",
        )
    submit = partial(submit_local, d)
    if mode == Mode.BATCH or mode == Mode.REALTIME:
        submit = d.submit

    queries_m = array_to_matrix(np.transpose(queries))
    active_partitions, active_queries = partition_ivf_index(
        centroids=self._centroids, query=queries_m, nprobe=nprobe, nthreads=nthreads
    )
    num_parts = len(active_partitions)

    parts_per_node = int(math.ceil(num_parts / num_partitions))
    nodes = []
    for part in range(0, num_parts, parts_per_node):
        part_end = part + parts_per_node
        if part_end > num_parts:
            part_end = num_parts
        aq = []
        for tt in range(part, part_end):
            aqt = []
            for ttt in range(len(active_queries[tt])):
                aqt.append(active_queries[tt][ttt])
            aq.append(aqt)
        nodes.append(
            submit(
                dist_qv_udf,
                dtype=self.dtype,
                parts_uri=self.db_uri,
                ids_uri=self.ids_uri,
                query_vectors=queries,
                active_partitions=np.array(active_partitions)[part:part_end],
                active_queries=np.array(aq, dtype=object),
                indices=np.array(self._index),
                k_nn=k,
                config=config,
                timestamp=self.base_array_timestamp,
                memory_budget=self.memory_budget,
                resource_class="large"
                if (not resources and not resource_class)
                else resource_class,
                resources=resources,
                image_name="3.9-vectorsearch",
            )
        )

    d.compute()
    d.wait()
    results = []
    for node in nodes:
        res = node.result()
        results.append(res)

    results_per_query_d = []
    results_per_query_i = []
    for q in range(queries.shape[0]):
        tmp_results = []
        for j in range(k):
            for r in results:
                if len(r[q]) > j:
                    if r[q][j][0] > 0:
                        tmp_results.append(r[q][j])
        tmp = sorted(tmp_results, key=lambda t: t[0])[0:k]
        for j in range(len(tmp), k):
            tmp.append((float(0.0), int(0)))
        results_per_query_d.append(np.array(tmp, dtype=np.float32)[:, 0])
        results_per_query_i.append(np.array(tmp, dtype=np.uint64)[:, 1])
    return np.array(results_per_query_d), np.array(results_per_query_i)

def vacuum(self):
    """
    The vacuuming process permanently deletes index files that are consolidated through the consolidation
    process. TileDB separates consolidation from vacuuming, in order to make consolidation process-safe
    in the presence of concurrent reads and writes.

    Note:

    1. Vacuuming is not process-safe and you should take extra care when invoking it.
    2. Vacuuming may affect the granularity of the time traveling functionality.

    The IVFFlat class vacuums consolidated fragment, array metadata and commits for the `db`
    and `ids` arrays.
    """
    super().vacuum()
    if not self.uri.startswith("tiledb://"):
        modes = ["fragment_meta", "commits", "array_meta"]
        for mode in modes:
            conf = tiledb.Config(self.config)
            conf["sm.consolidation.mode"] = mode
            conf["sm.vacuum.mode"] = mode
            tiledb.vacuum(self.db_uri, config=conf)
            tiledb.vacuum(self.ids_uri, config=conf)

def create(
uri: str,
dimensions: int,
vector_type: np.dtype,
group_exists: bool = False,
config: Optional[Mapping[str, Any]] = None,
storage_version: str = STORAGE_VERSION,
distance_metric: vspy.DistanceMetric = vspy.DistanceMetric.SUM_OF_SQUARES,
group: tiledb.Group = None,
asset_creation_threads: Sequence[Thread] = None,
**kwargs,
) -> IVFFlatIndex:
"""
Creates an empty IVFFlatIndex.

Parameters
----------
uri: str
    URI of the index.
dimensions: int
    Number of dimensions for the vectors to be stored in the index.
vector_type: np.dtype
    Datatype of vectors.
    Supported values (uint8, int8, float32).
group_exists: bool
    If False it creates the TileDB group for the index.
    If True the method expects the TileDB group to be already created.
config: Optional[Mapping[str, Any]]
    TileDB config dictionary.
storage_version: str
    The TileDB vector search storage version to use.
    If not provided, use the latest stable storage version.
group: tiledb.Group
    TileDB group open in write mode.
    Internal, this is used to avoid opening the group multiple times during
    ingestion.
asset_creation_threads: Sequence[Thread]
    List of asset creation threads to append new threads.
    Internal, this is used to parallelize all asset creation during
    ingestion.

"""
validate_storage_version(storage_version)
if (
    distance_metric != vspy.DistanceMetric.SUM_OF_SQUARES
    and distance_metric != vspy.DistanceMetric.L2
    and distance_metric != vspy.DistanceMetric.COSINE
):
    raise ValueError(
        f"Distance metric {distance_metric} is not supported in IVF_FLAT"
    )

if group is None != asset_creation_threads is not None:
    raise ValueError(
        "Can't pass `asset_creation_threads` without a `group` argument."
    )

with tiledb.scope_ctx(ctx_or_config=config):
    if not group_exists:
        try:
            tiledb.group_create(uri)
        except tiledb.TileDBError as err:
            raise err
    if group is None:
        grp = tiledb.Group(uri, "w")
    else:
        grp = group

    if asset_creation_threads is not None:
        threads = asset_creation_threads
    else:
        threads = []

    index.create_metadata(
        group=grp,
        dimensions=dimensions,
        vector_type=vector_type,
        index_type=INDEX_TYPE,
        storage_version=storage_version,
        distance_metric=distance_metric,
    )

    tile_size = int(TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions)
    grp.meta["partition_history"] = json.dumps([0])
    centroids_array_name = storage_formats[storage_version]["CENTROIDS_ARRAY_NAME"]
    index_array_name = storage_formats[storage_version]["INDEX_ARRAY_NAME"]
    ids_array_name = storage_formats[storage_version]["IDS_ARRAY_NAME"]
    parts_array_name = storage_formats[storage_version]["PARTS_ARRAY_NAME"]
    updates_array_name = storage_formats[storage_version]["UPDATES_ARRAY_NAME"]
    centroids_uri = f"{uri}/{centroids_array_name}"
    index_array_uri = f"{uri}/{index_array_name}"
    ids_uri = f"{uri}/{ids_array_name}"
    parts_uri = f"{uri}/{parts_array_name}"
    updates_array_uri = f"{uri}/{updates_array_name}"

    centroids_array_rows_dim = tiledb.Dim(
        name="rows",
        domain=(0, dimensions - 1),
        tile=dimensions,
        dtype=np.dtype(np.int32),
    )
    centroids_array_cols_dim = tiledb.Dim(
        name="cols",
        domain=(0, MAX_INT32),
        tile=100000,
        dtype=np.dtype(np.int32),
    )
    centroids_array_dom = tiledb.Domain(
        centroids_array_rows_dim, centroids_array_cols_dim
    )
    centroids_attr = tiledb.Attr(
        name="centroids",
        dtype=np.dtype(np.float32),
        filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
    )
    centroids_schema = tiledb.ArraySchema(
        domain=centroids_array_dom,
        sparse=False,
        attrs=[centroids_attr],
        cell_order="col-major",
        tile_order="col-major",
    )
    thread = Thread(
        target=create_array_and_add_to_group,
        kwargs={
            "array_uri": centroids_uri,
            "array_name": centroids_array_name,
            "group": grp,
            "schema": centroids_schema,
        },
    )
    thread.start()
    threads.append(thread)

    index_array_rows_dim = tiledb.Dim(
        name="rows",
        domain=(0, MAX_INT32),
        tile=100000,
        dtype=np.dtype(np.int32),
    )
    index_array_dom = tiledb.Domain(index_array_rows_dim)
    index_attr = tiledb.Attr(
        name="values",
        dtype=np.dtype(np.uint64),
        filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
    )
    index_schema = tiledb.ArraySchema(
        domain=index_array_dom,
        sparse=False,
        attrs=[index_attr],
        cell_order="col-major",
        tile_order="col-major",
    )
    thread = Thread(
        target=create_array_and_add_to_group,
        kwargs={
            "array_uri": index_array_uri,
            "array_name": index_array_name,
            "group": grp,
            "schema": index_schema,
        },
    )
    thread.start()
    threads.append(thread)

    ids_array_rows_dim = tiledb.Dim(
        name="rows",
        domain=(0, MAX_INT32),
        tile=tile_size,
        dtype=np.dtype(np.int32),
    )
    ids_array_dom = tiledb.Domain(ids_array_rows_dim)
    ids_attr = tiledb.Attr(
        name="values",
        dtype=np.dtype(np.uint64),
        filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
    )
    ids_schema = tiledb.ArraySchema(
        domain=ids_array_dom,
        sparse=False,
        attrs=[ids_attr],
        cell_order="col-major",
        tile_order="col-major",
    )
    thread = Thread(
        target=create_array_and_add_to_group,
        kwargs={
            "array_uri": ids_uri,
            "array_name": ids_array_name,
            "group": grp,
            "schema": ids_schema,
        },
    )
    thread.start()
    threads.append(thread)

    parts_array_rows_dim = tiledb.Dim(
        name="rows",
        domain=(0, dimensions - 1),
        tile=dimensions,
        dtype=np.dtype(np.int32),
    )
    parts_array_cols_dim = tiledb.Dim(
        name="cols",
        domain=(0, MAX_INT32),
        tile=tile_size,
        dtype=np.dtype(np.int32),
    )
    parts_array_dom = tiledb.Domain(parts_array_rows_dim, parts_array_cols_dim)
    parts_attr = tiledb.Attr(
        name="values",
        dtype=vector_type,
        filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
    )
    parts_schema = tiledb.ArraySchema(
        domain=parts_array_dom,
        sparse=False,
        attrs=[parts_attr],
        cell_order="col-major",
        tile_order="col-major",
    )
    thread = Thread(
        target=create_array_and_add_to_group,
        kwargs={
            "array_uri": parts_uri,
            "array_name": parts_array_name,
            "group": grp,
            "schema": parts_schema,
        },
    )
    thread.start()
    threads.append(thread)

    external_id_dim = tiledb.Dim(
        name="external_id",
        domain=(0, MAX_UINT64 - 1),
        dtype=np.dtype(np.uint64),
    )
    dom = tiledb.Domain(external_id_dim)
    vector_attr = tiledb.Attr(name="vector", dtype=vector_type, var=True)
    updates_schema = tiledb.ArraySchema(
        domain=dom,
        sparse=True,
        attrs=[vector_attr],
        allows_duplicates=False,
    )
    thread = Thread(
        target=create_array_and_add_to_group,
        kwargs={
            "array_uri": updates_array_uri,
            "array_name": updates_array_name,
            "group": grp,
            "schema": updates_schema,
        },
    )
    thread.start()
    threads.append(thread)

    if asset_creation_threads is None:
        for thread in threads:
            thread.join()
    if group is None:
        grp.close()
        return IVFFlatIndex(uri=uri, config=config, memory_budget=1000000)
    else:
        return None

Looking at ivf_flat_index.py, we can see that TileDB's vector search actually supports multiple distance metrics. This is evident in the create() function where it checks the distance metric:

if (
distance_metric != vspy.DistanceMetric.SUM_OF_SQUARES
and distance_metric != vspy.DistanceMetric.L2
and distance_metric != vspy.DistanceMetric.COSINE
):
raise ValueError(
f"Distance metric {distance_metric} is not supported in IVF_FLAT"
)

TileDB supports:

SUM_OF_SQUARES
L2 (Euclidean)
COSINE

So the limitation to only Euclidean distance in the langchain wrapper (INDEX_METRICS = frozenset(["euclidean"])) appears to be a design choice by the wrapper authors, not a limitation of TileDB itself.
Additionally, this file shows special handling for cosine similarity in several methods like query(), update(), and update_batch() where vectors are normalized when using cosine distance:

if self.distance_metric == vspy.DistanceMetric.COSINE:
queries = normalize_vectors(queries)

So if you wanted to use other distance metrics, you would need to modify the langchain wrapper to expose TileDB's full range of supported distance metrics.

Do you have enough information to confidently modify the langchain tiledb.py to support the other distance metrics as well as "int8" and "uint8"?

</details>

@BBC-Esq BBC-Esq changed the title formally support int8 and uint8 within langchain formally support int8 and uint8 within langchain and 2 distance metrics Jan 6, 2025
@BBC-Esq
Copy link
Author

BBC-Esq commented Jan 6, 2025

Draft revised tiledb.py for your consideration below.

draft script
"""Wrapper around TileDB vector database."""

from __future__ import annotations

import pickle
import random
import sys
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

INDEX_METRICS = frozenset(["euclidean"])
DEFAULT_METRIC = "euclidean"
DOCUMENTS_ARRAY_NAME = "documents"
VECTOR_INDEX_NAME = "vectors"
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
MAX_FLOAT_32 = np.finfo(np.dtype("float32")).max
MAX_FLOAT = sys.float_info.max


def dependable_tiledb_import() -> Any:
    """Import tiledb-vector-search if available, otherwise raise error."""
    return (
        guard_import("tiledb.vector_search"),
        guard_import("tiledb"),
    )


def get_vector_index_uri_from_group(group: Any) -> str:
    """Get the URI of the vector index."""
    return group[VECTOR_INDEX_NAME].uri


def get_documents_array_uri_from_group(group: Any) -> str:
    """Get the URI of the documents array from group.

    Args:
        group: TileDB group object.

    Returns:
        URI of the documents array.
    """
    return group[DOCUMENTS_ARRAY_NAME].uri


def get_vector_index_uri(uri: str) -> str:
    """Get the URI of the vector index."""
    return f"{uri}/{VECTOR_INDEX_NAME}"


def get_documents_array_uri(uri: str) -> str:
    """Get the URI of the documents array."""
    return f"{uri}/{DOCUMENTS_ARRAY_NAME}"


class TileDB(VectorStore):
    """TileDB vector store.

    To use, you should have the ``tiledb-vector-search`` python package installed.

    Example:
        .. code-block:: python

            from langchain_community import TileDB
            embeddings = OpenAIEmbeddings()
            db = TileDB(embeddings, index_uri, metric)

    """

    def __init__(
        self,
        embedding: Embeddings,
        index_uri: str,
        metric: str,
        *,
        vector_index_uri: str = "",
        docs_array_uri: str = "",
        config: Optional[Mapping[str, Any]] = None,
        timestamp: Any = None,
        allow_dangerous_deserialization: bool = False,
        **kwargs: Any,
    ):
        """Initialize with necessary components.

        Args:
            allow_dangerous_deserialization: whether to allow deserialization
                of the data which involves loading data using pickle.
                data can be modified by malicious actors to deliver a
                malicious payload that results in execution of
                arbitrary code on your machine.
        """
        if not allow_dangerous_deserialization:
            raise ValueError(
                "TileDB relies on pickle for serialization and deserialization. "
                "This can be dangerous if the data is intercepted and/or modified "
                "by malicious actors prior to being de-serialized. "
                "If you are sure that the data is safe from modification, you can "
                " set allow_dangerous_deserialization=True to proceed. "
                "Loading of compromised data using pickle can result in execution of "
                "arbitrary code on your machine."
            )
        self.embedding = embedding
        self.embedding_function = embedding.embed_query
        self.index_uri = index_uri
        self.metric = metric
        self.config = config

        tiledb_vs, tiledb = (
            guard_import("tiledb.vector_search"),
            guard_import("tiledb"),
        )
        with tiledb.scope_ctx(ctx_or_config=config):
            index_group = tiledb.Group(self.index_uri, "r")
            self.vector_index_uri = (
                vector_index_uri
                if vector_index_uri != ""
                else get_vector_index_uri_from_group(index_group)
            )
            self.docs_array_uri = (
                docs_array_uri
                if docs_array_uri != ""
                else get_documents_array_uri_from_group(index_group)
            )
            index_group.close()
            group = tiledb.Group(self.vector_index_uri, "r")
            self.index_type = group.meta.get("index_type")
            group.close()
            self.timestamp = timestamp
            if self.index_type == "FLAT":
                self.vector_index = tiledb_vs.flat_index.FlatIndex(
                    uri=self.vector_index_uri,
                    config=self.config,
                    timestamp=self.timestamp,
                    **kwargs,
                )
            elif self.index_type == "IVF_FLAT":
                self.vector_index = tiledb_vs.ivf_flat_index.IVFFlatIndex(
                    uri=self.vector_index_uri,
                    config=self.config,
                    timestamp=self.timestamp,
                    **kwargs,
                )

    @property
    def embeddings(self) -> Optional[Embeddings]:
        return self.embedding

    def process_index_results(
        self,
        ids: List[int],
        scores: List[float],
        *,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        score_threshold: float = MAX_FLOAT,
    ) -> List[Tuple[Document, float]]:
        """Turns TileDB results into a list of documents and scores.

        Args:
            ids: List of indices of the documents in the index.
            scores: List of distances of the documents in the index.
            k: Number of Documents to return. Defaults to 4.
            filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
            score_threshold: Optional, a floating point value to filter the
                resulting set of retrieved docs
        Returns:
            List of Documents and scores.
        """
        tiledb = guard_import("tiledb")
        docs = []
        docs_array = tiledb.open(
            self.docs_array_uri, "r", timestamp=self.timestamp, config=self.config
        )
        for idx, score in zip(ids, scores):
            if idx == 0 and score == 0:
                continue
            if idx == MAX_UINT64 and score == MAX_FLOAT_32:
                continue
            doc = docs_array[idx]
            if doc is None or len(doc["text"]) == 0:
                raise ValueError(f"Could not find document for id {idx}, got {doc}")
            pickled_metadata = doc.get("metadata")
            result_doc = Document(page_content=str(doc["text"][0]))
            if pickled_metadata is not None:
                metadata = pickle.loads(  # ignore[pickle]: explicit-opt-in
                    np.array(pickled_metadata.tolist()).astype(np.uint8).tobytes()
                )
                result_doc.metadata = metadata
            if filter is not None:
                filter = {
                    key: [value] if not isinstance(value, list) else value
                    for key, value in filter.items()
                }
                if all(
                    result_doc.metadata.get(key) in value
                    for key, value in filter.items()
                ):
                    docs.append((result_doc, score))
            else:
                docs.append((result_doc, score))
        docs_array.close()
        docs = [(doc, score) for doc, score in docs if score <= score_threshold]
        return docs[:k]

    def similarity_search_with_score_by_vector(
        self,
        embedding: List[float],
        *,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        fetch_k: int = 20,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Return docs most similar to query.

        Args:
            embedding: Embedding vector to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
            fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
                      Defaults to 20.
            **kwargs: kwargs to be passed to similarity search. Can include:
                nprobe: Optional, number of partitions to check if using IVF_FLAT index
                score_threshold: Optional, a floating point value to filter the
                    resulting set of retrieved docs

        Returns:
            List of documents most similar to the query text and distance
            in float for each. Lower score represents more similarity.
        """
        if "score_threshold" in kwargs:
            score_threshold = kwargs.pop("score_threshold")
        else:
            score_threshold = MAX_FLOAT
        d, i = self.vector_index.query(
            np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
            k=k if filter is None else fetch_k,
            **kwargs,
        )
        return self.process_index_results(
            ids=i[0], scores=d[0], filter=filter, k=k, score_threshold=score_threshold
        )

    def similarity_search_with_score(
        self,
        query: str,
        *,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        fetch_k: int = 20,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Return docs most similar to query.

        Args:
            query: Text to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
            fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
                      Defaults to 20.

        Returns:
            List of documents most similar to the query text with
            Distance as float. Lower score represents more similarity.
        """
        embedding = self.embedding_function(query)
        docs = self.similarity_search_with_score_by_vector(
            embedding,
            k=k,
            filter=filter,
            fetch_k=fetch_k,
            **kwargs,
        )
        return docs

    def similarity_search_by_vector(
        self,
        embedding: List[float],
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        fetch_k: int = 20,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs most similar to embedding vector.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
            fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
                      Defaults to 20.

        Returns:
            List of Documents most similar to the embedding.
        """
        docs_and_scores = self.similarity_search_with_score_by_vector(
            embedding,
            k=k,
            filter=filter,
            fetch_k=fetch_k,
            **kwargs,
        )
        return [doc for doc, _ in docs_and_scores]

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        filter: Optional[Dict[str, Any]] = None,
        fetch_k: int = 20,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs most similar to query.

        Args:
            query: Text to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
            fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
                      Defaults to 20.

        Returns:
            List of Documents most similar to the query.
        """
        docs_and_scores = self.similarity_search_with_score(
            query, k=k, filter=filter, fetch_k=fetch_k, **kwargs
        )
        return [doc for doc, _ in docs_and_scores]

    def max_marginal_relevance_search_with_score_by_vector(
        self,
        embedding: List[float],
        *,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        filter: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Return docs and their similarity scores selected using the maximal marginal
            relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents and similarity scores selected by maximal marginal
                relevance and score for each.
        """
        if "score_threshold" in kwargs:
            score_threshold = kwargs.pop("score_threshold")
        else:
            score_threshold = MAX_FLOAT
        scores, indices = self.vector_index.query(
            np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
            k=fetch_k if filter is None else fetch_k * 2,
            **kwargs,
        )
        results = self.process_index_results(
            ids=indices[0],
            scores=scores[0],
            filter=filter,
            k=fetch_k if filter is None else fetch_k * 2,
            score_threshold=score_threshold,
        )
        embeddings = [
            self.embedding.embed_documents([doc.page_content])[0] for doc, _ in results
        ]
        mmr_selected = maximal_marginal_relevance(
            np.array([embedding], dtype=np.float32),
            embeddings,
            k=k,
            lambda_mult=lambda_mult,
        )
        docs_and_scores = []
        for i in mmr_selected:
            docs_and_scores.append(results[i])
        return docs_and_scores

    def max_marginal_relevance_search_by_vector(
        self,
        embedding: List[float],
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        filter: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
            embedding,
            k=k,
            fetch_k=fetch_k,
            lambda_mult=lambda_mult,
            filter=filter,
            **kwargs,
        )
        return [doc for doc, _ in docs_and_scores]

    def max_marginal_relevance_search(
        self,
        query: str,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        filter: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            query: Text to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering (if needed) to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        embedding = self.embedding_function(query)
        docs = self.max_marginal_relevance_search_by_vector(
            embedding,
            k=k,
            fetch_k=fetch_k,
            lambda_mult=lambda_mult,
            filter=filter,
            **kwargs,
        )
        return docs

    @classmethod
    def create(
        cls,
        index_uri: str,
        index_type: str,
        dimensions: int,
        vector_type: np.dtype,
        *,
        metadatas: bool = True,
        config: Optional[Mapping[str, Any]] = None,
    ) -> None:
        tiledb_vs, tiledb = (
            guard_import("tiledb.vector_search"),
            guard_import("tiledb"),
        )
        with tiledb.scope_ctx(ctx_or_config=config):
            try:
                tiledb.group_create(index_uri)
            except tiledb.TileDBError as err:
                raise err
            group = tiledb.Group(index_uri, "w")
            vector_index_uri = get_vector_index_uri(group.uri)
            docs_uri = get_documents_array_uri(group.uri)
            if index_type == "FLAT":
                tiledb_vs.flat_index.create(
                    uri=vector_index_uri,
                    dimensions=dimensions,
                    vector_type=vector_type,
                    config=config,
                )
            elif index_type == "IVF_FLAT":
                tiledb_vs.ivf_flat_index.create(
                    uri=vector_index_uri,
                    dimensions=dimensions,
                    vector_type=vector_type,
                    config=config,
                )
            group.add(vector_index_uri, name=VECTOR_INDEX_NAME)

            # Create TileDB array to store Documents
            # TODO add a Document store API to tiledb-vector-search to allow storing
            #  different types of objects and metadata in a more generic way.
            dim = tiledb.Dim(
                name="id",
                domain=(0, MAX_UINT64 - 1),
                dtype=np.dtype(np.uint64),
            )
            dom = tiledb.Domain(dim)

            text_attr = tiledb.Attr(name="text", dtype=np.dtype("U1"), var=True)
            attrs = [text_attr]
            if metadatas:
                metadata_attr = tiledb.Attr(name="metadata", dtype=np.uint8, var=True)
                attrs.append(metadata_attr)
            schema = tiledb.ArraySchema(
                domain=dom,
                sparse=True,
                allows_duplicates=False,
                attrs=attrs,
            )
            tiledb.Array.create(docs_uri, schema)
            group.add(docs_uri, name=DOCUMENTS_ARRAY_NAME)
            group.close()

    @classmethod
    def __from(
        cls,
        texts: List[str],
        embeddings: List[List[float]],
        embedding: Embeddings,
        index_uri: str,
        *,
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        metric: str = DEFAULT_METRIC,
        index_type: str = "FLAT",
        config: Optional[Mapping[str, Any]] = None,
        index_timestamp: int = 0,
        **kwargs: Any,
    ) -> TileDB:
        if metric not in INDEX_METRICS:
            raise ValueError(
                (
                    f"Unsupported distance metric: {metric}. "
                    f"Expected one of {list(INDEX_METRICS)}"
                )
            )
        tiledb_vs, tiledb = (
            guard_import("tiledb.vector_search"),
            guard_import("tiledb"),
        )
        

        # Check if embeddings is a list
        if isinstance(embeddings, list) and embeddings:
            # Determine if it's a list of lists or a flat list
            is_nested = all(isinstance(item, list) for item in embeddings)
            structure = "List of lists" if is_nested else "Flat list"

            # Analyze the first element
            first_item = embeddings[0]
            if is_nested:
                first_item_dtype = type(first_item[0]) if first_item else None
                try:
                    first_item_shape = np.array(first_item).shape
                except Exception as e:
                    first_item_shape = f"Cannot determine shape: {e}"
            else:
                first_item_dtype = type(first_item)
                first_item_shape = ()

            # Print results
            print("Structure:", structure)
            print("First item dtype:", first_item_dtype)
            print("First item shape:", first_item_shape)
        else:
            print("Embeddings is not a non-empty list.")


        
        input_vectors = np.array(embeddings).astype(np.float32)
        cls.create(
            index_uri=index_uri,
            index_type=index_type,
            dimensions=input_vectors.shape[1],
            vector_type=input_vectors.dtype,
            metadatas=metadatas is not None,
            config=config,
        )
        with tiledb.scope_ctx(ctx_or_config=config):
            if not embeddings:
                raise ValueError("embeddings must be provided to build a TileDB index")

            vector_index_uri = get_vector_index_uri(index_uri)
            docs_uri = get_documents_array_uri(index_uri)
            if ids is None:
                ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]
            external_ids = np.array(ids).astype(np.uint64)

            tiledb_vs.ingestion.ingest(
                index_type=index_type,
                index_uri=vector_index_uri,
                input_vectors=input_vectors,
                external_ids=external_ids,
                index_timestamp=index_timestamp if index_timestamp != 0 else None,
                config=config,
                **kwargs,
            )
            with tiledb.open(docs_uri, "w") as A:
                if external_ids is None:
                    external_ids = np.zeros(len(texts), dtype=np.uint64)
                    for i in range(len(texts)):
                        external_ids[i] = i
                data = {}
                data["text"] = np.array(texts)
                if metadatas is not None:
                    metadata_attr = np.empty([len(metadatas)], dtype=object)
                    i = 0
                    for metadata in metadatas:
                        metadata_attr[i] = np.frombuffer(
                            pickle.dumps(metadata), dtype=np.uint8
                        )
                        i += 1
                    data["metadata"] = metadata_attr

                A[external_ids] = data
        return cls(
            embedding=embedding,
            index_uri=index_uri,
            metric=metric,
            config=config,
            **kwargs,
        )

    def delete(
        self, ids: Optional[List[str]] = None, timestamp: int = 0, **kwargs: Any
    ) -> Optional[bool]:
        """Delete by vector ID or other criteria.

        Args:
            ids: List of ids to delete.
            timestamp: Optional timestamp to delete with.
            **kwargs: Other keyword arguments that subclasses might use.

        Returns:
            Optional[bool]: True if deletion is successful,
            False otherwise, None if not implemented.
        """

        external_ids = np.array(ids).astype(np.uint64)
        self.vector_index.delete_batch(
            external_ids=external_ids, timestamp=timestamp if timestamp != 0 else None
        )
        return True

    def add_texts(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        timestamp: int = 0,
        **kwargs: Any,
    ) -> List[str]:
        """Run more texts through the embeddings and add to the vectorstore.

        Args:
            texts: Iterable of strings to add to the vectorstore.
            metadatas: Optional list of metadatas associated with the texts.
            ids: Optional ids of each text object.
            timestamp: Optional timestamp to write new texts with.
            kwargs: vectorstore specific parameters

        Returns:
            List of ids from adding the texts into the vectorstore.
        """
        tiledb = guard_import("tiledb")
        embeddings = self.embedding.embed_documents(list(texts))
        if ids is None:
            ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]

        external_ids = np.array(ids).astype(np.uint64)
        vectors = np.empty((len(embeddings)), dtype="O")
        for i in range(len(embeddings)):
            vectors[i] = np.array(embeddings[i], dtype=np.float32)
        self.vector_index.update_batch(
            vectors=vectors,
            external_ids=external_ids,
            timestamp=timestamp if timestamp != 0 else None,
        )

        docs = {}
        docs["text"] = np.array(texts)
        if metadatas is not None:
            metadata_attr = np.empty([len(metadatas)], dtype=object)
            i = 0
            for metadata in metadatas:
                metadata_attr[i] = np.frombuffer(pickle.dumps(metadata), dtype=np.uint8)
                i += 1
            docs["metadata"] = metadata_attr

        docs_array = tiledb.open(
            self.docs_array_uri,
            "w",
            timestamp=timestamp if timestamp != 0 else None,
            config=self.config,
        )
        docs_array[external_ids] = docs
        docs_array.close()
        return ids

    @classmethod
    def from_texts(
        cls,
        texts: List[str],
        embedding: Embeddings,
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        metric: str = DEFAULT_METRIC,
        index_uri: str = "/tmp/tiledb_array",
        index_type: str = "FLAT",
        config: Optional[Mapping[str, Any]] = None,
        index_timestamp: int = 0,
        **kwargs: Any,
    ) -> TileDB:
        """Construct a TileDB index from raw documents.

        Args:
            texts: List of documents to index.
            embedding: Embedding function to use.
            metadatas: List of metadata dictionaries to associate with documents.
            ids: Optional ids of each text object.
            metric: Metric to use for indexing. Defaults to "euclidean".
            index_uri: The URI to write the TileDB arrays
            index_type: Optional,  Vector index type ("FLAT", IVF_FLAT")
            config: Optional, TileDB config
            index_timestamp: Optional, timestamp to write new texts with.

        Example:
            .. code-block:: python

                from langchain_community import TileDB
                from langchain_community.embeddings import OpenAIEmbeddings
                embeddings = OpenAIEmbeddings()
                index = TileDB.from_texts(texts, embeddings)
        """
        embeddings = []
        embeddings = embedding.embed_documents(texts)
        return cls.__from(
            texts=texts,
            embeddings=embeddings,
            embedding=embedding,
            metadatas=metadatas,
            ids=ids,
            metric=metric,
            index_uri=index_uri,
            index_type=index_type,
            config=config,
            index_timestamp=index_timestamp,
            **kwargs,
        )

    @classmethod
    def from_embeddings(
        cls,
        text_embeddings: List[Tuple[str, List[float]]],
        embedding: Embeddings,
        index_uri: str,
        *,
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        metric: str = DEFAULT_METRIC,
        index_type: str = "FLAT",
        config: Optional[Mapping[str, Any]] = None,
        index_timestamp: int = 0,
        **kwargs: Any,
    ) -> TileDB:
        """Construct TileDB index from embeddings.

        Args:
            text_embeddings: List of tuples of (text, embedding)
            embedding: Embedding function to use.
            index_uri: The URI to write the TileDB arrays
            metadatas: List of metadata dictionaries to associate with documents.
            metric: Optional, Metric to use for indexing. Defaults to "euclidean".
            index_type: Optional, Vector index type ("FLAT", IVF_FLAT")
            config: Optional, TileDB config
            index_timestamp: Optional, timestamp to write new texts with.

        Example:
            .. code-block:: python

                from langchain_community import TileDB
                from langchain_community.embeddings import OpenAIEmbeddings
                embeddings = OpenAIEmbeddings()
                text_embeddings = embeddings.embed_documents(texts)
                text_embedding_pairs = list(zip(texts, text_embeddings))
                db = TileDB.from_embeddings(text_embedding_pairs, embeddings)
        """
        texts = [t[0] for t in text_embeddings]
        embeddings = [t[1] for t in text_embeddings]

        return cls.__from(
            texts=texts,
            embeddings=embeddings,
            embedding=embedding,
            metadatas=metadatas,
            ids=ids,
            metric=metric,
            index_uri=index_uri,
            index_type=index_type,
            config=config,
            index_timestamp=index_timestamp,
            **kwargs,
        )

    @classmethod
    def load(
        cls,
        index_uri: str,
        embedding: Embeddings,
        *,
        metric: str = DEFAULT_METRIC,
        config: Optional[Mapping[str, Any]] = None,
        timestamp: Any = None,
        **kwargs: Any,
    ) -> TileDB:
        """Load a TileDB index from a URI.

        Args:
            index_uri: The URI of the TileDB vector index.
            embedding: Embeddings to use when generating queries.
            metric: Optional, Metric to use for indexing. Defaults to "euclidean".
            config: Optional, TileDB config
            timestamp: Optional, timestamp to use for opening the arrays.
        """
        return cls(
            embedding=embedding,
            index_uri=index_uri,
            metric=metric,
            config=config,
            timestamp=timestamp,
            **kwargs,
        )

    def consolidate_updates(self, **kwargs: Any) -> None:
        self.vector_index = self.vector_index.consolidate_updates(**kwargs)

@ihnorton ihnorton transferred this issue from TileDB-Inc/TileDB-Py Jan 8, 2025
@ihnorton
Copy link
Member

ihnorton commented Jan 8, 2025

I moved this issue over to tiledb-vector-search to keep track. We won't be able to work on this immediately, but we do have some vector-search improvements scheduled in the coming months and may be able to include this.


re: #561 (comment)

I noticed that the underlying c++ library doesn't support float16 or bfloat16...where would I put in an issue for that support as well?

You can raise a feature request here: https://github.com/TileDB-Inc/TileDB

@BBC-Esq
Copy link
Author

BBC-Esq commented Jan 8, 2025

did you give the correct link?

@ihnorton
Copy link
Member

Yes. (for clarity, I edited the comment to add a quote of the question I was responding to)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants