Skip to content

Commit

Permalink
cache iteration: various fixes- will invalidate exisiting cache (#3480)
Browse files Browse the repository at this point in the history
This has fixes for:

 - [x] Shadowed arguments
 - [x] Formatting causing issues with context block: #2633
 - [x] improved df "object detection": #2661
 
Following PR changes:

- Detect when execution hash relies on a another hash object (cache
breaking) (#3270)
-  Allow for pickle hash as fallback for "unhashable" variables (#3270)
- Expand `@persistent_cache` api (this shouldn't cache bust, so I might
just follow up) (#2653)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dmadisetti and pre-commit-ci[bot] authored Jan 22, 2025
1 parent 9a96a9d commit ca092e9
Show file tree
Hide file tree
Showing 7 changed files with 507 additions and 78 deletions.
1 change: 1 addition & 0 deletions marimo/_ast/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from dataclasses import dataclass
Expand Down
25 changes: 18 additions & 7 deletions marimo/_runtime/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,24 @@ def is_data_primitive(value: Any) -> bool:
or (hasattr(value.dtype, "hasobject") and value.dtype.hasobject)
)
elif hasattr(value, "dtypes"):
for dtype in value.dtypes:
# Capture pandas cases
if getattr(dtype, "hasobject", None):
return False
# Capture polars cases
if hasattr(dtype, "is_numeric") and not dtype.is_numeric:
return False
# Bit of discrepancy between objects like polars and pandas, so use
# narwhals to normalize the dataframe.
import narwhals as nw

try:
return bool(
nw.narwhalify(
lambda df: all(
df[col].dtype.is_numeric() for col in df.columns
)
)(value)
)
except Exception as err:
raise err from ValueError(
"Unexpected datatype, narwhals was unable to normalize "
"dataframe. Please report this to "
"github.com/marimo-team/marimo"
)
# Otherwise may be a closely related array object
return True

Expand Down
88 changes: 76 additions & 12 deletions marimo/_save/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from marimo._utils.variables import unmangle_local

ARG_PREFIX: str = "*"


class BlockException(Exception):
pass
Expand Down Expand Up @@ -91,11 +93,22 @@ def generic_visit(self, node: ast.AST) -> tuple[ast.Module, ast.Module]: # typ

assert isinstance(node, list), "Unexpected block structure."
for n in node:
# There's a chance that the block is first evaluated somewhere in a
# multiline line expression, for instance:
# 1 >>> with cache as c:
# 2 >>> a = [
# 3 >>> f(x) # <<< first frame call is here, and not line 2
# 4 >>> ]
# so check that the "target" line is in the interval
# (i.e. 1 <= 3 <= 4)
parent = None
if n.lineno < self.target_line:
pre_block.append(n)
previous = n
elif n.lineno == self.target_line:
# the line is contained with this
elif n.lineno <= self.target_line <= n.end_lineno:
on_line.append(n)
parent = n
# The target line can easily be skipped if there are comments or
# white space or if the block is contained within another block.
else:
Expand All @@ -106,31 +119,80 @@ def generic_visit(self, node: ast.AST) -> tuple[ast.Module, ast.Module]: # typ
# excluding by omission try, for, classes and functions.
if len(on_line) == 0:
if isinstance(previous, (ast.With, ast.If)):
try:
# Recursion by referring the the containing block also
# captures the case where the target line number was not
# exactly hit.
return ExtractWithBlock(self.target_line).generic_visit(
previous.body # type: ignore[arg-type]
)
except BlockException:
on_line.append(previous)
on_line.append(previous)
# Captures both branches of If, and the With block.
bodies = (
[previous.body]
if isinstance(previous, ast.With)
else [previous.body, previous.orelse]
)
for body in bodies:
try:
# Recursion by referring to the containing block also
# captures the case where the target line number was not
# exactly hit.
# for instance:
# 1 >>> if True: # Captured ast node
# 2 >>> with fn() as x:
# 3 >>> with cache as c:
# 4 >>> a = 1 # <<< frame line
# will recurse through here thrice to get to the frame
# line.
# NB. the "extracted" With block is the one that
# invoked this call
return ExtractWithBlock(
self.target_line
).generic_visit(
body # type: ignore[arg-type]
)
except BlockException:
pass

else:
raise BlockException(
"persistent_cache cannot be invoked within a block "
"(try moving the block within the persistent_cache scope)."
)
# Intentionally not elif (on_line can be added in previous block)
if len(on_line) == 1:
assert isinstance(on_line[0], ast.With), "Unexpected block."
if parent and not isinstance(on_line[0], ast.With):
raise BlockException("Detected line is not a With statement.")
if not isinstance(on_line[0], ast.With):
raise BlockException(
"Unconventional formatting may lead to unexpected behavior. "
"Please format your code, and/or reduce nesting.\n"
"For instance, the following is not supported:\n"
">>>> with cache() as c: a = 1 # all one line"
)
return clean_to_modules(pre_block, on_line[0])
# It should be possible to relate the lines with the AST,
# but reduce potential bugs by just throwing an error.
raise BlockException(
"Saving on a shared line may lead to unexpected behavior."
"Unable to determine structure your call. Please"
" report this to github:marimo-team/marimo/issues"
)


class MangleArguments(ast.NodeTransformer):
"""Mangles arguments names to prevent shadowing issues in analysis."""

def __init__(
self,
args: set[str],
*arg: Any,
prefix: str = ARG_PREFIX,
**kwargs: Any,
) -> None:
super().__init__(*arg, **kwargs)
self.prefix = prefix
self.args = args

def visit_Name(self, node: ast.Name) -> ast.Name:
if node.id in self.args:
node.id = f"{self.prefix}{node.id}"
return node


class DeprivateVisitor(ast.NodeTransformer):
"""Removes the mangling of private variables from a module."""

Expand All @@ -157,12 +219,14 @@ def visit_Return(self, node: ast.Return) -> ast.Expr:

def strip_function(fn: Callable[..., Any]) -> ast.Module:
code, _ = inspect.getsourcelines(fn)
args = set(fn.__code__.co_varnames)
function_ast = ast.parse(textwrap.dedent("".join(code)))
body = function_ast.body.pop()
assert isinstance(body, (ast.FunctionDef, ast.AsyncFunctionDef)), (
"Expected a function definition"
)
extracted = ast.Module(body.body, type_ignores=[])
module = RemoveReturns().visit(extracted)
module = MangleArguments(args).visit(module)
assert isinstance(module, ast.Module), "Expected a module"
return module
68 changes: 42 additions & 26 deletions marimo/_save/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import struct
import sys
import types
from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Optional
from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple, Optional

from marimo._ast.visitor import Name, ScopedVisitor
from marimo._dependencies.dependencies import DependencyManager
Expand All @@ -22,7 +22,7 @@
is_pure_function,
)
from marimo._runtime.state import SetFunctor, State
from marimo._save.ast import DeprivateVisitor
from marimo._save.ast import DeprivateVisitor, strip_function
from marimo._save.cache import Cache, CacheType
from marimo._utils.variables import (
get_cell_from_local,
Expand Down Expand Up @@ -57,10 +57,6 @@ class SerialRefs(NamedTuple):
stateful_refs: set[Name]


class ShadowedRef:
"""Stub for scoped variables that may shadow global references"""


def hash_module(
code: Optional[CodeType], hash_type: str = DEFAULT_HASH
) -> bytes:
Expand All @@ -79,7 +75,9 @@ def process(code_obj: CodeType) -> None:
hash_alg.update(str(const).encode("utf8"))
# Concatenate the names and bytecode of the current code object
# Will cause invalidation of variable naming at the top level
hash_alg.update(bytes("|".join(code_obj.co_names), "utf8"))

names = [unmangle_local(name).name for name in code_obj.co_names]
hash_alg.update(bytes("|".join(names), "utf8"))
hash_alg.update(code_obj.co_code)

process(code)
Expand All @@ -92,7 +90,7 @@ def hash_raw_module(
# AST has to be compiled to code object prior to process.
return hash_module(
compile(
module,
DeprivateVisitor().visit(module),
"<hash>",
mode="exec",
flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT,
Expand All @@ -107,6 +105,38 @@ def hash_cell_impl(cell: CellImpl, hash_type: str = DEFAULT_HASH) -> bytes:
)


def hash_function(
fn: Callable[..., Any], hash_type: str = DEFAULT_HASH
) -> bytes:
return hash_raw_module(
DeprivateVisitor().visit(strip_function(fn)), hash_type
)


def hash_cell_group(
cell_ids: set[CellId_t],
graph: DirectedGraph,
hash_type: str = DEFAULT_HASH,
) -> bytes:
hash_alg = hashlib.new(hash_type, usedforsecurity=False)
hashes = []
for cell_id in cell_ids:
cell_impl = graph.cells[cell_id]
hashes.append(hash_cell_impl(cell_impl, hash_alg.name))

# Sort results post hash, to ensure deterministic ordering.
for hashed_cell in sorted(hashes):
hash_alg.update(hashed_cell)
return hash_alg.digest()


def hash_cell_execution(
cell_id: CellId_t, graph: DirectedGraph, hash_type: str = DEFAULT_HASH
) -> bytes:
ancestors = graph.ancestors(cell_id)
return hash_cell_group(ancestors, graph, hash_type)


def standardize_tensor(tensor: Tensor) -> Optional[Tensor]:
if (
hasattr(tensor, "__array__")
Expand Down Expand Up @@ -719,22 +749,6 @@ def serialize_and_dequeue_stateful_content_refs(
refs, inclusive=False
)

for ref in transitive_state_refs:
if ref in scope and isinstance(scope[ref], ShadowedRef):
# TODO(akshayka, dmadisetti): Lift this restriction once
# function args are rewritten.
#
# This makes more sense as a NameError, but the marimo's
# explainer text for NameError's doesn't make sense in this
# context. ("Definition expected in ...")
raise RuntimeError(
f"The cached function declares an argument '{ref}'"
"but a captured function or class uses the "
f"global variable '{ref}'. Please rename "
"the argument, or restructure the use "
f"of the global variable."
)

# Filter for relevant stateful cases.
refs |= set(
filter(
Expand Down Expand Up @@ -776,9 +790,8 @@ def hash_and_dequeue_execution_refs(self, refs: set[Name]) -> set[Name]:
*[self.graph.definitions.get(ref, set()) for ref in refs]
)
to_hash = ancestors & ref_cells
for ancestor_id in sorted(to_hash):
for ancestor_id in to_hash:
cell_impl = self.graph.cells[ancestor_id]
self.hash_alg.update(hash_cell_impl(cell_impl, self.hash_alg.name))
for ref in cell_impl.defs:
# Look for both, since mangle reference depends on the context
# of the definition.
Expand All @@ -787,6 +800,9 @@ def hash_and_dequeue_execution_refs(self, refs: set[Name]) -> set[Name]:
unmangled_ref, _ = unmangle_local(ref)
if unmangled_ref in refs:
refs.remove(unmangled_ref)
self.hash_alg.update(
hash_cell_group(to_hash, self.graph, self.hash_alg.name)
)
return refs

def hash_and_verify_context_refs(
Expand Down
19 changes: 6 additions & 13 deletions marimo/_save/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
from marimo._runtime.context import get_context
from marimo._runtime.runtime import notebook_dir
from marimo._runtime.state import State
from marimo._save.ast import ExtractWithBlock, strip_function
from marimo._save.ast import ARG_PREFIX, ExtractWithBlock, strip_function
from marimo._save.cache import Cache, CacheException
from marimo._save.hash import (
DEFAULT_HASH,
BlockHasher,
ShadowedRef,
cache_attempt_from_hash,
content_cache_attempt_from_base,
)
Expand Down Expand Up @@ -115,21 +114,13 @@ def _set_context(self, fn: Callable[..., Any]) -> None:
# checking a single frame- should be good enough.
f_locals = inspect.stack()[2 + self._frame_offset][0].f_locals
self.scope = {**ctx.globals, **f_locals}
# In case scope shadows variables
#
# TODO(akshayka, dmadisetti): rewrite function args with an AST pass
# to make them unique, deterministically based on function body; this
# will allow for lifting the error when a ShadowedRef is also used
# as a regular ref.
for arg in self._args:
self.scope[arg] = ShadowedRef()

# Scoped refs are references particular to this block, that may not be
# defined out of the context of the block, or the cell.
# For instance, the args of the invoked function are restricted to the
# block.
cell_id = ctx.cell_id or ctx.execution_context.cell_id or ""
self.scoped_refs = set(self._args)
self.scoped_refs = set([f"{ARG_PREFIX}{k}" for k in self._args])
# # As are the "locals" not in globals
self.scoped_refs |= set(f_locals.keys()) - set(ctx.globals.keys())
# Defined in the cell, and currently available in scope
Expand Down Expand Up @@ -184,16 +175,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
self._set_context(args[0])
return self

# Rewrite scoped args to prevent shadowed variables
arg_dict = {f"{ARG_PREFIX}{k}": v for (k, v) in zip(self._args, args)}
kwargs = {f"{ARG_PREFIX}{k}": v for (k, v) in kwargs.items()}
# Capture the call case
arg_dict = {k: v for (k, v) in zip(self._args, args)}
scope = {**self.scope, **get_context().globals, **arg_dict, **kwargs}
assert self._loader is not None, UNEXPECTED_FAILURE_BOILERPLATE
attempt = content_cache_attempt_from_base(
self.base_block,
scope,
self._loader(),
scoped_refs=self.scoped_refs,
required_refs=set(self._args),
required_refs=set([f"{ARG_PREFIX}{k}" for k in self._args]),
as_fn=True,
)

Expand Down
Loading

0 comments on commit ca092e9

Please sign in to comment.