Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,23 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
"""Display list of models that failed during evaluation to the user."""

@abc.abstractmethod
def log_models_updated_during_restatement(
self,
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
environment: EnvironmentSummary,
environment_naming_info: EnvironmentNamingInfo,
default_catalog: t.Optional[str],
) -> None:
"""Display a list of models where new versions got deployed to the specified :environment while we were restating data the old versions

Args:
snapshots: a list of (snapshot_we_restated, snapshot_it_got_replaced_with_during_restatement) tuples
environment: which environment got updated while we were restating models
environment_naming_info: how snapshots are named in that :environment (for display name purposes)
default_catalog: the configured default catalog (for display name purposes)
"""

@abc.abstractmethod
def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
"""Starts loading and returns a unique ID that can be used to stop the loading. Optionally can display a message."""
Expand Down Expand Up @@ -771,6 +788,15 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
pass

def log_models_updated_during_restatement(
self,
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
environment: EnvironmentSummary,
environment_naming_info: EnvironmentNamingInfo,
default_catalog: t.Optional[str],
) -> None:
pass

def log_destructive_change(
self,
snapshot_name: str,
Expand Down Expand Up @@ -2225,6 +2251,37 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
for node_name, msg in error_messages.items():
self._print(f" [red]{node_name}[/red]\n\n{msg}")

def log_models_updated_during_restatement(
self,
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
environment: EnvironmentSummary,
environment_naming_info: EnvironmentNamingInfo,
default_catalog: t.Optional[str] = None,
) -> None:
if snapshots:
tree = Tree(
f"[yellow]The following models had new versions deployed in plan '{environment.plan_id}' while data was being restated:[/yellow]"
)

for restated_snapshot, updated_snapshot in snapshots:
display_name = restated_snapshot.display_name(
environment_naming_info,
default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
dialect=self.dialect,
)
current_branch = tree.add(display_name)
current_branch.add(f"restated version: '{restated_snapshot.version}'")
current_branch.add(f"currently active version: '{updated_snapshot.version}'")

self._print(tree)

self.log_warning(
f"\nThe '{environment.name}' environment currently points to [bold]different[/bold] versions of these models, not the versions that just got restated."
)
self._print(
"[yellow]If this is undesirable, please re-run this restatement plan which will apply it to the most recent versions of these models.[/yellow]\n"
)

def log_destructive_change(
self,
snapshot_name: str,
Expand Down
68 changes: 49 additions & 19 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sqlmesh.core.console import Console, get_console
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
from sqlmesh.core.macros import RuntimeStage
from sqlmesh.core.snapshot.definition import to_view_mapping
from sqlmesh.core.snapshot.definition import to_view_mapping, SnapshotTableInfo
from sqlmesh.core.plan import stages
from sqlmesh.core.plan.definition import EvaluatablePlan
from sqlmesh.core.scheduler import Scheduler
Expand Down Expand Up @@ -284,32 +284,62 @@ def visit_audit_only_run_stage(
def visit_restatement_stage(
self, stage: stages.RestatementStage, plan: EvaluatablePlan
) -> None:
snapshot_intervals_to_restate = {(s, i) for s, i in stage.snapshot_intervals.items()}

# Restating intervals on prod plans should mean that the intervals are cleared across
# all environments, not just the version currently in prod
# This ensures that work done in dev environments can still be promoted to prod
# by forcing dev environments to re-run intervals that changed in prod
# Restating intervals on prod plans means that once the data for the intervals being restated has been backfilled
# (which happens in the backfill stage) then we need to clear those intervals *from state* across all other environments.
#
# This ensures that work done in dev environments can still be promoted to prod by forcing dev environments to
# re-run intervals that changed in prod (because after this stage runs they are cleared from state and thus show as missing)
#
# It also means that any new dev environments created while this restatement plan was running also get the
# correct intervals cleared because we look up matching snapshots as at right now and not as at the time the plan
# was created, which could have been several hours ago if there was a lot of data to restate.
#
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
snapshot_intervals_to_restate.update(
{
(s.table_info, s.interval)
for s in identify_restatement_intervals_across_snapshot_versions(
state_reader=self.state_sync,
prod_restatements=plan.restatements,
disable_restatement_models=plan.disabled_restatement_models,
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
current_ts=to_timestamp(plan.execution_time or now()),
).values()
}

intervals_to_clear = identify_restatement_intervals_across_snapshot_versions(
state_reader=self.state_sync,
prod_restatements=plan.restatements,
disable_restatement_models=plan.disabled_restatement_models,
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
current_ts=to_timestamp(plan.execution_time or now()),
)

if not intervals_to_clear:
# Nothing to do
return

self.state_sync.remove_intervals(
snapshot_intervals=list(snapshot_intervals_to_restate),
snapshot_intervals=[(s.table_info, s.interval) for s in intervals_to_clear.values()],
remove_shared_versions=plan.is_prod,
)

# While the restatements were being processed, did any of the snapshots being restated get new versions deployed?
# If they did, they will not reflect the data that just got restated, so we need to notify the user
if deployed_env := self.state_sync.get_environment(plan.environment.name):
promoted_snapshots_by_name = {s.name: s for s in deployed_env.snapshots}

deployed_during_restatement: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]] = []

for name in plan.restatements:
snapshot = stage.all_snapshots[name]
version = snapshot.table_info.version
if (
prod_snapshot := promoted_snapshots_by_name.get(name)
) and prod_snapshot.version != version:
deployed_during_restatement.append(
(snapshot.table_info, prod_snapshot.table_info)
)

if deployed_during_restatement:
self.console.log_models_updated_during_restatement(
deployed_during_restatement,
deployed_env.summary,
plan.environment.naming_info,
self.default_catalog,
)
# note: the plan will automatically fail at the promotion stage with a ConflictingPlanError because the environment was changed by another plan
# so there is no need to explicitly fail the plan here

def visit_environment_record_update_stage(
self, stage: stages.EnvironmentRecordUpdateStage, plan: EvaluatablePlan
) -> None:
Expand Down
91 changes: 87 additions & 4 deletions sqlmesh/core/plan/explainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from __future__ import annotations

import abc
import typing as t
import logging
from dataclasses import dataclass

from rich.console import Console as RichConsole
from rich.tree import Tree
from sqlglot.dialects.dialect import DialectType
from sqlmesh.core import constants as c
from sqlmesh.core.console import Console, TerminalConsole, get_console
from sqlmesh.core.environment import EnvironmentNamingInfo
from sqlmesh.core.snapshot.definition import DeployabilityIndex
from sqlmesh.core.plan.common import (
SnapshotIntervalClearRequest,
identify_restatement_intervals_across_snapshot_versions,
)
from sqlmesh.core.plan.definition import EvaluatablePlan, SnapshotIntervals
from sqlmesh.core.plan import stages
from sqlmesh.core.plan.evaluator import (
Expand Down Expand Up @@ -45,6 +53,15 @@ def evaluate(
explainer_console = _get_explainer_console(
self.console, plan.environment, self.default_catalog
)

# add extra metadata that's only needed at this point for better --explain output
plan_stages = [
ExplainableRestatementStage.from_restatement_stage(stage, self.state_reader, plan)
if isinstance(stage, stages.RestatementStage)
else stage
for stage in plan_stages
]

explainer_console.explain(plan_stages)


Expand All @@ -54,6 +71,61 @@ def explain(self, stages: t.List[stages.PlanStage]) -> None:
pass


@dataclass
class ExplainableRestatementStage(stages.RestatementStage):
"""
This brings forward some calculations that would usually be done in the evaluator so the user can be given a better indication
of what might happen when they ask for the plan to be explained
"""

snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest]
"""Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name"""

deployability_index: DeployabilityIndex
"""Deployability of those snapshots (which arent necessarily present in the current plan so we cant use the
plan deployability index), used for outputting physical table names"""

@classmethod
def from_restatement_stage(
cls: t.Type[ExplainableRestatementStage],
stage: stages.RestatementStage,
state_reader: StateReader,
plan: EvaluatablePlan,
) -> ExplainableRestatementStage:
all_restatement_intervals = identify_restatement_intervals_across_snapshot_versions(
state_reader=state_reader,
prod_restatements=plan.restatements,
disable_restatement_models=plan.disabled_restatement_models,
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
)

snapshot_intervals_to_clear = {}
deployability_index = DeployabilityIndex.all_deployable()

if all_restatement_intervals:
snapshot_intervals_to_clear = {
s_id.name: r for s_id, r in all_restatement_intervals.items()
}

# creating a deployability index over the "snapshot intervals to clear"
# allows us to print the physical names of the tables affected in the console output
# note that we can't use the DeployabilityIndex on the plan because it only includes
# snapshots for the current environment, not across all environments
deployability_index = DeployabilityIndex.create(
snapshots=state_reader.get_snapshots(
[s.snapshot_id for s in snapshot_intervals_to_clear.values()]
),
start=plan.start,
start_override_per_model=plan.start_override_per_model,
)

return cls(
snapshot_intervals_to_clear=snapshot_intervals_to_clear,
deployability_index=deployability_index,
all_snapshots=stage.all_snapshots,
)


MAX_TREE_LENGTH = 10


Expand Down Expand Up @@ -146,11 +218,22 @@ def visit_audit_only_run_stage(self, stage: stages.AuditOnlyRunStage) -> Tree:
tree.add(display_name)
return tree

def visit_restatement_stage(self, stage: stages.RestatementStage) -> Tree:
def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage) -> Tree:
return self.visit_restatement_stage(stage)

def visit_restatement_stage(
self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage]
) -> Tree:
tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]")
for snapshot_table_info, interval in stage.snapshot_intervals.items():
display_name = self._display_name(snapshot_table_info)
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")

