Skip to content

Commit

Permalink
Refactor code to improve readability by reducing line breaks
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Jun 4, 2024
1 parent bfcc8bb commit 4827798
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 57 deletions.
71 changes: 23 additions & 48 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,15 @@ def _make_qualified_table_name(self, table_name: str) -> str:
def get_table_schema(self, table_name: str) -> pa.Schema:
return cast(pa.Schema, self.db_client[table_name].schema)

def _create_table(
self, table_name: str, schema: Union[pa.Schema, LanceModel]
) -> None:
def _create_table(self, table_name: str, schema: Union[pa.Schema, LanceModel]) -> None:
"""Create a LanceDB Table from the provided LanceModel or PyArrow schema.
Args:
schema: The table schema to create.
table_name: The name of the table to create.
"""

self.db_client.create_table(
table_name, schema=schema, embedding_functions=self.model_func
)
self.db_client.create_table(table_name, schema=schema, embedding_functions=self.model_func)

def delete_table(self, table_name: str) -> None:
"""Delete a LanceDB table.
Expand Down Expand Up @@ -243,9 +239,7 @@ def add_to_table(
Returns:
None
"""
self.db_client.open_table(table_name).add(
data, mode, on_bad_vectors, fill_value
)
self.db_client.open_table(table_name).add(data, mode, on_bad_vectors, fill_value)

def drop_storage(self) -> None:
"""Drop the dataset from the LanceDB instance.
Expand Down Expand Up @@ -288,9 +282,7 @@ def is_storage_initialized(self) -> bool:

def _create_sentinel_table(self) -> None:
"""Create an empty table to indicate that the storage is initialized."""
self._create_table(
schema=cast(LanceModel, NullSchema), table_name=self.sentinel_table
)
self._create_table(schema=cast(LanceModel, NullSchema), table_name=self.sentinel_table)

def _delete_sentinel_table(self) -> None:
"""Delete the sentinel table."""
Expand Down Expand Up @@ -327,9 +319,7 @@ def _update_schema_in_storage(self, schema: Schema) -> None:
"inserted_at": str(pendulum.now()),
"schema": json.dumps(schema.to_dict()),
}
version_table_name = self._make_qualified_table_name(
self.schema.version_table_name
)
version_table_name = self._make_qualified_table_name(self.schema.version_table_name)
self._create_record(properties, VersionSchema, version_table_name)

def _create_record(
Expand All @@ -345,9 +335,7 @@ def _create_record(
try:
tbl = self.db_client.open_table(self._make_qualified_table_name(table_name))
except FileNotFoundError:
tbl = self.db_client.create_table(
self._make_qualified_table_name(table_name)
)
tbl = self.db_client.create_table(self._make_qualified_table_name(table_name))
except Exception:
raise

Expand All @@ -367,9 +355,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
"""Loads compressed state from destination storage by finding a load ID that was completed."""
while True:
try:
state_table_name = self._make_qualified_table_name(
self.schema.state_table_name
)
state_table_name = self._make_qualified_table_name(self.schema.state_table_name)
state_records = (
self.db_client.open_table(state_table_name)
.search()
Expand All @@ -381,9 +367,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
return None
for state in state_records:
load_id = state["_dlt_load_id"]
loads_table_name = self._make_qualified_table_name(
self.schema.loads_table_name
)
loads_table_name = self._make_qualified_table_name(self.schema.loads_table_name)
load_records = (
self.db_client.open_table(loads_table_name)
.search()
Expand Down Expand Up @@ -415,9 +399,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> StorageSchemaInfo:
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
try:
version_table_name = self._make_qualified_table_name(
self.schema.version_table_name
)
version_table_name = self._make_qualified_table_name(self.schema.version_table_name)
response = (
self.db_client[version_table_name]
.search()
Expand Down Expand Up @@ -454,9 +436,7 @@ def complete_load(self, load_id: str) -> None:
def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

def start_file_load(
self, table: TTableSchema, file_path: str, load_id: str
) -> LoadJob:
def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
return LoadLanceDBJob(
self.schema,
table,
Expand Down Expand Up @@ -501,9 +481,7 @@ def __init__(
self.table_name = table_name
self.table_schema: TTableSchema = table_schema
self.unique_identifiers = self._list_unique_identifiers(table_schema)
self.embedding_fields = get_columns_names_with_prop(
table_schema, VECTORIZE_HINT
)
self.embedding_fields = get_columns_names_with_prop(table_schema, VECTORIZE_HINT)
self.embedding_model_func = model_func
self.embedding_model_dimensions = client_config.embedding_model_dimensions

Expand Down Expand Up @@ -562,21 +540,15 @@ def _upload_data(
except Exception:
raise

parsed_records: List[LanceModel] = [
lancedb_model(**record) for record in records
]
parsed_records: List[LanceModel] = [lancedb_model(**record) for record in records]

# Upsert using reserved ID as the key.
tbl.merge_insert(
self.id_field_name
).when_matched_update_all().when_not_matched_insert_all().execute(
parsed_records
)
).when_matched_update_all().when_not_matched_insert_all().execute(parsed_records)

@staticmethod
def _generate_uuid(
data: DictStrAny, unique_identifiers: Sequence[str], table_name: str
) -> str:
def _generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str:
"""Generates deterministic UUID - used for deduplication.
Args:
Expand Down Expand Up @@ -628,12 +600,15 @@ def _create_template_schema(
Vector(embedding_model_dimensions or embedding_model_func.ndims()),
...,
)
return create_model(
"TemplateSchema",
__base__=LanceModel,
__module__=__name__,
__validators__={},
**special_fields,
return cast(
TLanceModel,
create_model( # type: ignore[call-overload]
"TemplateSchema",
__base__=LanceModel,
__module__=__name__,
__validators__={},
**special_fields,
),
)

def _make_field_schema(self, column_name: str, column: TColumnSchema) -> DictStrAny:
Expand Down
4 changes: 1 addition & 3 deletions tests/load/lancedb/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,7 @@ def test_empty_dataset_allowed() -> None:
client: LanceDBClient = p.destination_client() # type: ignore[assignment]

assert p.dataset_name is None
info = p.run(
lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])
)
info = p.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]))
# dataset in load info is empty
assert info.dataset_name is None
client = p.destination_client() # type: ignore[assignment]
Expand Down
7 changes: 1 addition & 6 deletions tests/load/lancedb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@ def assert_table(
assert exists

qualified_collection_name = client._make_qualified_table_name(collection_name)
records = (
client.db_client.open_table(qualified_collection_name)
.search()
.limit(50)
.to_list()
)
records = client.db_client.open_table(qualified_collection_name).search().limit(50).to_list()

if expected_items_count is not None:
assert expected_items_count == len(records)
Expand Down

0 comments on commit 4827798

Please sign in to comment.