Skip to content
8 changes: 8 additions & 0 deletions hugr-py/src/hugr/hugr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,14 @@ def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, No
)
return mapping

def _overwrite_hugr(self, new_hugr: Hugr) -> None:
"""Modify a Hugr in place by replacing contents with those from a new Hugr."""
self.module_root = new_hugr.module_root
self.entrypoint = new_hugr.entrypoint
self._nodes = new_hugr._nodes
self._links = new_hugr._links
self._free_nodes = new_hugr._free_nodes

def _to_serial(self) -> SerialHugr:
"""Serialize the HUGR."""

Expand Down
30 changes: 25 additions & 5 deletions hugr-py/src/hugr/passes/_composable_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol, runtime_checkable

Expand All @@ -16,9 +17,23 @@
class ComposablePass(Protocol):
"""A Protocol which represents a composable Hugr transformation."""

def __call__(self, hugr: Hugr) -> None:
def __call__(self, hugr: Hugr, *, inplace: bool = True) -> Hugr:
"""Call the pass to transform a HUGR."""
...
if inplace:
self._apply_inplace(hugr)
return hugr
else:
return self._apply(hugr)

# At least one of the following _apply methods must be overriden
def _apply(self, hugr: Hugr) -> Hugr:
hugr = deepcopy(hugr)
self._apply_inplace(hugr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we were really concerned....you could have an _in_progress:bool on the instance of ComposablePass (not sure but this might mean setattr??), and then check that you didn't get back here with that set (and error if you do)....

I don't see a great solution tho, no

return hugr

def _apply_inplace(self, hugr: Hugr) -> None:
new_hugr = self._apply(hugr)
hugr._overwrite_hugr(new_hugr)

@property
def name(self) -> str:
Expand Down Expand Up @@ -48,10 +63,15 @@ class ComposedPass(ComposablePass):

passes: list[ComposablePass]

def __call__(self, hugr: Hugr):
"""Call all of the passes in sequence."""
def _apply(self, hugr: Hugr) -> Hugr:
result_hugr = hugr
for comp_pass in self.passes:
result_hugr = comp_pass(result_hugr, inplace=False)
return result_hugr

def _apply_inplace(self, hugr: Hugr) -> None:
for comp_pass in self.passes:
comp_pass(hugr)
comp_pass(hugr, inplace=True)

@property
def name(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions hugr-py/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

def test_composable_pass() -> None:
class MyDummyPass(ComposablePass):
def __call__(self, hugr: Hugr) -> None:
return self(hugr)
def __call__(self, hugr: Hugr, inplace: bool = True) -> Hugr:
return self(hugr, inplace)

dummy = MyDummyPass()

Expand Down
Loading