Skip to content

Commit

Permalink
Merge pull request #675 from dbt-labs/remove-non-production-methods-f…
Browse files Browse the repository at this point in the history
…rom-sql-client

Remove DDL and other test-only methods from SqlClient protocol
  • Loading branch information
tlento authored Jul 26, 2023
2 parents 1a5611f + 940aa09 commit 98a3e4a
Show file tree
Hide file tree
Showing 19 changed files with 232 additions and 260 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230724-152228.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Remove DDL and other unused methods from SqlClient protocol
time: 2023-07-24T15:22:28.930711-07:00
custom:
Author: tlento
Issue: None
130 changes: 0 additions & 130 deletions metricflow/cli/dbt_connectors/adapter_backed_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@

import enum
import logging
import textwrap
import time
from typing import Optional, Sequence

import pandas as pd
from dbt.adapters.base.impl import BaseAdapter
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.pretty_print import pformat_big_objects

from metricflow.dataflow.sql_table import SqlTable
from metricflow.errors.errors import SqlBindParametersNotSupportedError
from metricflow.logging.formatting import indent_log_line
from metricflow.protocols.sql_client import SqlEngine
Expand Down Expand Up @@ -189,133 +186,6 @@ def dry_run(
logger.info(f"Finished running the dry_run in {stop - start:.2f}s")
return

def create_table_from_dataframe(
self,
sql_table: SqlTable,
df: pd.DataFrame,
chunk_size: Optional[int] = None,
) -> None:
"""Create a table in the data warehouse containing the contents of the dataframe.
Only used in tutorials and tests.
Args:
sql_table: The SqlTable object representing the table location to use
df: The Pandas DataFrame object containing the column schema and data to load
chunk_size: The number of rows to insert per transaction
"""
logger.info(f"Creating table '{sql_table.sql}' from a DataFrame with {df.shape[0]} row(s)")
start_time = time.time()
with self._adapter.connection_named("MetricFlow_create_from_dataframe"):
# Create table
# update dtypes to convert None to NA in boolean columns.
# This mirrors the SQLAlchemy schema detection logic in pandas.io.sql
df = df.convert_dtypes()
columns = df.columns
columns_to_insert = []
for i in range(len(df.columns)):
# Format as "column_name column_type"
columns_to_insert.append(
f"{columns[i]} {self._get_type_from_pandas_dtype(str(df[columns[i]].dtype).lower())}"
)
self._adapter.execute(
f"CREATE TABLE IF NOT EXISTS {sql_table.sql} ({', '.join(columns_to_insert)})",
auto_begin=True,
fetch=False,
)
self._adapter.commit_if_has_connection()

# Insert rows
values = []
for row in df.itertuples(index=False, name=None):
cells = []
for cell in row:
if pd.isnull(cell):
# use null keyword instead of isNA/None/etc.
cells.append("null")
elif type(cell) in [str, pd.Timestamp]:
# Wrap cell in quotes & escape existing single quotes
escaped_cell = str(cell).replace("'", "''")
cells.append(f"'{escaped_cell}'")
else:
cells.append(str(cell))

values.append(f"({', '.join(cells)})")
if chunk_size and len(values) == chunk_size:
value_string = ",\n".join(values)
self._adapter.execute(
f"INSERT INTO {sql_table.sql} VALUES {value_string}", auto_begin=True, fetch=False
)
values = []
if values:
value_string = ",\n".join(values)
self._adapter.execute(
f"INSERT INTO {sql_table.sql} VALUES {value_string}", auto_begin=True, fetch=False
)
# Commit all insert transaction at once
self._adapter.commit_if_has_connection()

logger.info(f"Created table '{sql_table.sql}' from a DataFrame in {time.time() - start_time:.2f}s")

def _get_type_from_pandas_dtype(self, dtype: str) -> str:
"""Helper method to get the engine-specific type value.
The dtype dict here is non-exhaustive but should be adequate for our needs.
"""
# TODO: add type handling for string/bool/bigint types for all engines
if dtype == "string" or dtype == "object":
return "text"
elif dtype == "boolean" or dtype == "bool":
return "boolean"
elif dtype == "int64":
return "bigint"
elif dtype == "float64":
return self._sql_query_plan_renderer.expr_renderer.double_data_type
elif dtype == "datetime64[ns]":
return self._sql_query_plan_renderer.expr_renderer.timestamp_data_type
else:
raise ValueError(f"Encountered unexpected Pandas dtype ({dtype})!")

def list_tables(self, schema_name: str) -> Sequence[str]:
"""Get a list of the table names in a given schema. Only used in tutorials and tests."""
# TODO: Short term, make this work with as many engines as possible. Medium term, remove this altogether.
if self.sql_engine_type is SqlEngine.SNOWFLAKE:
# Snowflake likes capitalizing things, except when it doesn't. We can get away with this due to its
# limited scope of usage.
schema_name = schema_name.upper()

df = self.query(
textwrap.dedent(
f"""\
SELECT table_name FROM information_schema.tables
WHERE table_schema = '{schema_name}'
"""
),
)
if df.empty:
return []

# Lower casing table names and data frame names for consistency between Snowflake and other clients.
# As above, we can do this because it isn't used in any consequential situations.
df.columns = df.columns.str.lower()
return [t.lower() for t in df["table_name"]]

def table_exists(self, sql_table: SqlTable) -> bool:
"""Check if a given table exists. Only used in tutorials and tests."""
return sql_table.table_name in self.list_tables(sql_table.schema_name)

def create_schema(self, schema_name: str) -> None:
"""Create the given schema in a data warehouse. Only used in tutorials and tests."""
self.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}")

