From 9f0d67e43a556f23517f0244e82bec74e44a7481 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:45:42 +0100 Subject: [PATCH] feat: snowflake hints --- .../impl/snowflake/configuration.py | 19 ++++++++++ dlt/destinations/impl/snowflake/snowflake.py | 16 ++++++--- .../snowflake/test_snowflake_table_builder.py | 36 +++++++++++++++++++ 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 4a89a1564b..4355edb09c 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -138,6 +138,25 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration) query_tag: Optional[str] = None """A tag with placeholders to tag sessions executing jobs""" + # TODO: decide name - create_indexes vs create_constraints (create_indexes used in other destinations) + create_indexes: bool = False + """Whether UNIQUE or PRIMARY KEY constrains should be created""" + + def __init__( + self, + *, + credentials: SnowflakeCredentials = None, + create_indexes: bool = False, + destination_name: str = None, + environment: str = None, + ) -> None: + super().__init__( + credentials=credentials, + destination_name=destination_name, + environment=environment, + ) + self.create_indexes = create_indexes + def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" if self.credentials and self.credentials.host: diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index e5146139f2..c6220fd65e 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, List +from typing import Optional, Sequence, List, Dict from urllib.parse import urlparse, urlunparse from dlt.common.data_writers.configuration import CsvFormatConfiguration @@ -17,7 +17,7 @@ ) from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema import TColumnSchema, Schema +from dlt.common.schema import TColumnSchema, Schema, TColumnHint from dlt.common.schema.typing import TColumnType from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS @@ -29,6 +29,8 @@ from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import ReferenceFollowupJobRequest +SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE", "primary_key": "PRIMARY KEY"} + class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( @@ -238,6 +240,7 @@ def __init__( self.config: SnowflakeClientConfiguration = config self.sql_client: SnowflakeSqlClient = sql_client # type: ignore self.type_mapper = self.capabilities.get_type_mapper() + self.active_hints = SUPPORTED_HINTS if self.config.create_indexes else {} def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False @@ -288,9 +291,14 @@ def _from_db_type( return self.type_mapper.from_destination_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) + hints_str = " ".join( + self.active_hints.get(h, "") + for h in self.active_hints.keys() + if c.get(h, False) is True + ) + column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index 1fc0034f43..4b8c4e1b2a 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -78,6 +78,42 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None: assert '"COL10" DATE NOT NULL' in sql +def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: + mod_update = deepcopy(TABLE_UPDATE) + + mod_update[0]["primary_key"] = True + mod_update[0]["sort"] = True + mod_update[1]["unique"] = True + mod_update[4]["parent_key"] = True + + sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, False)) + + assert sql.strip().startswith("CREATE TABLE") + assert "EVENT_TEST_TABLE" in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL5" VARCHAR' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL7" BINARY' in sql + assert '"COL8" NUMBER(38,0)' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql + + # same thing with indexes + snowflake_client = snowflake().client( + snowflake_client.schema, + SnowflakeClientConfiguration(create_indexes=True)._bind_dataset_name( + dataset_name="test_" + uniq_id() + ), + ) + sql = snowflake_client._get_table_update_sql("event_test_table", mod_update, False)[0] + sqlfluff.parse(sql) + assert '"COL1" NUMBER(19,0) PRIMARY KEY NOT NULL' in sql + assert '"COL2" FLOAT UNIQUE NOT NULL' in sql + + def test_alter_table(snowflake_client: SnowflakeClient) -> None: statements = snowflake_client._get_table_update_sql("event_test_table", TABLE_UPDATE, True) assert len(statements) == 1