From bbc00a5237f08d18a5d81c04caf6cd25bf84fe94 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 14 Jun 2024 16:31:40 +0200 Subject: [PATCH] Update embedding model config and schema in LanceDB Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/configuration.py | 5 ++--- dlt/destinations/impl/lancedb/schema.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index 796aa0b0cc..c5236deb75 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -85,8 +85,8 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration): """The model used by the embedding provider for generating embeddings. Check with the embedding provider which options are available. Reference https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/.""" - embedding_model_dimensions: int = 1024 - """The dimensions of the embeddings generated. In most cases it will be automatically inferred, + embedding_model_dimensions: Optional[int] = None + """The dimensions of the embeddings generated. In most cases it will be automatically inferred, by LanceDB, but it is configurable in rare cases. Make sure it corresponds with the associated embedding model's dimensionality.""" @@ -100,7 +100,6 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration): __config_gen_annotations__: ClassVar[List[str]] = [ "embedding_model", - "embedding_model_dimensions", "embedding_model_provider", ] diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index 14b505a90e..66d0cdaec6 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -55,6 +55,7 @@ def make_arrow_table_schema( arrow_schema.append(pa.field(id_field_name, pa.string())) if embedding_fields: + # User's provided dimension config, if provided, takes precedence. vec_size = embedding_model_dimensions or embedding_model_func.ndims() arrow_schema.append(pa.field(vector_field_name, pa.list_(pa.float32(), vec_size))) @@ -63,7 +64,7 @@ def make_arrow_table_schema( arrow_schema.append(field) metadata = {} - if embedding_model_func and embedding_fields: + if embedding_model_func: # Get the registered alias if it exists, otherwise use the class name. name = getattr( embedding_model_func,