Skip to content

Commit

Permalink
Refactor!: make when_matched syntax compatible with merge syntax (#3497)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Dec 12, 2024
1 parent a637248 commit c512e63
Show file tree
Hide file tree
Showing 13 changed files with 323 additions and 183 deletions.
10 changes: 7 additions & 3 deletions docs/concepts/models/model_kinds.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ MODEL (
name db.employees,
kind INCREMENTAL_BY_UNIQUE_KEY (
unique_key name,
when_matched WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)
when_matched (
WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)
)
)
);
```
Expand All @@ -334,8 +336,10 @@ MODEL (
name db.employees,
kind INCREMENTAL_BY_UNIQUE_KEY (
unique_key name,
when_matched WHEN MATCHED AND source.value IS NULL THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary),
WHEN MATCHED THEN UPDATE SET target.title = COALESCE(source.title, target.title)
when_matched (
WHEN MATCHED AND source.value IS NULL THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)
WHEN MATCHED THEN UPDATE SET target.title = COALESCE(source.title, target.title)
)
)
);
```
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"rich[jupyter]",
"ruamel.yaml",
"setuptools; python_version>='3.12'",
"sqlglot[rs]~=25.34.1",
"sqlglot[rs]~=26.0.0",
"tenacity",
],
extras_require={
Expand Down
27 changes: 16 additions & 11 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,12 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
return None

name = key.name.lower()
if name == "when_matched":
value: t.Optional[t.Union[exp.Expression, t.List[exp.Expression]]] = (
self._parse_when_matched() # type: ignore
)
elif name == "time_data_type":
if name == "time_data_type":
# TODO: if we make *_data_type a convention to parse things into exp.DataType, we could make this more generic
value = self._parse_types(schema=True)
elif name == "when_matched":
# Parentheses around the WHEN clauses can be used to disambiguate them from other properties
value = self._parse_wrapped(self._parse_when_matched, optional=True)
elif self._match(TokenType.L_PAREN):
value = self.expression(exp.Tuple, expressions=self._parse_csv(self._parse_equality))
self._match_r_paren()
Expand Down Expand Up @@ -605,15 +604,11 @@ def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
size = len(expressions)

for i, prop in enumerate(expressions):
value = prop.args.get("value")
if prop.name == "when_matched" and isinstance(value, list):
output_value = ", ".join(self.sql(v) for v in value)
else:
output_value = self.sql(prop, "value")
sql = self.indent(f"{prop.name} {output_value}")
sql = self.indent(f"{prop.name} {self.sql(prop, 'value')}")

if i < size - 1:
sql += ","

props.append(self.maybe_comment(sql, expression=prop))

return "\n".join(props)
Expand Down Expand Up @@ -648,6 +643,15 @@ def _macro_func_sql(self: Generator, expression: MacroFunc) -> str:
return self.maybe_comment(sql, expression)


def _whens_sql(self: Generator, expression: exp.Whens) -> str:
if isinstance(expression.parent, exp.Merge):
return self.whens_sql(expression)

# If the `WHEN` clauses aren't part of a MERGE statement (e.g. they
# appear in the `MODEL` DDL), then we will wrap them with parentheses.
return self.wrap(self.expressions(expression, sep=" ", indent=False))


def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None:
name = func.__name__
setattr(klass, f"_{name}", getattr(klass, name))
Expand Down Expand Up @@ -901,6 +905,7 @@ def extend_sqlglot() -> None:
ModelKind: _model_kind_sql,
PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False),
StagedFilePath: lambda self, e: self.table_sql(e),
exp.Whens: _whens_sql,
}
)

Expand Down
42 changes: 21 additions & 21 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,20 +1337,13 @@ def _merge(
target_table: TableName,
query: Query,
on: exp.Expression,
match_expressions: t.List[exp.When],
whens: exp.Whens,
) -> None:
this = exp.alias_(exp.to_table(target_table), alias=MERGE_TARGET_ALIAS, table=True)
using = exp.alias_(
exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True
)
self.execute(
exp.Merge(
this=this,
using=using,
on=on,
expressions=match_expressions,
)
)
self.execute(exp.Merge(this=this, using=using, on=on, whens=whens))

