From 4d32c5fdbeb96b832ec9a07069f62e2c9d49dbd1 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 19 Nov 2024 10:24:01 +0100 Subject: [PATCH 1/2] allow to select schema from pipeline dataset factory --- dlt/destinations/dataset.py | 5 ++++- dlt/pipeline/pipeline.py | 9 +++++++-- tests/load/test_read_interfaces.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index cffdc0f059..5d0cfec169 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -248,6 +248,7 @@ def _destination_client(self, schema: Schema) -> JobClientBase: def _ensure_client_and_schema(self) -> None: """Lazy load schema and client""" + # full schema given, nothing to do if not self._schema and isinstance(self._provided_schema, Schema): self._schema = self._provided_schema @@ -259,6 +260,8 @@ def _ensure_client_and_schema(self) -> None: stored_schema = client.get_stored_schema(self._provided_schema) if stored_schema: self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) + else: + self._schema = Schema(self._provided_schema) # no schema name given, load newest schema from destination elif not self._schema: @@ -268,7 +271,7 @@ def _ensure_client_and_schema(self) -> None: if stored_schema: self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema)) - # default to empty schema with dataset name if nothing found + # default to empty schema with dataset name if not self._schema: self._schema = Schema(self._dataset_name) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 037458f9c1..51443ef617 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -15,6 +15,7 @@ cast, get_type_hints, ContextManager, + Union, ) from dlt import version @@ -1790,11 +1791,15 @@ def __getstate__(self) -> Any: # pickle only the SupportsPipeline protocol fields return {"pipeline_name": self.pipeline_name} - def _dataset(self, dataset_type: TDatasetType = "dbapi") -> SupportsReadableDataset: + def _dataset( + self, schema: Union[Schema, str, None] = None, dataset_type: TDatasetType = "dbapi" + ) -> SupportsReadableDataset: """Access helper to dataset""" + if schema is None: + schema = self.default_schema if self.default_schema_name else None return dataset( self._destination, self.dataset_name, - schema=(self.default_schema if self.default_schema_name else None), + schema=schema, dataset_type=dataset_type, ) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index c6019ecf2d..6af02d42e0 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -379,6 +379,34 @@ def test_column_selection(populated_pipeline: Pipeline) -> None: arrow_table = table_relationship.select("unknown_column").head().arrow() +@pytest.mark.no_load +@pytest.mark.essential +@pytest.mark.parametrize( + "populated_pipeline", + configs, + indirect=True, + ids=lambda x: x.name, +) +def test_schema_arg(populated_pipeline: Pipeline) -> None: + """Simple test to ensure schemas may be selected via schema arg""" + + # if there is no arg, the defautl schema is used + dataset = populated_pipeline._dataset() + assert dataset.schema.name == populated_pipeline.default_schema_name # type: ignore + assert "items" in dataset.schema.tables # type: ignore + + # setting a different schema name will try to load that schema, + # not find one and create an empty schema with that name + dataset = populated_pipeline._dataset(schema="unknown_schema") + assert dataset.schema.name == "unknown_schema" # type: ignore + assert "items" not in dataset.schema.tables # type: ignore + + # providing the schema name of the right schema will load it + dataset = populated_pipeline._dataset(schema=populated_pipeline.default_schema_name) + assert dataset.schema.name == populated_pipeline.default_schema_name # type: ignore + assert "items" in dataset.schema.tables # type: ignore + + @pytest.mark.no_load @pytest.mark.essential @pytest.mark.parametrize( From 0d59ab5b30d17a69f90822eb3b432837d471e820 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 19 Nov 2024 10:29:14 +0100 Subject: [PATCH 2/2] fix existing test --- tests/load/test_read_interfaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index 6af02d42e0..aac32875a5 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -450,7 +450,7 @@ def test_standalone_dataset(populated_pipeline: Pipeline) -> None: ), ) assert "items" not in dataset.schema.tables - assert dataset.schema.name == populated_pipeline.dataset_name + assert dataset.schema.name == "wrong_schema_name" # check that schema is loaded if no schema name given dataset = cast(