Skip to content

Commit

Permalink
fix: ensure metric_macro expands templates
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Feb 21, 2025
1 parent f820f9a commit d47ff5e
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 81 deletions.
56 changes: 20 additions & 36 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,10 @@ def process_template(self, sql: str, **kwargs: Any) -> str:
kwargs.update(self._context)

context = validate_template_context(self.engine, kwargs)
return template.render(context)
try:
return template.render(context)
except RecursionError as ex:
raise SupersetTemplateException("Cyclic filters detected") from ex


class JinjaTemplateProcessor(BaseTemplateProcessor):
Expand Down Expand Up @@ -659,11 +662,18 @@ def set_context(self, **kwargs: Any) -> None:
"filter_values": partial(safe_proxy, extra_cache.filter_values),
"get_filters": partial(safe_proxy, extra_cache.get_filters),
"dataset": partial(safe_proxy, dataset_macro_with_context),
"metric": partial(safe_proxy, metric_macro),
"get_time_filter": partial(safe_proxy, extra_cache.get_time_filter),
}
)

# The `metric` filter needs the full context, in order to expand other filters
self._context["metric"] = partial(
safe_proxy,
metric_macro,
self.env,
self._context,
)


class NoOpTemplateProcessor(BaseTemplateProcessor):
def process_template(self, sql: str, **kwargs: Any) -> str:
Expand Down Expand Up @@ -889,27 +899,12 @@ def get_dataset_id_from_context(metric_key: str) -> int:
raise SupersetTemplateException(exc_message)


def has_metric_macro(template_string: str, env: Environment) -> bool:
"""
Checks if a template string contains a metric macro.
>>> has_metric_macro("{{ metric('my_metric') }}")
True
"""
ast = env.parse(template_string)

def visit_node(node: Node) -> bool:
return (
isinstance(node, Call)
and isinstance(node.node, nodes.Name)
and node.node.name == "metric"
) or any(visit_node(child) for child in node.iter_child_nodes())

return visit_node(ast)


def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str:
def metric_macro(
env: Environment,
context: dict[str, Any],
metric_key: str,
dataset_id: Optional[int] = None,
) -> str:
"""
Given a metric key, returns its syntax.
Expand Down Expand Up @@ -943,18 +938,7 @@ def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str:
)

definition = metrics[metric_key]

env = SandboxedEnvironment(undefined=DebugUndefined)
context = {"metric": partial(safe_proxy, metric_macro)}
while has_metric_macro(definition, env):
old_definition = definition
template = env.from_string(definition)
try:
definition = template.render(context)
except RecursionError as ex:
raise SupersetTemplateException("Cyclic metric macro detected") from ex

if definition == old_definition:
break
template = env.from_string(definition)
definition = template.render(context)

return definition
Loading

0 comments on commit d47ff5e

Please sign in to comment.