Skip to content

Commit

Permalink
feat: snowflake hints
Browse files Browse the repository at this point in the history
  • Loading branch information
donotpush committed Dec 12, 2024
1 parent 77d8ab6 commit 9f0d67e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 4 deletions.
19 changes: 19 additions & 0 deletions dlt/destinations/impl/snowflake/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,25 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration)
query_tag: Optional[str] = None
"""A tag with placeholders to tag sessions executing jobs"""

# TODO: decide name - create_indexes vs create_constraints (create_indexes used in other destinations)
create_indexes: bool = False
"""Whether UNIQUE or PRIMARY KEY constrains should be created"""

def __init__(
self,
*,
credentials: SnowflakeCredentials = None,
create_indexes: bool = False,
destination_name: str = None,
environment: str = None,
) -> None:
super().__init__(
credentials=credentials,
destination_name=destination_name,
environment=environment,
)
self.create_indexes = create_indexes

def fingerprint(self) -> str:
"""Returns a fingerprint of host part of a connection string"""
if self.credentials and self.credentials.host:
Expand Down
16 changes: 12 additions & 4 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence, List
from typing import Optional, Sequence, List, Dict
from urllib.parse import urlparse, urlunparse

from dlt.common.data_writers.configuration import CsvFormatConfiguration
Expand All @@ -17,7 +17,7 @@
)
from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema
from dlt.common.schema import TColumnSchema, Schema, TColumnHint
from dlt.common.schema.typing import TColumnType

from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS
Expand All @@ -29,6 +29,8 @@
from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient
from dlt.destinations.job_impl import ReferenceFollowupJobRequest

SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE", "primary_key": "PRIMARY KEY"}


class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs):
def __init__(
Expand Down Expand Up @@ -238,6 +240,7 @@ def __init__(
self.config: SnowflakeClientConfiguration = config
self.sql_client: SnowflakeSqlClient = sql_client # type: ignore
self.type_mapper = self.capabilities.get_type_mapper()
self.active_hints = SUPPORTED_HINTS if self.config.create_indexes else {}

def create_load_job(
self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False
Expand Down Expand Up @@ -288,9 +291,14 @@ def _from_db_type(
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
hints_str = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
)

def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool:
Expand Down
36 changes: 36 additions & 0 deletions tests/load/snowflake/test_snowflake_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,42 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None:
assert '"COL10" DATE NOT NULL' in sql


def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None:
mod_update = deepcopy(TABLE_UPDATE)

mod_update[0]["primary_key"] = True
mod_update[0]["sort"] = True
mod_update[1]["unique"] = True
mod_update[4]["parent_key"] = True

sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, False))

assert sql.strip().startswith("CREATE TABLE")
assert "EVENT_TEST_TABLE" in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL5" VARCHAR' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL7" BINARY' in sql
assert '"COL8" NUMBER(38,0)' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE NOT NULL' in sql

# same thing with indexes
snowflake_client = snowflake().client(
snowflake_client.schema,
SnowflakeClientConfiguration(create_indexes=True)._bind_dataset_name(
dataset_name="test_" + uniq_id()
),
)
sql = snowflake_client._get_table_update_sql("event_test_table", mod_update, False)[0]
sqlfluff.parse(sql)
assert '"COL1" NUMBER(19,0) PRIMARY KEY NOT NULL' in sql
assert '"COL2" FLOAT UNIQUE NOT NULL' in sql


def test_alter_table(snowflake_client: SnowflakeClient) -> None:
statements = snowflake_client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)
assert len(statements) == 1
Expand Down

0 comments on commit 9f0d67e

Please sign in to comment.