diff --git a/hugr-py/src/hugr/model/__init__.py b/hugr-py/src/hugr/model/__init__.py index ac97177cd2..7b7273167d 100644 --- a/hugr-py/src/hugr/model/__init__.py +++ b/hugr-py/src/hugr/model/__init__.py @@ -1,9 +1,10 @@ """HUGR model data structures.""" -from collections.abc import Sequence +from abc import ABC +from collections.abc import Generator, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Protocol +from typing import Optional from semver import Version @@ -21,7 +22,7 @@ def _current_version() -> Version: CURRENT_VERSION: Version = _current_version() -class Term(Protocol): +class Term(ABC): """A model term for static data such as types, constants and metadata.""" def __str__(self) -> str: @@ -33,6 +34,26 @@ def from_str(s: str) -> "Term": """Read the term from its string representation.""" return rust.string_to_term(s) + def to_list_parts(self) -> Generator["SeqPart"]: + if isinstance(self, List): + for part in self.parts: + if isinstance(part, Splice): + yield from part.seq.to_list_parts() + else: + yield part + else: + yield Splice(self) + + def to_tuple_parts(self) -> Generator["SeqPart"]: + if isinstance(self, Tuple): + for part in self.parts: + if isinstance(part, Splice): + yield from part.seq.to_tuple_parts() + else: + yield part + else: + yield Splice(self) + @dataclass(frozen=True) class Wildcard(Term): @@ -129,9 +150,13 @@ def from_str(s: str) -> "Symbol": return rust.string_to_symbol(s) -class Op(Protocol): +class Op(ABC): """The operation of a node.""" + def symbol_name(self) -> str | None: + """Returns name of the symbol introduced by this node, if any.""" + return None + @dataclass(frozen=True) class InvalidOp(Op): @@ -159,6 +184,9 @@ class DefineFunc(Op): symbol: Symbol + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class DeclareFunc(Op): @@ -166,6 +194,9 @@ class DeclareFunc(Op): symbol: Symbol + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class CustomOp(Op): @@ -181,6 +212,9 @@ class DefineAlias(Op): symbol: Symbol value: Term + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class DeclareAlias(Op): @@ -188,6 +222,9 @@ class DeclareAlias(Op): symbol: Symbol + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class TailLoop(Op): @@ -205,6 +242,9 @@ class DeclareConstructor(Op): symbol: Symbol + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class DeclareOperation(Op): @@ -212,6 +252,9 @@ class DeclareOperation(Op): symbol: Symbol + def symbol_name(self) -> str | None: + return self.symbol.name + @dataclass(frozen=True) class Import(Op): @@ -219,6 +262,9 @@ class Import(Op): name: str + def symbol_name(self) -> str | None: + return self.name + @dataclass class Node: diff --git a/hugr-py/src/hugr/model/export.py b/hugr-py/src/hugr/model/export.py index 8bef32e452..e354c0e165 100644 --- a/hugr-py/src/hugr/model/export.py +++ b/hugr-py/src/hugr/model/export.py @@ -5,8 +5,8 @@ from typing import Generic, TypeVar, cast import hugr.model as model -from hugr.hugr.base import Hugr, Node -from hugr.hugr.node_port import InPort, OutPort +from hugr.hugr.base import Hugr +from hugr.hugr.node_port import InPort, Node, OutPort from hugr.ops import ( CFG, DFG, diff --git a/hugr-py/src/hugr/model/load.py b/hugr-py/src/hugr/model/load.py new file mode 100644 index 0000000000..cd54293baa --- /dev/null +++ b/hugr-py/src/hugr/model/load.py @@ -0,0 +1,739 @@ +from collections.abc import Generator, Iterable +from dataclasses import dataclass, field +from typing import List, Tuple, Dict, Any +import json + +import hugr.model as model +from hugr import val +from hugr.hugr import InPort, OutPort +from hugr.hugr.base import Hugr +from hugr.hugr.node_port import Node +from hugr.std.int import IntVal +from hugr.std.float import FloatVal +from hugr.std.collections.array import ArrayVal, Array +from hugr.ops import DFG, Case, Conditional, Custom, FuncDecl, FuncDefn, Op, TailLoop, Input, Output +from hugr.tys import ( + BoundedNatArg, + BoundedNatParam, + BytesArg, + BytesParam, + ConstParam, + FloatArg, + FloatParam, + FunctionType, + ListArg, + ListConcatArg, + ListParam, + Opaque, + PolyFuncType, + RowVariable, + StringArg, + StringParam, + Sum, + TupleArg, + TupleConcatArg, + TupleParam, + Type, + TypeArg, + TypeBound, + TypeParam, + TypeTypeParam, + Variable, +) + +ImportContext = model.Term | model.Node | model.Region | str + + +class ModelImportError(Exception): + """Exception raised when importing from the model representation fails.""" + + def __init__(self, message: str, location: ImportContext | None = None): + self.message = message + self.location = location + + match location: + case model.Term() as term: + location_error = f"Error caused by term:\n```\n{term}\n```" + case model.Region() as region: + location_error = f"Error caused by region:\n```\n{region}\n```" + case model.Node() as node: + location_error = f"Error caused by node:\n```\n{node}\n```" + case str() as other: + location_error = other + case None: + location_error = "Error in unspecified location." + + super().__init__(f"{message}\n{location_error}") + + +class ModelImport: + local_vars: dict[str, "LocalVarData"] + current_symbol: str | None + link_ports_in: list[dict[str, list[InPort]]] + link_ports_out: list[dict[str, list[OutPort]]] + static_edges: list[tuple[Node, Node]] + + module: model.Module + symbols: dict[str, model.Node] + hugr: Hugr + + def __init__(self, module: model.Module): + self.local_vars = {} + self.current_symbol = None + self.module = module + self.symbols = {} + self.hugr = Hugr() + self.link_ports_in = [] + self.link_ports_out = [] + self.static_edges = [] + + for node in module.root.children: + symbol_name = node.operation.symbol_name() + + if symbol_name is None: + continue + + if symbol_name in self.symbols: + error = f"Duplicate symbol name `{symbol_name}`." + raise ModelImportError(error, node) + + self.symbols[symbol_name] = node + + def add_node(self, node: model.Node, operation: Op, parent: Node) -> Node: + metadata = self.import_meta_json(node) + node_id = self.hugr.add_node(operation, parent, metadata = metadata) + self.record_in_links(node_id, node.inputs) + self.record_out_links(node_id, node.outputs) + return node_id + + def record_in_links(self, node: Node, links: Iterable[str]): + link_ports_in = self.link_ports_in[len(self.link_ports_in) - 1] + + for offset, link in enumerate(links): + in_port = InPort(node=node, offset=offset) + link_ports_in.setdefault(link, []).append(in_port) + + def record_out_links(self, node: Node, links: Iterable[str]): + link_ports_out = self.link_ports_out[len(self.link_ports_out) - 1] + + for offset, link in enumerate(links): + out_port = OutPort(node=node, offset=offset) + link_ports_out.setdefault(link, []).append(out_port) + + + def link_ports(self): + link_ports_in = self.link_ports_in[len(self.link_ports_in) - 1] + link_ports_out = self.link_ports_out[len(self.link_ports_out) - 1] + + links = link_ports_in.keys() | link_ports_out.keys() + + for link in links: + in_ports = link_ports_in[link] + out_ports = link_ports_out[link] + + match in_ports, out_ports: + case [[], []]: + assert False + case _, [out_port]: + for in_port in in_ports: + self.hugr.add_link(out_port, in_port) + case [[in_port], _]: + for out_port in out_ports: + self.hugr.add_link(out_port, in_port) + case _, _: + error = f"Link `{link}` has multiple inputs and outputs." + raise ModelImportError(error) + + def link_static_ports(self): + for src, dst in self.static_edges: + out_port_offset = self.hugr.num_out_ports(src) - 1 + out_port = OutPort(node=src, offset=out_port_offset) + + in_port_offset = self.hugr.num_in_ports(dst) - 1 + in_port = InPort(node=dst, offset=in_port_offset) + + self.hugr.add_link(out_port, in_port) + + def import_dfg_region( + self, region: model.Region, parent: Node, isolated: bool = False + ): + if isolated: + self.link_ports_in.append({}) + self.link_ports_out.append({}) + + signature = self.import_signature(region.signature) + + input_node = self.hugr.add_node(Input(signature.input)) + self.record_out_links(input_node, region.sources) + + output_node = self.hugr.add_node(Output(signature.output)) + self.record_in_links(output_node, region.targets) + + order_data = self.import_meta_order_region(region) + order_data.add_node_keys(input_node, order_data.input_keys) + order_data.add_node_keys(output_node, order_data.output_keys) + + for child in region.children: + child_id = self.import_node_in_dfg(child, parent) + child_order_keys = self.import_meta_order_keys(child) + order_data.add_node_keys(child_id, child_order_keys) + + for src_key, tgt_key in order_data.edges: + src_node = order_data.get_node_by_key(src_key) + tgt_node = order_data.get_node_by_key(tgt_key) + self.hugr.add_order_link(src_node, tgt_node) + + if isolated: + self.link_ports_in.pop() + self.link_ports_out.pop() + + def import_node_in_dfg(self, node: model.Node, parent: Node) -> Node: + def import_dfg_node() -> Node: + match node.regions: + case [body]: + pass + case _: + raise ModelImportError("DFG node expects a dataflow region.", node) + + signature = self.import_signature(node.signature) + node_id = self.add_node( + node, DFG(signature.input, signature.output), parent + ) + self.import_dfg_region(body, node_id) + return node_id + + def import_tail_loop() -> Node: + match node.regions: + case [body]: + pass + case _: + raise ModelImportError("Loop node expects a dataflow region.", node) + + match body.signature: + case model.Apply("core.fn", [_, body_outputs]): + pass + case _: + error = "Tail loop body expects `(core.fn _ _)` signature." + raise ModelImportError(error, node) + + match list(import_closed_list(body_outputs)): + case [model.Apply("core.adt", [variants]), *rest]: + pass + case _: + error = "TailLoop body expects `(core.adt _)` as first target type." + raise ModelImportError(error, node) + + match list(import_closed_list(variants)): + case [just_inputs, just_outputs]: + pass + case _: + raise ModelImportError( + "TailLoop body expects sum type with two variants.", node + ) + + node_id = self.add_node( + node, + TailLoop( + just_inputs=self.import_type_row(just_inputs), + rest=[self.import_type(t) for t in rest], + _just_outputs=self.import_type_row(just_outputs), + ), + parent, + ) + self.import_dfg_region(body, node_id) + return node_id + + def import_custom_node(op: model.Term) -> Node: + match op: + case model.Apply(symbol, args): + extension, op_name = split_extension_name(symbol) + case _: + raise ModelImportError( + "The operation of a custom node must be a symbol application.", + node, + ) + + return self.add_node( + node, + Custom( + op_name=op_name, + extension=extension, + signature=self.import_signature(node.signature), + args=[self.import_type_arg(arg) for arg in args], + ), + parent, + ) + + def import_cfg() -> Node: ... + + def import_conditional() -> Node: + match node.signature: + case model.Apply("core.fn", [inputs, outputs]): + pass + case _: + error = "Conditional node expects `(core.fn _ _)` signature." + raise ModelImportError(error, node) + + match list(import_closed_list(inputs)): + case [model.Apply("core.adt", [variants]), *other_inputs]: + sum_ty = Sum( + [ + self.import_type_row(variant) + for variant in import_closed_list(variants) + ] + ) + case _: + error = ( + "Conditional node expects `(core.adt _)` as first input type." + ) + raise ModelImportError( + error, + node, + ) + + node_id = self.add_node( + node, + Conditional( + sum_ty=sum_ty, + other_inputs=[self.import_type(t) for t in other_inputs], + _outputs=self.import_type_row(outputs) + ), + parent, + ) + + for case_body in node.regions: + case_signature = self.import_signature(case_body.signature) + case_id = self.hugr.add_node( + Case(inputs=case_signature.input, _outputs=case_signature.output), + node_id, + ) + self.import_dfg_region(case_body, case_id) + + return node_id + + match node.operation: + case model.InvalidOp(): + error = "Invalid operation can not be imported." + raise ModelImportError(error, node) + case model.Dfg(): + return import_dfg_node() + case model.Cfg(): + return import_cfg() + case model.Block(): + error = "Unexpected basic block." + raise ModelImportError(error, node) + case model.CustomOp(op): + return import_custom_node(op) + case model.TailLoop(): + return import_tail_loop() + case model.Conditional(): + return import_conditional() + case _: + error = "Unexpected node in DFG region." + raise ModelImportError(error, node) + + def import_node_in_module(self, node: model.Node, parent: Node) -> Node | None: + def import_declare_func(symbol: model.Symbol) -> Node: + title = self.import_meta_title(node) + f_name = symbol.name if title is None else title + signature = self.enter_symbol(symbol) + node_id = self.add_node( + node, + FuncDecl( + f_name=f_name, signature=signature, visibility=symbol.visibility + ), + parent, + ) + self.exit_symbol() + return node_id + + def import_define_func(symbol: model.Symbol) -> Node: + title = self.import_meta_title(node) + f_name = symbol.name if title is None else title + signature = self.enter_symbol(symbol) + node_id = self.add_node( + node, + FuncDefn( + f_name=f_name, + inputs=signature.body.input, + _outputs=signature.body.output, + params=signature.params, + visibility=symbol.visibility, + ), + parent, + ) + + match node.regions: + case [body]: + pass + case _: + error = "Function definition expects a single region." + raise ModelImportError(error, node) + + self.import_dfg_region(body, node_id, isolated=True) + self.exit_symbol() + return node_id + + match node.operation: + case model.DeclareFunc(symbol): + return import_declare_func(symbol) + case model.DefineFunc(symbol): + return import_define_func(symbol) + case model.DeclareAlias(): + error = "Aliases unsupported for now." + raise ModelImportError(error, node) + case model.DefineAlias(): + error = "Aliases unsupported for now." + raise ModelImportError(error, node) + case model.Import(): + return None + case model.DeclareConstructor(): + return None + case model.DeclareOperation(): + return None + case _: + error = "Unexpected node in module region." + raise ModelImportError(error, node) + + def enter_symbol(self, symbol: model.Symbol) -> PolyFuncType: + assert len(self.local_vars) == 0 + + bounds: Dict[str, TypeBound] = {} + + for constraint in symbol.constraints: + match constraint: + case model.Apply("core.nonlinear", [model.Var(name)]): + bounds[name] = TypeBound.Copyable + case _: + error = "Constraint other than `core.nonlinear` on a variable." + raise ModelImportError(error, constraint) + + param_types: List[TypeParam] = [] + + for index, param in enumerate(symbol.params): + bound = bounds[param.name] if param.name in bounds else TypeBound.Linear + type = self.import_type_param(param.type, bound = bound) + self.local_vars[param.name] = LocalVarData(index, type) + param_types.append(type) + + body = self.import_signature(symbol.signature) + return PolyFuncType(param_types, body) + + def exit_symbol(self): + self.local_vars = {} + + def import_signature(self, term: model.Term | None) -> FunctionType: + match term: + case None: + error = "Signature required." + raise ModelImportError(error) + case model.Apply("core.fn", [inputs, outputs]): + return FunctionType( + self.import_type_row(inputs), self.import_type_row(outputs) + ) + case _: + error = "Invalid signature." + raise ModelImportError(error, term) + + def lookup_var(self, name: str) -> "LocalVarData": + if name in self.local_vars: + error = f"Unknown variable `{name}`." + raise ImportError(error) + + return self.local_vars[name] + + def import_type_param(self, term: model.Term, bound: TypeBound = TypeBound.Linear) -> TypeParam: + """Import a TypeParam from a model Term.""" + match term: + case model.Apply("core.nat"): + return BoundedNatParam() + case model.Apply("core.str"): + return StringParam() + case model.Apply("core.float"): + return FloatParam() + case model.Apply("core.bytes"): + return BytesParam() + case model.Apply("core.type"): + return TypeTypeParam(bound) + case model.Apply("core.list", [item_type]): + return ListParam(self.import_type_param(item_type)) + case model.Apply("core.tuple", [item_types]): + return TupleParam( + [ + self.import_type_param(item_type) + for item_type in import_closed_list(item_types) + ] + ) + case model.Apply("core.const", [runtime_type]): + return ConstParam(self.import_type(runtime_type)) + case _: + error = "Failed to import TypeParam." + raise ModelImportError(error, term) + + def import_type_arg(self, term: model.Term) -> TypeArg: + """Import a TypeArg from a model Term.""" + + def import_list(term: model.Term) -> TypeArg: + lists: list[TypeArg] = [] + + for group in group_seq_parts(term.to_list_parts()): + if isinstance(group, list): + lists.append( + ListArg([self.import_type_arg(item) for item in group]) + ) + else: + lists.append(self.import_type_arg(group)) + + return ListConcatArg(lists).flatten() + + def import_tuple(term: model.Term) -> TypeArg: + tuples: list[TypeArg] = [] + + for group in group_seq_parts(term.to_list_parts()): + if isinstance(group, list): + tuples.append( + TupleArg([self.import_type_arg(item) for item in group]) + ) + else: + tuples.append(self.import_type_arg(group)) + + return TupleConcatArg(tuples).flatten() + + # TODO: TypeTypeArg + + match term: + case model.Literal(str() as value): + return StringArg(value) + case model.Literal(int() as value): + return BoundedNatArg(value) + case model.Literal(float() as value): + return FloatArg(value) + case model.Literal(bytes() as value): + return BytesArg(value) + case model.List(): + return import_list(term) + case model.Tuple(): + return import_tuple(term) + case _: + error = "Failed to import TypeArg." + raise ModelImportError(error, term) + + def import_type(self, term: model.Term) -> Type: + """Import the type from a model Term.""" + match term: + case model.Apply("core.fn", [inputs, outputs]): + return FunctionType( + self.import_type_row(inputs), self.import_type_row(outputs) + ) + case model.Apply("core.adt", [variants]): + return Sum( + [ + self.import_type_row(variant) + for variant in import_closed_list(variants) + ] + ) + case model.Apply(symbol, args): + extension, id = split_extension_name(symbol) + return Opaque( + id=id, + extension=extension, + bound=TypeBound.Linear, + args=[self.import_type_arg(arg) for arg in args], + ) + case model.Var(name): + var_data = self.lookup_var(name) + return Variable(idx=var_data.index, bound=var_data.bound) + case _: + error = "Failed to import Type." + raise ModelImportError(error, term) + + def import_type_row(self, term: model.Term) -> list[Type]: + def import_part(part: model.SeqPart) -> Type: + if isinstance(part, model.Splice): + if isinstance(part.seq, model.Var): + var_data = self.lookup_var(part.seq.name) + return RowVariable(var_data.index, var_data.bound) + else: + error = "Can only import spliced variables." + raise ModelImportError(error, term) + else: + return self.import_type(part) + + return [import_part(part) for part in term.to_list_parts()] + + def import_meta_json(self, node: model.Node) -> Dict[str, Any]: + """Collects the `core.meta_json` metadata on the given node.""" + + metadata = {} + + for meta in node.meta: + match meta: + case model.Apply("compat.meta_json", [ + model.Literal(str() as key), + model.Literal(str() as value) + ]): + pass + case _: + continue + + try: + decoded = json.loads(value) + except json.JSONDecodeError: + error = "Failed to decode JSON metadata." + raise ModelImportError(error, node) + + metadata[key] = decoded + + return metadata + + def import_meta_title(self, node: model.Node) -> str | None: + """Searches for `core.title` metadata on the given node.""" + for meta in node.meta: + match meta: + case model.Apply("core.title", [model.Literal(str() as title)]): + return title + case model.Apply("core.title"): + error = "Invalid instance of `core.title` metadata." + raise ModelImportError(error, meta) + case _: + pass + + return None + + def import_meta_order_region(self, region: model.Region) -> "RegionOrderHints": + """Searches for order hint metadata on the given region.""" + data = RegionOrderHints() + + for meta in region.meta: + match meta: + case model.Apply( + "core.order_hint.input_key", + [model.Literal(int() as key)] + ): + data.input_keys.append(key) + case model.Apply( + "core.order_hint.output_key", + [model.Literal(int() as key)] + ): + data.output_keys.append(key) + case model.Apply( + "core.order_hint.order", + [model.Literal(int() as before), model.Literal(int() as after)] + ): + data.edges.append((before, after)) + case _: + pass + + return data + + def import_meta_order_keys(self, node: model.Node) -> List[int]: + """Collects all order hint keys in the metadata of a node.""" + keys = [] + + for meta in node.meta: + match meta: + case model.Apply( + "core.order_hint.key", + [model.Literal(int() as key)] + ): + keys.append(key) + case _: + pass + + return keys + + def import_value(self, term: model.Term) -> val.Value: + match term: + case model.Apply("arithmetic.int.const", [ + model.Literal(int() as int_bitwidth), + model.Literal(int() as int_value) + ]): + return IntVal(int_value, int_bitwidth) + case model.Apply("arithmetic.float.const_f64", [model.Literal(float() as float_value)]): + return FloatVal(float_value) + case model.Apply("collections.array.const", [ + _, + array_item_type, + array_items + ]): + raise NotImplementedError("Import array constants") + case model.Apply("compat.const_json", [ + model.Literal(str() as json), + ]): + # TODO + raise NotImplementedError("Import json encoded constants") + case _: + error = "Unsupported constant value." + raise ModelImportError(error, term) + + +@dataclass +class LocalVarData: + index: int + type: TypeParam + bound: TypeBound = field(default=TypeBound.Linear) + +@dataclass +class RegionOrderHints: + input_keys: list[int] = field(default = []) + output_keys: list[int] = field(default = []) + edges: List[Tuple[int, int]] = field(default = []) + key_to_node: Dict[int, Node] = field(default = {}) + + def add_node_keys(self, node: Node, keys: Iterable[int]): + for key in keys: + if key in self.key_to_node: + error = f"Duplicate order key `{key}`." + raise ModelImportError(error) + + self.key_to_node[key] = node + + def get_node_by_key(self, key: int) -> Node: + if key in self.key_to_node: + error = f"Unknown order key `{key}`." + raise ModelImportError(error) + + return self.key_to_node[key] + +def group_seq_parts( + parts: Iterable[model.SeqPart], +) -> Generator[model.Term | list[model.Term]]: + group: list[model.Term] = [] + + for part in parts: + if isinstance(part, model.Splice): + if len(group) > 0: + yield group + group = [] + yield part.seq + else: + group.append(part) + + if len(group) > 0: + yield group + + +def import_closed_list(term: model.Term) -> Generator[model.Term]: + for part in term.to_list_parts(): + if isinstance(part, model.Splice): + raise ModelImportError("Expected closed list.", term) + else: + yield part + + +def import_closed_tuple(term: model.Term) -> Generator[model.Term]: + for part in term.to_tuple_parts(): + if isinstance(part, model.Splice): + raise ModelImportError("Expected closed tuple.", term) + else: + yield part + + +def split_extension_name(name: str) -> tuple[str, str]: + match name.rsplit(".", 1): + case [extension, id]: + return (extension, id) + case [id]: + return ("", id) + case _: + assert False diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index e17d39653a..99ca320d2c 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -1,7 +1,5 @@ """HUGR edge kinds, types, type parameters and type arguments.""" -from __future__ import annotations - import base64 from dataclasses import dataclass, field from typing import TYPE_CHECKING, Literal, Protocol, cast, runtime_checkable @@ -364,6 +362,15 @@ def to_model(self) -> model.Term: [model.Splice(cast(model.Term, elem.to_model())) for elem in self.lists] ) + def flatten(self) -> TypeArg: + match self.lists: + case []: + return ListArg([]) + case [item]: + return item + case _: + return self + @dataclass(frozen=True) class TupleArg(TypeArg): @@ -405,6 +412,15 @@ def to_model(self) -> model.Term: [model.Splice(cast(model.Term, elem.to_model())) for elem in self.tuples] ) + def flatten(self) -> TypeArg: + match self.tuples: + case []: + return TupleArg([]) + case [item]: + return item + case _: + return self + @dataclass(frozen=True) class VariableArg(TypeArg):