Skip to content

Commit

Permalink
Feat!: Add support for merge_filter and dbt incremental_predicates fo…
Browse files Browse the repository at this point in the history
…r Incremental By Unique Key (#3540)
  • Loading branch information
themisvaltinos authored Dec 20, 2024
1 parent 1f71537 commit 0e51dd8
Show file tree
Hide file tree
Showing 13 changed files with 354 additions and 15 deletions.
18 changes: 18 additions & 0 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
lambda: _parse_macro_or_clause(self, self._parse_when_matched),
optional=True,
)
elif name == "merge_filter":
value = self._parse_conjunction()
elif self._match(TokenType.L_PAREN):
value = self.expression(exp.Tuple, expressions=self._parse_csv(self._parse_equality))
self._match_r_paren()
Expand Down Expand Up @@ -1260,3 +1262,19 @@ def extract_func_call(

def is_meta_expression(v: t.Any) -> bool:
return isinstance(v, (Audit, Metric, Model))


def replace_merge_table_aliases(expression: exp.Expression) -> exp.Expression:
"""
Resolves references from the "source" and "target" tables (or their DBT equivalents)
with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
"""
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS

if isinstance(expression, exp.Column):
if expression.table.lower() in ("target", "dbt_internal_dest"):
expression.set("table", exp.to_identifier(MERGE_TARGET_ALIAS))
elif expression.table.lower() in ("source", "dbt_internal_source"):
expression.set("table", exp.to_identifier(MERGE_SOURCE_ALIAS))

return expression
4 changes: 4 additions & 0 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,6 +1804,7 @@ def merge(
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
unique_key: t.Sequence[exp.Expression],
when_matched: t.Optional[exp.Whens] = None,
merge_filter: t.Optional[exp.Expression] = None,
) -> None:
source_queries, columns_to_types = self._get_source_queries_and_columns_to_types(
source_table, columns_to_types, target_table=target_table
Expand All @@ -1815,6 +1816,9 @@ def merge(
for part in unique_key
)
)
if merge_filter:
on = exp.and_(merge_filter, on)

if not when_matched:
when_matched = exp.Whens()
when_matched.append(
Expand Down
9 changes: 7 additions & 2 deletions sqlmesh/core/engine_adapter/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def merge(
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
unique_key: t.Sequence[exp.Expression],
when_matched: t.Optional[exp.Whens] = None,
merge_filter: t.Optional[exp.Expression] = None,
) -> None:
logical_merge(
self,
Expand All @@ -39,6 +40,7 @@ def merge(
columns_to_types,
unique_key,
when_matched=when_matched,
merge_filter=merge_filter,
)


Expand Down Expand Up @@ -409,6 +411,7 @@ def logical_merge(
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
unique_key: t.Sequence[exp.Expression],
when_matched: t.Optional[exp.Whens] = None,
merge_filter: t.Optional[exp.Expression] = None,
) -> None:
"""
Merge implementation for engine adapters that do not support merge natively.
Expand All @@ -420,10 +423,12 @@ def logical_merge(
within the temporary table are ommitted.
4. Drop the temporary table.
"""
if when_matched:
if when_matched or merge_filter:
prop = "when_matched" if when_matched else "merge_filter"
raise SQLMeshError(
"This engine does not support MERGE expressions and therefore `when_matched` is not supported."
f"This engine does not support MERGE expressions and therefore `{prop}` is not supported."
)

engine_adapter._replace_by_key(
target_table, source_table, columns_to_types, unique_key, is_unique_key=True
)
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/engine_adapter/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def merge(
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
unique_key: t.Sequence[exp.Expression],
when_matched: t.Optional[exp.Whens] = None,
merge_filter: t.Optional[exp.Expression] = None,
) -> None:
# Merge isn't supported until Postgres 15
merge_impl = (
Expand All @@ -120,4 +121,5 @@ def merge(
columns_to_types,
unique_key,
when_matched=when_matched,
merge_filter=merge_filter,
)
31 changes: 19 additions & 12 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ class IncrementalByUniqueKeyKind(_IncrementalBy):
)
unique_key: SQLGlotListOfFields
when_matched: t.Optional[exp.Whens] = None
merge_filter: t.Optional[exp.Expression] = None
batch_concurrency: t.Literal[1] = 1

@field_validator("when_matched", mode="before")
Expand All @@ -453,17 +454,6 @@ def _when_matched_validator(
v: t.Optional[t.Union[str, exp.Whens]],
values: t.Dict[str, t.Any],
) -> t.Optional[exp.Whens]:
def replace_table_references(expression: exp.Expression) -> exp.Expression:
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS

if isinstance(expression, exp.Column):
if expression.table.lower() == "target":
expression.set("table", exp.to_identifier(MERGE_TARGET_ALIAS))
elif expression.table.lower() == "source":
expression.set("table", exp.to_identifier(MERGE_SOURCE_ALIAS))

return expression

if v is None:
return v
if isinstance(v, str):
Expand All @@ -474,14 +464,30 @@ def replace_table_references(expression: exp.Expression) -> exp.Expression:

return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(values)))

return t.cast(exp.Whens, v.transform(replace_table_references))
return t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases))

