Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache iteration: various fixes- will invalidate exisiting cache #3480

Merged
merged 9 commits into from
Jan 22, 2025
1 change: 1 addition & 0 deletions marimo/_ast/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from dataclasses import dataclass
Expand Down
25 changes: 18 additions & 7 deletions marimo/_runtime/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,24 @@ def is_data_primitive(value: Any) -> bool:
or (hasattr(value.dtype, "hasobject") and value.dtype.hasobject)
)
elif hasattr(value, "dtypes"):
for dtype in value.dtypes:
# Capture pandas cases
if getattr(dtype, "hasobject", None):
return False
# Capture polars cases
if hasattr(dtype, "is_numeric") and not dtype.is_numeric:
return False
# Bit of discrepancy between objects like polars and pandas, so use
# narwhals to normalize the dataframe.
import narwhals as nw

try:
return bool(
nw.narwhalify(
lambda df: all(
df[col].dtype.is_numeric() for col in df.columns
)
)(value)
)
except Exception as err:
raise err from ValueError(
"Unexpected datatype, narwhals was unable to normalize "
"dataframe. Please report this to "
"github.com/marimo-team/marimo"
)
# Otherwise may be a closely related array object
return True

Expand Down
88 changes: 76 additions & 12 deletions marimo/_save/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from marimo._utils.variables import unmangle_local

ARG_PREFIX: str = "*"


class BlockException(Exception):
pass
Expand Down Expand Up @@ -91,11 +93,22 @@ def generic_visit(self, node: ast.AST) -> tuple[ast.Module, ast.Module]: # typ

assert isinstance(node, list), "Unexpected block structure."
for n in node:
# There's a chance that the block is first evaluated somewhere in a
# multiline line expression, for instance:
# 1 >>> with cache as c:
# 2 >>> a = [
# 3 >>> f(x) # <<< first frame call is here, and not line 2
# 4 >>> ]
# so check that the "target" line is in the interval
# (i.e. 1 <= 3 <= 4)
parent = None
if n.lineno < self.target_line:
pre_block.append(n)
previous = n
elif n.lineno == self.target_line:
# the line is contained with this
elif n.lineno <= self.target_line <= n.end_lineno:
on_line.append(n)
parent = n
# The target line can easily be skipped if there are comments or
# white space or if the block is contained within another block.
else:
Expand All @@ -106,31 +119,80 @@ def generic_visit(self, node: ast.AST) -> tuple[ast.Module, ast.Module]: # typ
# excluding by omission try, for, classes and functions.
if len(on_line) == 0:
if isinstance(previous, (ast.With, ast.If)):
try:
# Recursion by referring the the containing block also
# captures the case where the target line number was not
# exactly hit.
return ExtractWithBlock(self.target_line).generic_visit(
previous.body # type: ignore[arg-type]
)
except BlockException:
on_line.append(previous)
on_line.append(previous)
# Captures both branches of If, and the With block.
bodies = (
[previous.body]
if isinstance(previous, ast.With)
else [previous.body, previous.orelse]
)
for body in bodies:
try:
# Recursion by referring to the containing block also
# captures the case where the target line number was not
# exactly hit.
# for instance:
# 1 >>> if True: # Captured ast node
# 2 >>> with fn() as x:
# 3 >>> with cache as c:
# 4 >>> a = 1 # <<< frame line
# will recurse through here thrice to get to the frame
# line.
# NB. the "extracted" With block is the one that
# invoked this call
return ExtractWithBlock(
self.target_line
).generic_visit(
body # type: ignore[arg-type]
)
except BlockException:
pass

else:
raise BlockException(
"persistent_cache cannot be invoked within a block "
"(try moving the block within the persistent_cache scope)."
)
# Intentionally not elif (on_line can be added in previous block)
if len(on_line) == 1:
assert isinstance(on_line[0], ast.With), "Unexpected block."
if parent and not isinstance(on_line[0], ast.With):
raise BlockException("Detected line is not a With statement.")
if not isinstance(on_line[0], ast.With):
raise BlockException(
"Unconventional formatting may lead to unexpected behavior. "
"Please format your code, and/or reduce nesting.\n"
"For instance, the following is not supported:\n"
">>>> with cache() as c: a = 1 # all one line"
)
return clean_to_modules(pre_block, on_line[0])
# It should be possible to relate the lines with the AST,
# but reduce potential bugs by just throwing an error.
raise BlockException(
"Saving on a shared line may lead to unexpected behavior."
"Unable to determine structure your call. Please"
" report this to github:marimo-team/marimo/issues"
)


