From 48277987a504cddc51861948c5555187d5b00cdb Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 4 Jun 2024 21:24:56 +0200 Subject: [PATCH] Refactor code to improve readability by reducing line breaks Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 71 ++++++------------- tests/load/lancedb/test_pipeline.py | 4 +- tests/load/lancedb/utils.py | 7 +- 3 files changed, 25 insertions(+), 57 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 88de20acb3..12ced96b36 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -166,9 +166,7 @@ def _make_qualified_table_name(self, table_name: str) -> str: def get_table_schema(self, table_name: str) -> pa.Schema: return cast(pa.Schema, self.db_client[table_name].schema) - def _create_table( - self, table_name: str, schema: Union[pa.Schema, LanceModel] - ) -> None: + def _create_table(self, table_name: str, schema: Union[pa.Schema, LanceModel]) -> None: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -176,9 +174,7 @@ def _create_table( table_name: The name of the table to create. """ - self.db_client.create_table( - table_name, schema=schema, embedding_functions=self.model_func - ) + self.db_client.create_table(table_name, schema=schema, embedding_functions=self.model_func) def delete_table(self, table_name: str) -> None: """Delete a LanceDB table. @@ -243,9 +239,7 @@ def add_to_table( Returns: None """ - self.db_client.open_table(table_name).add( - data, mode, on_bad_vectors, fill_value - ) + self.db_client.open_table(table_name).add(data, mode, on_bad_vectors, fill_value) def drop_storage(self) -> None: """Drop the dataset from the LanceDB instance. @@ -288,9 +282,7 @@ def is_storage_initialized(self) -> bool: def _create_sentinel_table(self) -> None: """Create an empty table to indicate that the storage is initialized.""" - self._create_table( - schema=cast(LanceModel, NullSchema), table_name=self.sentinel_table - ) + self._create_table(schema=cast(LanceModel, NullSchema), table_name=self.sentinel_table) def _delete_sentinel_table(self) -> None: """Delete the sentinel table.""" @@ -327,9 +319,7 @@ def _update_schema_in_storage(self, schema: Schema) -> None: "inserted_at": str(pendulum.now()), "schema": json.dumps(schema.to_dict()), } - version_table_name = self._make_qualified_table_name( - self.schema.version_table_name - ) + version_table_name = self._make_qualified_table_name(self.schema.version_table_name) self._create_record(properties, VersionSchema, version_table_name) def _create_record( @@ -345,9 +335,7 @@ def _create_record( try: tbl = self.db_client.open_table(self._make_qualified_table_name(table_name)) except FileNotFoundError: - tbl = self.db_client.create_table( - self._make_qualified_table_name(table_name) - ) + tbl = self.db_client.create_table(self._make_qualified_table_name(table_name)) except Exception: raise @@ -367,9 +355,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Loads compressed state from destination storage by finding a load ID that was completed.""" while True: try: - state_table_name = self._make_qualified_table_name( - self.schema.state_table_name - ) + state_table_name = self._make_qualified_table_name(self.schema.state_table_name) state_records = ( self.db_client.open_table(state_table_name) .search() @@ -381,9 +367,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: return None for state in state_records: load_id = state["_dlt_load_id"] - loads_table_name = self._make_qualified_table_name( - self.schema.loads_table_name - ) + loads_table_name = self._make_qualified_table_name(self.schema.loads_table_name) load_records = ( self.db_client.open_table(loads_table_name) .search() @@ -415,9 +399,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> StorageSchemaInfo: def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" try: - version_table_name = self._make_qualified_table_name( - self.schema.version_table_name - ) + version_table_name = self._make_qualified_table_name(self.schema.version_table_name) response = ( self.db_client[version_table_name] .search() @@ -454,9 +436,7 @@ def complete_load(self, load_id: str) -> None: def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def start_file_load( - self, table: TTableSchema, file_path: str, load_id: str - ) -> LoadJob: + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: return LoadLanceDBJob( self.schema, table, @@ -501,9 +481,7 @@ def __init__( self.table_name = table_name self.table_schema: TTableSchema = table_schema self.unique_identifiers = self._list_unique_identifiers(table_schema) - self.embedding_fields = get_columns_names_with_prop( - table_schema, VECTORIZE_HINT - ) + self.embedding_fields = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) self.embedding_model_func = model_func self.embedding_model_dimensions = client_config.embedding_model_dimensions @@ -562,21 +540,15 @@ def _upload_data( except Exception: raise - parsed_records: List[LanceModel] = [ - lancedb_model(**record) for record in records - ] + parsed_records: List[LanceModel] = [lancedb_model(**record) for record in records] # Upsert using reserved ID as the key. tbl.merge_insert( self.id_field_name - ).when_matched_update_all().when_not_matched_insert_all().execute( - parsed_records - ) + ).when_matched_update_all().when_not_matched_insert_all().execute(parsed_records) @staticmethod - def _generate_uuid( - data: DictStrAny, unique_identifiers: Sequence[str], table_name: str - ) -> str: + def _generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str: """Generates deterministic UUID - used for deduplication. Args: @@ -628,12 +600,15 @@ def _create_template_schema( Vector(embedding_model_dimensions or embedding_model_func.ndims()), ..., ) - return create_model( - "TemplateSchema", - __base__=LanceModel, - __module__=__name__, - __validators__={}, - **special_fields, + return cast( + TLanceModel, + create_model( # type: ignore[call-overload] + "TemplateSchema", + __base__=LanceModel, + __module__=__name__, + __validators__={}, + **special_fields, + ), ) def _make_field_schema(self, column_name: str, column: TColumnSchema) -> DictStrAny: diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 750b1fccf8..88d7432667 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -361,9 +361,7 @@ def test_empty_dataset_allowed() -> None: client: LanceDBClient = p.destination_client() # type: ignore[assignment] assert p.dataset_name is None - info = p.run( - lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]) - ) + info = p.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) # dataset in load info is empty assert info.dataset_name is None client = p.destination_client() # type: ignore[assignment] diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index 1bbdbd2138..9c7a1cade4 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -27,12 +27,7 @@ def assert_table( assert exists qualified_collection_name = client._make_qualified_table_name(collection_name) - records = ( - client.db_client.open_table(qualified_collection_name) - .search() - .limit(50) - .to_list() - ) + records = client.db_client.open_table(qualified_collection_name).search().limit(50).to_list() if expected_items_count is not None: assert expected_items_count == len(records)