@field_validator("merge_filter", mode="before")
@field_validator_v1_args
def _merge_filter_validator(
cls,
v: t.Optional[exp.Expression],
values: t.Dict[str, t.Any],
) -> t.Optional[exp.Expression]:
if v is None:
return v
if isinstance(v, str):
v = v.strip()
return d.parse_one(v, dialect=get_dialect(values))

return v.transform(d.replace_merge_table_aliases)

@property
def data_hash_values(self) -> t.List[t.Optional[str]]:
return [
*super().data_hash_values,
*(gen(k) for k in self.unique_key),
gen(self.when_matched) if self.when_matched is not None else None,
gen(self.merge_filter) if self.merge_filter is not None else None,
]

def to_expression(
Expand All @@ -494,6 +500,7 @@ def to_expression(
{
"unique_key": exp.Tuple(expressions=self.unique_key),
"when_matched": self.when_matched,
"merge_filter": self.merge_filter,
}
),
],
Expand Down
6 changes: 6 additions & 0 deletions sqlmesh/core/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,12 @@ def when_matched(self) -> t.Optional[exp.Whens]:
return self.kind.when_matched
return None

@property
def merge_filter(self) -> t.Optional[exp.Expression]:
if isinstance(self.kind, IncrementalByUniqueKeyKind):
return self.kind.merge_filter
return None

@property
def catalog(self) -> t.Optional[str]:
"""Returns the catalog of a model."""
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,7 @@ def insert(
columns_to_types=model.columns_to_types,
unique_key=model.unique_key,
when_matched=model.when_matched,
merge_filter=model.merge_filter,
)

def append(
Expand All @@ -1407,6 +1408,7 @@ def append(
columns_to_types=model.columns_to_types,
unique_key=model.unique_key,
when_matched=model.when_matched,
merge_filter=model.merge_filter,
)


Expand Down
11 changes: 11 additions & 0 deletions sqlmesh/dbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,17 @@ def model_kind(self, context: DbtContext) -> ModelKind:
f"{self.canonical_name(context)}: SQLMesh incremental by unique key strategy is not compatible with '{strategy}'"
f" incremental strategy. Supported strategies include {collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES)}."
)

if self.incremental_predicates:
dialect = self.dialect(context)
incremental_kind_kwargs["merge_filter"] = exp.and_(
*[
d.parse_one(predicate, dialect=dialect)
for predicate in self.incremental_predicates
],
dialect=dialect,
).transform(d.replace_merge_table_aliases)

return IncrementalByUniqueKeyKind(
unique_key=self.unique_key,
disable_restatement=disable_restatement,
Expand Down
73 changes: 73 additions & 0 deletions tests/core/engine_adapter/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,79 @@ def test_merge_when_matched_multiple(make_mocked_engine_adapter: t.Callable, ass
)


def test_merge_filter(make_mocked_engine_adapter: t.Callable, assert_exp_eq):
adapter = make_mocked_engine_adapter(EngineAdapter)

adapter.merge(
target_table="target",
source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')),
columns_to_types={
"ID": exp.DataType.build("int"),
"ts": exp.DataType.build("timestamp"),
"val": exp.DataType.build("int"),
},
unique_key=[exp.to_identifier("ID", quoted=True)],
when_matched=exp.Whens(
expressions=[
exp.When(
matched=True,
source=False,
then=exp.Update(
expressions=[
exp.column("val", "__MERGE_TARGET__").eq(
exp.column("val", "__MERGE_SOURCE__")
),
exp.column("ts", "__MERGE_TARGET__").eq(
exp.Coalesce(
this=exp.column("ts", "__MERGE_SOURCE__"),
expressions=[exp.column("ts", "__MERGE_TARGET__")],
)
),
],
),
)
]
),
merge_filter=exp.And(
this=exp.GT(
this=exp.column("ID", "__MERGE_SOURCE__"),
expression=exp.Literal(this="0", is_string=False),
),
expression=exp.LT(
this=exp.column("ts", "__MERGE_TARGET__"),
expression=exp.Timestamp(this=exp.column("2020-02-05", quoted=True)),
),
),
)

assert_exp_eq(
adapter.cursor.execute.call_args[0][0],
"""
MERGE INTO "target" AS "__MERGE_TARGET__"
USING (
SELECT "ID", "ts", "val"
FROM "source"
) AS "__MERGE_SOURCE__"
ON (
"__MERGE_SOURCE__"."ID" > 0
AND "__MERGE_TARGET__"."ts" < TIMESTAMP("2020-02-05")
)
AND "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID"
WHEN MATCHED THEN
UPDATE SET
"__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val",
"__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts")
WHEN NOT MATCHED THEN
INSERT ("ID", "ts", "val")
VALUES (
"__MERGE_SOURCE__"."ID",
"__MERGE_SOURCE__"."ts",
"__MERGE_SOURCE__"."val"
);
""",
)


def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(EngineAdapter)

Expand Down
Loading

0 comments on commit 0e51dd8

Please sign in to comment.