class MangleArguments(ast.NodeTransformer):
"""Mangles arguments names to prevent shadowing issues in analysis."""

def __init__(
self,
args: set[str],
*arg: Any,
prefix: str = ARG_PREFIX,
**kwargs: Any,
) -> None:
super().__init__(*arg, **kwargs)
self.prefix = prefix
self.args = args

def visit_Name(self, node: ast.Name) -> ast.Name:
if node.id in self.args:
node.id = f"{self.prefix}{node.id}"
return node


class DeprivateVisitor(ast.NodeTransformer):
"""Removes the mangling of private variables from a module."""

Expand All @@ -157,12 +219,14 @@ def visit_Return(self, node: ast.Return) -> ast.Expr:

def strip_function(fn: Callable[..., Any]) -> ast.Module:
code, _ = inspect.getsourcelines(fn)
args = set(fn.__code__.co_varnames)
function_ast = ast.parse(textwrap.dedent("".join(code)))
body = function_ast.body.pop()
assert isinstance(body, (ast.FunctionDef, ast.AsyncFunctionDef)), (
"Expected a function definition"
)
extracted = ast.Module(body.body, type_ignores=[])
module = RemoveReturns().visit(extracted)
module = MangleArguments(args).visit(module)
assert isinstance(module, ast.Module), "Expected a module"
return module
68 changes: 42 additions & 26 deletions marimo/_save/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import struct
import sys
import types
from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Optional
from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple, Optional

