Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow to select schema from pipeline dataset factory #2075

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def _destination_client(self, schema: Schema) -> JobClientBase:

def _ensure_client_and_schema(self) -> None:
"""Lazy load schema and client"""

# full schema given, nothing to do
if not self._schema and isinstance(self._provided_schema, Schema):
self._schema = self._provided_schema
Expand All @@ -259,6 +260,8 @@ def _ensure_client_and_schema(self) -> None:
stored_schema = client.get_stored_schema(self._provided_schema)
if stored_schema:
self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema))
else:
self._schema = Schema(self._provided_schema)

# no schema name given, load newest schema from destination
elif not self._schema:
Expand All @@ -268,7 +271,7 @@ def _ensure_client_and_schema(self) -> None:
if stored_schema:
self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema))

# default to empty schema with dataset name if nothing found
# default to empty schema with dataset name
if not self._schema:
self._schema = Schema(self._dataset_name)

Expand Down
9 changes: 7 additions & 2 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
cast,
get_type_hints,
ContextManager,
Union,
)

from dlt import version
Expand Down Expand Up @@ -1790,11 +1791,15 @@ def __getstate__(self) -> Any:
# pickle only the SupportsPipeline protocol fields
return {"pipeline_name": self.pipeline_name}

def _dataset(self, dataset_type: TDatasetType = "dbapi") -> SupportsReadableDataset:
def _dataset(
self, schema: Union[Schema, str, None] = None, dataset_type: TDatasetType = "dbapi"
) -> SupportsReadableDataset:
"""Access helper to dataset"""
if schema is None:
schema = self.default_schema if self.default_schema_name else None
return dataset(
self._destination,
self.dataset_name,
schema=(self.default_schema if self.default_schema_name else None),
schema=schema,
dataset_type=dataset_type,
)
30 changes: 29 additions & 1 deletion tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,34 @@ def test_column_selection(populated_pipeline: Pipeline) -> None:
arrow_table = table_relationship.select("unknown_column").head().arrow()


@pytest.mark.no_load
@pytest.mark.essential
@pytest.mark.parametrize(
"populated_pipeline",
configs,
indirect=True,
ids=lambda x: x.name,
)
def test_schema_arg(populated_pipeline: Pipeline) -> None:
"""Simple test to ensure schemas may be selected via schema arg"""

# if there is no arg, the defautl schema is used
dataset = populated_pipeline._dataset()
assert dataset.schema.name == populated_pipeline.default_schema_name # type: ignore
assert "items" in dataset.schema.tables # type: ignore

# setting a different schema name will try to load that schema,
# not find one and create an empty schema with that name
dataset = populated_pipeline._dataset(schema="unknown_schema")
assert dataset.schema.name == "unknown_schema" # type: ignore
assert "items" not in dataset.schema.tables # type: ignore

# providing the schema name of the right schema will load it
dataset = populated_pipeline._dataset(schema=populated_pipeline.default_schema_name)
assert dataset.schema.name == populated_pipeline.default_schema_name # type: ignore
assert "items" in dataset.schema.tables # type: ignore


@pytest.mark.no_load
@pytest.mark.essential
@pytest.mark.parametrize(
Expand Down Expand Up @@ -422,7 +450,7 @@ def test_standalone_dataset(populated_pipeline: Pipeline) -> None:
),
)
assert "items" not in dataset.schema.tables
assert dataset.schema.name == populated_pipeline.dataset_name
assert dataset.schema.name == "wrong_schema_name"

# check that schema is loaded if no schema name given
dataset = cast(
Expand Down
Loading