if isinstance(stage, ExplainableRestatementStage) and (
snapshot_intervals := stage.snapshot_intervals_to_clear
):
for clear_request in snapshot_intervals.values():
display_name = self._display_name(clear_request.table_info)
interval = clear_request.interval
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")

return tree

def visit_backfill_stage(self, stage: stages.BackfillStage) -> Tree:
Expand Down
38 changes: 22 additions & 16 deletions sqlmesh/core/plan/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Snapshot,
SnapshotTableInfo,
SnapshotId,
Interval,
)


Expand Down Expand Up @@ -98,14 +97,19 @@ class AuditOnlyRunStage:

@dataclass
class RestatementStage:
"""Restate intervals for given snapshots.
"""Clear intervals from state for snapshots in *other* environments, when restatements are requested in prod.

This stage is effectively a "marker" stage to trigger the plan evaluator to perform the "clear intervals" logic after the BackfillStage has completed.
The "clear intervals" logic is executed just-in-time using the latest state available in order to pick up new snapshots that may have
been created while the BackfillStage was running, which is why we do not build a list of snapshots to clear at plan time and defer to evaluation time.

Note that this stage is only present on `prod` plans because dev plans do not need to worry about clearing intervals in other environments.

Args:
snapshot_intervals: Intervals to restate.
all_snapshots: All snapshots in the plan by name.
all_snapshots: All snapshots in the plan by name. Note that this does not include the snapshots from other environments that will get their
intervals cleared, it's included here as an optimization to prevent having to re-fetch the current plan's snapshots
"""

snapshot_intervals: t.Dict[SnapshotTableInfo, Interval]
all_snapshots: t.Dict[str, Snapshot]


Expand Down Expand Up @@ -321,10 +325,6 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
if audit_only_snapshots:
stages.append(AuditOnlyRunStage(snapshots=list(audit_only_snapshots.values())))

restatement_stage = self._get_restatement_stage(plan, snapshots_by_name)
if restatement_stage:
stages.append(restatement_stage)

if missing_intervals_before_promote:
stages.append(
BackfillStage(
Expand All @@ -349,6 +349,15 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
)
)

# note: "restatement stage" (which is clearing intervals in state - not actually performing the restatements, that's the backfill stage)
# needs to come *after* the backfill stage so that at no time do other plans / runs see empty prod intervals and compete with this plan to try to fill them.
# in addition, when we update intervals in state, we only clear intervals from dev snapshots to force dev models to be backfilled based on the new prod data.
# we can leave prod intervals alone because by the time this plan finishes, the intervals in state have not actually changed, since restatement replaces
# data for existing intervals and does not produce new ones
restatement_stage = self._get_restatement_stage(plan, snapshots_by_name)
if restatement_stage:
stages.append(restatement_stage)

stages.append(
EnvironmentRecordUpdateStage(
no_gaps_snapshot_names={s.name for s in before_promote_snapshots}
Expand Down Expand Up @@ -443,15 +452,12 @@ def _get_after_all_stage(
def _get_restatement_stage(
self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]
) -> t.Optional[RestatementStage]:
snapshot_intervals_to_restate = {}
for name, interval in plan.restatements.items():
restated_snapshot = snapshots_by_name[name]
restated_snapshot.remove_interval(interval)
snapshot_intervals_to_restate[restated_snapshot.table_info] = interval
if not snapshot_intervals_to_restate or plan.is_dev:
if not plan.restatements or plan.is_dev:
# The RestatementStage to clear intervals from state across all environments is not needed for plans against dev, only prod
return None

return RestatementStage(
snapshot_intervals=snapshot_intervals_to_restate, all_snapshots=snapshots_by_name
all_snapshots=snapshots_by_name,
)

def _get_physical_layer_update_stage(
Expand Down
Loading