from marimo._ast.visitor import Name, ScopedVisitor
from marimo._dependencies.dependencies import DependencyManager
Expand All @@ -22,7 +22,7 @@
is_pure_function,
)
from marimo._runtime.state import SetFunctor, State
from marimo._save.ast import DeprivateVisitor
from marimo._save.ast import DeprivateVisitor, strip_function
from marimo._save.cache import Cache, CacheType
from marimo._utils.variables import (
get_cell_from_local,
Expand Down Expand Up @@ -57,10 +57,6 @@ class SerialRefs(NamedTuple):
stateful_refs: set[Name]


class ShadowedRef:
"""Stub for scoped variables that may shadow global references"""


def hash_module(
code: Optional[CodeType], hash_type: str = DEFAULT_HASH
) -> bytes:
Expand All @@ -79,7 +75,9 @@ def process(code_obj: CodeType) -> None:
hash_alg.update(str(const).encode("utf8"))
# Concatenate the names and bytecode of the current code object
# Will cause invalidation of variable naming at the top level
hash_alg.update(bytes("|".join(code_obj.co_names), "utf8"))

names = [unmangle_local(name).name for name in code_obj.co_names]
hash_alg.update(bytes("|".join(names), "utf8"))
hash_alg.update(code_obj.co_code)

process(code)
Expand All @@ -92,7 +90,7 @@ def hash_raw_module(
# AST has to be compiled to code object prior to process.
return hash_module(
compile(
module,
DeprivateVisitor().visit(module),
"<hash>",
mode="exec",
flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT,
Expand All @@ -107,6 +105,38 @@ def hash_cell_impl(cell: CellImpl, hash_type: str = DEFAULT_HASH) -> bytes:
)


def hash_function(
fn: Callable[..., Any], hash_type: str = DEFAULT_HASH
) -> bytes:
return hash_raw_module(
DeprivateVisitor().visit(strip_function(fn)), hash_type
)


def hash_cell_group(
cell_ids: set[CellId_t],
graph: DirectedGraph,
hash_type: str = DEFAULT_HASH,
) -> bytes:
hash_alg = hashlib.new(hash_type, usedforsecurity=False)
hashes = []
for cell_id in cell_ids:
cell_impl = graph.cells[cell_id]
hashes.append(hash_cell_impl(cell_impl, hash_alg.name))

# Sort results post hash, to ensure deterministic ordering.
for hashed_cell in sorted(hashes):
hash_alg.update(hashed_cell)
return hash_alg.digest()


def hash_cell_execution(
cell_id: CellId_t, graph: DirectedGraph, hash_type: str = DEFAULT_HASH
) -> bytes:
ancestors = graph.ancestors(cell_id)
return hash_cell_group(ancestors, graph, hash_type)


def standardize_tensor(tensor: Tensor) -> Optional[Tensor]:
if (
hasattr(tensor, "__array__")
Expand Down Expand Up @@ -719,22 +749,6 @@ def serialize_and_dequeue_stateful_content_refs(
refs, inclusive=False
)

for ref in transitive_state_refs:
if ref in scope and isinstance(scope[ref], ShadowedRef):
# TODO(akshayka, dmadisetti): Lift this restriction once
# function args are rewritten.
#
# This makes more sense as a NameError, but the marimo's
# explainer text for NameError's doesn't make sense in this
# context. ("Definition expected in ...")
raise RuntimeError(
f"The cached function declares an argument '{ref}'"
"but a captured function or class uses the "
f"global variable '{ref}'. Please rename "
"the argument, or restructure the use "
f"of the global variable."
)

# Filter for relevant stateful cases.
refs |= set(
filter(
Expand Down Expand Up @@ -776,9 +790,8 @@ def hash_and_dequeue_execution_refs(self, refs: set[Name]) -> set[Name]:
*[self.graph.definitions.get(ref, set()) for ref in refs]
)
to_hash = ancestors & ref_cells
for ancestor_id in sorted(to_hash):
for ancestor_id in to_hash:
cell_impl = self.graph.cells[ancestor_id]
self.hash_alg.update(hash_cell_impl(cell_impl, self.hash_alg.name))
for ref in cell_impl.defs:
# Look for both, since mangle reference depends on the context
# of the definition.
Expand All @@ -787,6 +800,9 @@ def hash_and_dequeue_execution_refs(self, refs: set[Name]) -> set[Name]:
unmangled_ref, _ = unmangle_local(ref)
if unmangled_ref in refs:
refs.remove(unmangled_ref)
self.hash_alg.update(
hash_cell_group(to_hash, self.graph, self.hash_alg.name)
)
return refs

def hash_and_verify_context_refs(
Expand Down
19 changes: 6 additions & 13 deletions marimo/_save/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
from marimo._runtime.context import get_context
from marimo._runtime.runtime import notebook_dir
from marimo._runtime.state import State
from marimo._save.ast import ExtractWithBlock, strip_function
from marimo._save.ast import ARG_PREFIX, ExtractWithBlock, strip_function
from marimo._save.cache import Cache, CacheException
from marimo._save.hash import (
DEFAULT_HASH,
BlockHasher,
ShadowedRef,
cache_attempt_from_hash,
content_cache_attempt_from_base,
)
Expand Down Expand Up @@ -115,21 +114,13 @@ def _set_context(self, fn: Callable[..., Any]) -> None:
# checking a single frame- should be good enough.
f_locals = inspect.stack()[2 + self._frame_offset][0].f_locals
self.scope = {**ctx.globals, **f_locals}
# In case scope shadows variables
#
# TODO(akshayka, dmadisetti): rewrite function args with an AST pass
# to make them unique, deterministically based on function body; this
# will allow for lifting the error when a ShadowedRef is also used
# as a regular ref.
for arg in self._args:
self.scope[arg] = ShadowedRef()

# Scoped refs are references particular to this block, that may not be
# defined out of the context of the block, or the cell.
# For instance, the args of the invoked function are restricted to the
# block.
cell_id = ctx.cell_id or ctx.execution_context.cell_id or ""
self.scoped_refs = set(self._args)
self.scoped_refs = set([f"{ARG_PREFIX}{k}" for k in self._args])
# # As are the "locals" not in globals
self.scoped_refs |= set(f_locals.keys()) - set(ctx.globals.keys())
# Defined in the cell, and currently available in scope
Expand Down Expand Up @@ -184,16 +175,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
self._set_context(args[0])
return self

# Rewrite scoped args to prevent shadowed variables
arg_dict = {f"{ARG_PREFIX}{k}": v for (k, v) in zip(self._args, args)}
kwargs = {f"{ARG_PREFIX}{k}": v for (k, v) in kwargs.items()}
# Capture the call case
arg_dict = {k: v for (k, v) in zip(self._args, args)}
scope = {**self.scope, **get_context().globals, **arg_dict, **kwargs}
assert self._loader is not None, UNEXPECTED_FAILURE_BOILERPLATE
attempt = content_cache_attempt_from_base(
self.base_block,
scope,
self._loader(),
scoped_refs=self.scoped_refs,
required_refs=set(self._args),
required_refs=set([f"{ARG_PREFIX}{k}" for k in self._args]),
as_fn=True,
)

Expand Down
Loading
Loading