def drop_schema(self, schema_name: str, cascade: bool = True) -> None:
"""Drop the given schema from the data warehouse. Only used in tests."""
self.execute(f"DROP SCHEMA IF EXISTS {schema_name}{' CASCADE' if cascade else ''}")

def drop_table(self, sql_table: SqlTable) -> None:
"""Drop the given table from the data warehouse. Only used in tutorials and tests."""
self.execute(f"DROP TABLE IF EXISTS {sql_table.sql}")

def close(self) -> None: # noqa: D
self._adapter.cancel_open_connections()

Expand Down
2 changes: 1 addition & 1 deletion metricflow/execution/execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D
def execute(self) -> TaskExecutionResult: # noqa: D
start_time = time.time()
logger.info(f"Dropping table {self._output_table} in case it already exists")
self._sql_client.drop_table(self._output_table)
self._sql_client.execute(f"DROP TABLE IF EXISTS {self._output_table.sql}")
logger.info(f"Creating table {self._output_table} using a SELECT query")
sql_query = self.sql_query
assert sql_query
Expand Down
44 changes: 1 addition & 43 deletions metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from abc import abstractmethod
from enum import Enum
from typing import Optional, Protocol, Sequence
from typing import Protocol

from pandas import DataFrame

from metricflow.dataflow.sql_table import SqlTable
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_request.sql_request_attributes import SqlJsonTag
Expand Down Expand Up @@ -48,22 +47,6 @@ def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer:
"""
raise NotImplementedError

@abstractmethod
def create_table_from_dataframe(
self,
sql_table: SqlTable,
df: DataFrame,
chunk_size: Optional[int] = None,
) -> None:
"""Creates a table and populates it with the contents of the dataframe.
Args:
sql_table: The SqlTable metadata of the table to create
df: The Pandas DataFrame with the contents of the target table
chunk_size: The number of rows to write per query
"""
raise NotImplementedError

@abstractmethod
def query(
self,
Expand Down Expand Up @@ -93,31 +76,6 @@ def dry_run(
"""Base dry_run method."""
raise NotImplementedError

@abstractmethod
def list_tables(self, schema_name: str) -> Sequence[str]:
"""List the tables in the given schema."""
raise NotImplementedError

@abstractmethod
def table_exists(self, sql_table: SqlTable) -> bool:
"""Determines whether or not the given table exists in the data warehouse."""
raise NotImplementedError

@abstractmethod
def drop_table(self, sql_table: SqlTable) -> None:
"""Drop the given table from the data warehouse."""
raise NotImplementedError

@abstractmethod
def create_schema(self, schema_name: str) -> None:
"""Create the given schema if it doesn't already exist."""
raise NotImplementedError

@abstractmethod
def drop_schema(self, schema_name: str, cascade: bool) -> None: # noqa: D
"""Drop the given schema if it exists. If cascade is set, drop the tables in the schema first."""
raise NotImplementedError

