Skip to content

Commit

Permalink
Update embedding model config and schema in LanceDB
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Jun 14, 2024
1 parent cf41859 commit bbc00a5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
5 changes: 2 additions & 3 deletions dlt/destinations/impl/lancedb/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration):
"""The model used by the embedding provider for generating embeddings.
Check with the embedding provider which options are available.
Reference https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/."""
embedding_model_dimensions: int = 1024
"""The dimensions of the embeddings generated. In most cases it will be automatically inferred,
embedding_model_dimensions: Optional[int] = None
"""The dimensions of the embeddings generated. In most cases it will be automatically inferred, by LanceDB,
but it is configurable in rare cases.
Make sure it corresponds with the associated embedding model's dimensionality."""
Expand All @@ -100,7 +100,6 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration):

__config_gen_annotations__: ClassVar[List[str]] = [
"embedding_model",
"embedding_model_dimensions",
"embedding_model_provider",
]

Expand Down
3 changes: 2 additions & 1 deletion dlt/destinations/impl/lancedb/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def make_arrow_table_schema(
arrow_schema.append(pa.field(id_field_name, pa.string()))

if embedding_fields:
# User's provided dimension config, if provided, takes precedence.
vec_size = embedding_model_dimensions or embedding_model_func.ndims()
arrow_schema.append(pa.field(vector_field_name, pa.list_(pa.float32(), vec_size)))

Expand All @@ -63,7 +64,7 @@ def make_arrow_table_schema(
arrow_schema.append(field)

metadata = {}
if embedding_model_func and embedding_fields:
if embedding_model_func:
# Get the registered alias if it exists, otherwise use the class name.
name = getattr(
embedding_model_func,
Expand Down

0 comments on commit bbc00a5

Please sign in to comment.