Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,10 +816,57 @@ class GraphExecutionState(BaseModel):
# Optional priority; others follow in name order
ready_order: list[str] = Field(default_factory=list)
indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes")
_iteration_path_cache: dict[str, tuple[int, ...]] = PrivateAttr(default_factory=dict)

def _type_key(self, node_obj: BaseInvocation) -> str:
return node_obj.__class__.__name__

def _get_iteration_path(self, exec_node_id: str) -> tuple[int, ...]:
"""Best-effort outer->inner iteration indices for an execution node, stopping at collectors."""
cached = self._iteration_path_cache.get(exec_node_id)
if cached is not None:
return cached

# Only prepared execution nodes participate; otherwise treat as non-iterated.
source_node_id = self.prepared_source_mapping.get(exec_node_id)
if source_node_id is None:
self._iteration_path_cache[exec_node_id] = ()
return ()

# Source-graph iterator ancestry, with edges into collectors removed so iteration context doesn't leak.
it_g = self._iterator_graph(self.graph.nx_graph())
iterator_sources = [
n for n in nx.ancestors(it_g, source_node_id) if isinstance(self.graph.get_node(n), IterateInvocation)
]

# Order iterators outer->inner via topo order of the iterator graph.
topo = list(nx.topological_sort(it_g))
topo_index = {n: i for i, n in enumerate(topo)}
iterator_sources.sort(key=lambda n: topo_index.get(n, 0))

# Map iterator source nodes to the prepared iterator exec nodes that are ancestors of exec_node_id.
eg = self.execution_graph.nx_graph()
path: list[int] = []
for it_src in iterator_sources:
prepared = self.source_prepared_mapping.get(it_src)
if not prepared:
continue
it_exec = next((p for p in prepared if nx.has_path(eg, p, exec_node_id)), None)
if it_exec is None:
continue
it_node = self.execution_graph.nodes.get(it_exec)
if isinstance(it_node, IterateInvocation):
path.append(it_node.index)

# If this exec node is itself an iterator, include its own index as the innermost element.
node_obj = self.execution_graph.nodes.get(exec_node_id)
if isinstance(node_obj, IterateInvocation):
path.append(node_obj.index)

result = tuple(path)
self._iteration_path_cache[exec_node_id] = result
return result

def _queue_for(self, cls_name: str) -> Deque[str]:
q = self._ready_queues.get(cls_name)
if q is None:
Expand All @@ -843,7 +890,15 @@ def _enqueue_if_ready(self, nid: str) -> None:
if self.indegree[nid] != 0 or nid in self.executed:
return
node_obj = self.execution_graph.nodes[nid]
self._queue_for(self._type_key(node_obj)).append(nid)
q = self._queue_for(self._type_key(node_obj))
nid_path = self._get_iteration_path(nid)
# Insert in lexicographic outer->inner order; preserve FIFO for equal paths.
for i, existing in enumerate(q):
if self._get_iteration_path(existing) > nid_path:
q.insert(i, nid)
break
else:
q.append(nid)

model_config = ConfigDict(
json_schema_extra={
Expand Down Expand Up @@ -1083,12 +1138,12 @@ def no_unexecuted_iter_ancestors(n: str) -> bool:

# Select the correct prepared parents for each iteration
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat()
prepared_parent_mappings = [
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
for it in iterator_node_prepared_combinations
] # type: ignore
prepared_parent_mappings = [m for m in prepared_parent_mappings if all(p[1] is not None for p in m)]

# Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings:
Expand All @@ -1110,15 +1165,17 @@ def _get_iteration_node(
if len(prepared_nodes) == 1:
return next(iter(prepared_nodes))

# Check if the requested node is an iterator
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
if prepared_iterator is not None:
return prepared_iterator

# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]

# If the requested node is an iterator, only accept it if it is compatible with all parent iterators
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
if prepared_iterator is not None:
if all(nx.has_path(execution_graph, pit[0], prepared_iterator) for pit in parent_iterators):
return prepared_iterator
return None

return next(
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
None,
Expand Down
45 changes: 45 additions & 0 deletions tests/test_graph_execution_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,48 @@ def test_graph_iterate_execution_order(execution_number: int):
_ = invoke_next(g)
assert _[1].item == "Dinosaur Sushi"
_ = invoke_next(g)


# Because this tests deterministic ordering, we run it multiple times
@pytest.mark.parametrize("execution_number", range(5))
def test_graph_nested_iterate_execution_order(execution_number: int):
"""
Validates best-effort in-order execution for nodes expanded under nested iterators.
Expected lexicographic order by (outer_index, inner_index), subject to readiness.
"""
graph = Graph()

# Outer iterator: [0, 1]
graph.add_node(RangeInvocation(id="outer_range", start=0, stop=2, step=1))
graph.add_node(IterateInvocation(id="outer_iter"))

# Inner iterator is derived from the outer item:
# start = outer_item * 10
# stop = start + 2 => yields 2 items per outer item
graph.add_node(MultiplyInvocation(id="mul10", b=10))
graph.add_node(AddInvocation(id="stop_plus2", b=2))
graph.add_node(RangeInvocation(id="inner_range", start=0, stop=1, step=1))
graph.add_node(IterateInvocation(id="inner_iter"))

# Observe inner items (they encode outer via start=outer*10)
graph.add_node(AddInvocation(id="sum", b=0))

graph.add_edge(create_edge("outer_range", "collection", "outer_iter", "collection"))
graph.add_edge(create_edge("outer_iter", "item", "mul10", "a"))
graph.add_edge(create_edge("mul10", "value", "stop_plus2", "a"))
graph.add_edge(create_edge("mul10", "value", "inner_range", "start"))
graph.add_edge(create_edge("stop_plus2", "value", "inner_range", "stop"))
graph.add_edge(create_edge("inner_range", "collection", "inner_iter", "collection"))
graph.add_edge(create_edge("inner_iter", "item", "sum", "a"))

g = GraphExecutionState(graph=graph)
sum_values: list[int] = []

while True:
n, o = invoke_next(g)
if n is None:
break
if g.prepared_source_mapping[n.id] == "sum":
sum_values.append(o.value)

assert sum_values == [0, 1, 10, 11]