@abstractmethod
def close(self) -> None: # noqa: D
"""Close the connections / engines used by this client."""
Expand Down
3 changes: 1 addition & 2 deletions metricflow/test/execution/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def test_write_table_task(mf_test_session_state: MetricFlowTestSessionState, sql
results = SequentialPlanExecutor().execute_plan(execution_plan)

assert not results.contains_task_errors
assert sql_client.table_exists(output_table)

assert_dataframes_equal(
actual=sql_client.query(f"SELECT * FROM {output_table.sql}"),
Expand All @@ -59,4 +58,4 @@ def test_write_table_task(mf_test_session_state: MetricFlowTestSessionState, sql
),
compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE,
)
sql_client.drop_table(output_table)
sql_client.execute(f"DROP TABLE IF EXISTS {output_table.sql}")
26 changes: 20 additions & 6 deletions metricflow/test/fixtures/sql_client_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from dbt.adapters.factory import get_adapter_by_type
from dbt.cli.main import dbtRunner

from metricflow.cli.dbt_connectors.adapter_backed_client import AdapterBackedSqlClient
from metricflow.protocols.sql_client import SqlClient
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState, dialect_from_url
from metricflow.test.fixtures.sql_clients.adapter_backed_ddl_client import AdapterBackedDDLSqlClient
from metricflow.test.fixtures.sql_clients.big_query import BigQuerySqlClient
from metricflow.test.fixtures.sql_clients.common_client import SqlDialect
from metricflow.test.fixtures.sql_clients.databricks import DatabricksSqlClient
from metricflow.test.fixtures.sql_clients.ddl_sql_client import SqlClientWithDDLMethods
from metricflow.test.fixtures.sql_clients.duckdb import DuckDbSqlClient
from metricflow.test.fixtures.sql_clients.redshift import RedshiftSqlClient

Expand Down Expand Up @@ -66,7 +67,7 @@ def __initialize_dbt() -> None:
dbtRunner().invoke(["-q", "debug"], project_dir=dbt_dir, PROFILES_DIR=dbt_dir)


def make_test_sql_client(url: str, password: str, schema: str) -> SqlClient:
def make_test_sql_client(url: str, password: str, schema: str) -> SqlClientWithDDLMethods:
"""Build SQL client based on env configs."""
# TODO: Switch on an enum of adapter type when all engines are cut over
dialect = dialect_from_url(url=url)
Expand All @@ -80,13 +81,13 @@ def make_test_sql_client(url: str, password: str, schema: str) -> SqlClient:
assert len(warehouses) == 1, f"Found more than 1 warehouse in Snowflake URL: `{warehouses}`"
os.environ[MF_SQL_ENGINE_WAREHOUSE] = warehouses[0]
__initialize_dbt()
return AdapterBackedSqlClient(adapter=get_adapter_by_type("snowflake"))
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("snowflake"))
elif dialect == SqlDialect.BIGQUERY:
return BigQuerySqlClient.from_connection_details(url, password)
elif dialect == SqlDialect.POSTGRESQL:
configure_test_env_from_url(url, schema)
__initialize_dbt()
return AdapterBackedSqlClient(adapter=get_adapter_by_type("postgres"))
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("postgres"))
elif dialect == SqlDialect.DUCKDB:
return DuckDbSqlClient.from_connection_details(url, password)
elif dialect == SqlDialect.DATABRICKS:
Expand All @@ -96,8 +97,12 @@ def make_test_sql_client(url: str, password: str, schema: str) -> SqlClient:


@pytest.fixture(scope="session")
def sql_client(mf_test_session_state: MetricFlowTestSessionState) -> Generator[SqlClient, None, None]:
"""Provides an SqlClient requiring warehouse access."""
def ddl_sql_client(mf_test_session_state: MetricFlowTestSessionState) -> Generator[SqlClientWithDDLMethods, None, None]:
"""Provides a SqlClient with the necessary DDL and data loading methods for test configuration.
This allows us to provide the operations necessary for executing the test suite without exposing those methods in
MetricFlow's core SqlClient protocol.
"""
sql_client = make_test_sql_client(
url=mf_test_session_state.sql_engine_url,
password=mf_test_session_state.sql_engine_password,
Expand All @@ -122,3 +127,12 @@ def sql_client(mf_test_session_state: MetricFlowTestSessionState) -> Generator[S

sql_client.close()
return None


@pytest.fixture(scope="session")
def sql_client(ddl_sql_client: SqlClientWithDDLMethods) -> SqlClient:
"""Provides a standard SqlClient instance for running MetricFlow tests.
Unless the test case itself requires the DDL methods, this is the fixture we should use.
"""
return ddl_sql_client
Loading

0 comments on commit 98a3e4a

Please sign in to comment.