def scd_type_2_by_time(
self,
Expand Down Expand Up @@ -1807,7 +1800,7 @@ def merge(
source_table: QueryOrDF,
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
unique_key: t.Sequence[exp.Expression],
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
when_matched: t.Optional[exp.Whens] = None,
) -> None:
source_queries, columns_to_types = self._get_source_queries_and_columns_to_types(
source_table, columns_to_types, target_table=target_table
Expand All @@ -1820,17 +1813,23 @@ def merge(
)
)
if not when_matched:
when_matched = exp.When(
matched=True,
source=False,
then=exp.Update(
expressions=[
exp.column(col, MERGE_TARGET_ALIAS).eq(exp.column(col, MERGE_SOURCE_ALIAS))
for col in columns_to_types
],
when_matched = exp.Whens()
when_matched.append(
"expressions",
exp.When(
matched=True,
source=False,
then=exp.Update(
expressions=[
exp.column(col, MERGE_TARGET_ALIAS).eq(
exp.column(col, MERGE_SOURCE_ALIAS)
)
for col in columns_to_types
],
),
),
)
when_matched = ensure_list(when_matched)

when_not_matched = exp.When(
matched=False,
source=False,
Expand All @@ -1841,14 +1840,15 @@ def merge(
),
),
)
match_expressions = when_matched + [when_not_matched]
when_matched.append("expressions", when_not_matched)

for source_query in source_queries:
with source_query as query:
self._merge(
target_table=target_table,
query=query,
on=on,
match_expressions=match_expressions,
whens=when_matched,
)

def rename_table(
Expand Down
8 changes: 5 additions & 3 deletions sqlmesh/core/engine_adapter/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def merge(
source_table: QueryOrDF,
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
unique_key: t.Sequence[exp.Expression],
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
when_matched: t.Optional[exp.Whens] = None,
) -> None:
logical_merge(
self,
Expand Down Expand Up @@ -105,7 +105,9 @@ def _insert_overwrite_by_condition(
target_table=table_name,
query=query,
on=exp.false(),
match_expressions=[when_not_matched_by_source, when_not_matched_by_target],
whens=exp.Whens(
expressions=[when_not_matched_by_source, when_not_matched_by_target]
),
)


Expand Down Expand Up @@ -406,7 +408,7 @@ def logical_merge(
source_table: QueryOrDF,
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
unique_key: t.Sequence[exp.Expression],
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
when_matched: t.Optional[exp.Whens] = None,
) -> None:
"""
Merge implementation for engine adapters that do not support merge natively.
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/engine_adapter/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def merge(
source_table: QueryOrDF,
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
unique_key: t.Sequence[exp.Expression],
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
when_matched: t.Optional[exp.Whens] = None,
) -> None:
# Merge isn't supported until Postgres 15
merge_impl = (
Expand Down
47 changes: 18 additions & 29 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from pydantic import Field
from sqlglot import exp
from sqlglot.helper import ensure_list
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import quote_identifiers
from sqlglot.optimizer.simplify import gen
Expand Down Expand Up @@ -423,48 +422,38 @@ class IncrementalByUniqueKeyKind(_IncrementalBy):
ModelKindName.INCREMENTAL_BY_UNIQUE_KEY
)
unique_key: SQLGlotListOfFields
when_matched: t.Optional[t.List[exp.When]] = None
when_matched: t.Optional[exp.Whens] = None
batch_concurrency: t.Literal[1] = 1

@field_validator("when_matched", mode="before")
@field_validator_v1_args
def _when_matched_validator(
cls,
v: t.Optional[t.Union[exp.When, str, t.List[exp.When], t.List[str]]],
v: t.Optional[t.Union[str, exp.Whens]],
values: t.Dict[str, t.Any],
) -> t.Optional[t.List[exp.When]]:
) -> t.Optional[exp.Whens]:
def replace_table_references(expression: exp.Expression) -> exp.Expression:
from sqlmesh.core.engine_adapter.base import (
MERGE_SOURCE_ALIAS,
MERGE_TARGET_ALIAS,
)
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS

if isinstance(expression, exp.Column):
if expression.table.lower() == "target":
expression.set(
"table",
exp.to_identifier(MERGE_TARGET_ALIAS),
)
expression.set("table", exp.to_identifier(MERGE_TARGET_ALIAS))
elif expression.table.lower() == "source":
expression.set(
"table",
exp.to_identifier(MERGE_SOURCE_ALIAS),
)
expression.set("table", exp.to_identifier(MERGE_SOURCE_ALIAS))

return expression

if not v:
return v # type: ignore

result = []
list_v = ensure_list(v)
for value in ensure_list(list_v):
if isinstance(value, str):
result.append(
t.cast(exp.When, d.parse_one(value, into=exp.When, dialect=get_dialect(values)))
)
else:
result.append(t.cast(exp.When, value.transform(replace_table_references))) # type: ignore
return result
if v is None:
return v
if isinstance(v, str):
# Whens wrap the WHEN clauses, but the parentheses aren't parsed by sqlglot
v = v.strip()
if v.startswith("("):
v = v[1:-1]

return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(values)))

return t.cast(exp.Whens, v.transform(replace_table_references))

@property
def data_hash_values(self) -> t.List[t.Optional[str]]:
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def managed_columns(self) -> t.Dict[str, exp.DataType]:
return getattr(self.kind, "managed_columns", {})

@property
def when_matched(self) -> t.Optional[t.List[exp.When]]:
def when_matched(self) -> t.Optional[exp.Whens]:
if isinstance(self.kind, IncrementalByUniqueKeyKind):
return self.kind.when_matched
return None
Expand Down
85 changes: 85 additions & 0 deletions sqlmesh/migrations/v0064_join_when_matched_strings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Join list of `WHEN [NOT] MATCHED` strings into a single string."""

import json

import pandas as pd
from sqlglot import exp

from sqlmesh.utils.migration import index_text_type, blob_text_type


def migrate(state_sync, **kwargs): # type: ignore
engine_adapter = state_sync.engine_adapter
schema = state_sync.schema
snapshots_table = "_snapshots"
index_type = index_text_type(engine_adapter.dialect)
if schema:
snapshots_table = f"{schema}.{snapshots_table}"

new_snapshots = []

for (
name,
identifier,
version,
snapshot,
kind_name,
updated_ts,
unpaused_ts,
ttl_ms,
unrestorable,
) in engine_adapter.fetchall(
exp.select(
"name",
"identifier",
"version",
"snapshot",
"kind_name",
"updated_ts",
"unpaused_ts",
"ttl_ms",
"unrestorable",
).from_(snapshots_table),
quote_identifiers=True,
):
parsed_snapshot = json.loads(snapshot)
node = parsed_snapshot["node"]

if "kind" in node:
kind = node["kind"]
if isinstance(when_matched := kind.get("when_matched"), list):
kind["when_matched"] = " ".join(when_matched)

new_snapshots.append(
{
"name": name,
"identifier": identifier,
"version": version,
"snapshot": json.dumps(parsed_snapshot),
"kind_name": kind_name,
"updated_ts": updated_ts,
"unpaused_ts": unpaused_ts,
"ttl_ms": ttl_ms,
"unrestorable": unrestorable,
}
)

if new_snapshots:
engine_adapter.delete_from(snapshots_table, "TRUE")
blob_type = blob_text_type(engine_adapter.dialect)

engine_adapter.insert_append(
snapshots_table,
pd.DataFrame(new_snapshots),
columns_to_types={
"name": exp.DataType.build(index_type),
"identifier": exp.DataType.build(index_type),
"version": exp.DataType.build(index_type),
"snapshot": exp.DataType.build(blob_type),
"kind_name": exp.DataType.build(index_type),
"updated_ts": exp.DataType.build("bigint"),
"unpaused_ts": exp.DataType.build("bigint"),
"ttl_ms": exp.DataType.build("bigint"),
"unrestorable": exp.DataType.build("boolean"),
},
)
2 changes: 2 additions & 0 deletions tests/core/engine_adapter/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2334,6 +2334,8 @@ def test_value_normalization(
)
],
)
if ctx.dialect == "tsql" and column_type == exp.DataType.Type.DATETIME:
full_column_type = exp.DataType.build("DATETIME2", dialect="tsql")

columns_to_types = {
"_idx": exp.DataType.build(DATA_TYPE.INT),
Expand Down
Loading

0 comments on commit c512e63

Please sign in to comment.