Skip to content

Commit

Permalink
fix: failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
dmadisetti committed Mar 7, 2025
1 parent b82af1d commit 5b01ba1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 60 deletions.
4 changes: 3 additions & 1 deletion marimo/_ast/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def serialize(
cell = status._cell
assert cell is not None
if not toplevel_fn:
return to_functiondef(cell, status.name, extraction.allowed_refs)
return to_functiondef(
cell, status.name, extraction.unshadowed, None, fn="cell"
)
elif status.is_cell:
return to_functiondef(
cell,
Expand Down
8 changes: 5 additions & 3 deletions marimo/_ast/toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def update(

if self._cell is None:
try:
self._cell = compile_cell(self.code, cell_id=self.cell_id)
self._cell = compile_cell(
self.code, cell_id=self.cell_id
).configure(self.cell_config)
except SyntaxError:
# Keep default
self.type = TopLevelType.UNPARSABLE
Expand Down Expand Up @@ -211,8 +213,8 @@ def __init__(
# Refresh names
names = [status.name for status in self.statuses]

unshadowed_builtins = set(builtins.__dict__.keys()) - defs
self.allowed_refs.update(unshadowed_builtins)
self.unshadowed = set(builtins.__dict__.keys()) - defs
self.allowed_refs.update(self.unshadowed)
self.used_refs = refs

# Now toplevel, "allowed" defs have been determined, we can resolve
Expand Down
72 changes: 16 additions & 56 deletions marimo/_ast/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,67 +384,27 @@ def generic_visit(self, node: ast.AST) -> ast.AST:
super().generic_visit(node)
return node

def _visit_and_get_refs(self, node: ast.AST) -> set[Name]:
"""Create a ref scope for the variable to be declared (e.g. function,
class), visit the children the node, propagate the refs to the higher
scope and then return the refs."""

# TODO: Update for class_definition toplevel decorator.
self.ref_stack.append(set())
self.generic_visit(node)
refs = self.ref_stack.pop()
# The scope a level up from the one just investigated also is dependent
# on these refs. Consider the case:
# >> def foo():
# >> def bar(): <- current scope
# >> print(x)
#
# the variable `foo` needs to be aware that it may require the ref `x`
# during execution.
self.ref_stack[-1].update(refs)
return refs

def _extract_signature_keys(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> dict[str, ast.AST]:
signature_keys = {
# Include args so it's known they are not refs
"args": ast.arguments( # type: ignore[call-overload]
node.args.args
+ node.args.kwonlyargs
+ [
arg
for arg in (node.args.vararg, node.args.kwarg)
if arg is not None
]
),
}
if sys.version_info >= (3, 12):
signature_keys["type_params"] = node.type_params
return signature_keys

def _visit_and_get_fn_refs(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
def _visit_and_get_refs(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]
) -> tuple[set[Name], set[Name]]:
"""Create a ref scope for the variable to be declared (e.g. function,
class), visit the children the node, propagate the refs to the higher
scope and then return the body refs and unbounded refs."""

# Handle function refs that are evaluated in the outer scope
self.ref_stack.append(set())
mock = deepcopy(node)
# Just signature and non-scope parts.
mock.body.clear()
self.generic_visit(mock)
# Collect the unbounded refs
unbounded_refs = set(self.ref_stack[-1])
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Handle function refs that are evaluated in the outer scope
# Remove the body, which keeps signature and non-scoped parts.
mock = deepcopy(node)
mock.body.clear()
self.generic_visit(mock)
# Collect the unbounded refs
unbounded_refs = set(self.ref_stack[-1])
else:
# TODO: Update for class_definition toplevel decorator.
unbounded_refs = set()

# Process the function body
# module_stub = ast.FunctionDef(
# name=node.name,
# body=node.body,
# **_extract_signature_keys(node),
# )
self.generic_visit(node)
refs = self.ref_stack.pop()
# The scope a level up from the one just investigated also is dependent
Expand All @@ -462,7 +422,7 @@ def _visit_and_get_fn_refs(
# ClassDef and FunctionDef nodes don't have ast.Name nodes as children
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
node.name = self._if_local_then_mangle(node.name)
refs = self._visit_and_get_refs(node)
refs, _ = self._visit_and_get_refs(node)
self._define(
node,
node.name,
Expand All @@ -474,7 +434,7 @@ def visit_AsyncFunctionDef(
self, node: ast.AsyncFunctionDef
) -> ast.AsyncFunctionDef:
node.name = self._if_local_then_mangle(node.name)
refs, unbounded_refs = self._visit_and_get_fn_refs(node)
refs, unbounded_refs = self._visit_and_get_refs(node)
self._define(
node,
node.name,
Expand All @@ -488,7 +448,7 @@ def visit_AsyncFunctionDef(

def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
node.name = self._if_local_then_mangle(node.name)
refs, unbounded_refs = self._visit_and_get_fn_refs(node)
refs, unbounded_refs = self._visit_and_get_refs(node)
self._define(
node,
node.name,
Expand Down

0 comments on commit 5b01ba1

Please sign in to comment.