Skip to content
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1adca71
Scaffold
justinchuby Jun 29, 2025
f48537b
Add support for expr in SymbolicDim
justinchuby Jun 29, 2025
7464af1
wip
justinchuby Jun 29, 2025
27ae0ca
wip
justinchuby Jun 29, 2025
1afe64a
Create NodeInferencer
justinchuby Jun 29, 2025
78ad6e0
inference_common
justinchuby Jun 29, 2025
5aa2df7
Update shapes
justinchuby Jun 29, 2025
dbc3593
update
justinchuby Jun 29, 2025
b9f0528
Claude - add sympy import
justinchuby Jun 30, 2025
c9a35b7
Claude and lint
justinchuby Jun 30, 2025
65e3dd2
concat
justinchuby Jun 30, 2025
7960770
Update _maybe_convert_to_symbolic_dim
justinchuby Jun 30, 2025
a7704c5
reshape
justinchuby Jun 30, 2025
922a597
Update the way dim is set
justinchuby Jun 30, 2025
9183848
Simplify
justinchuby Jun 30, 2025
9300aba
Update
justinchuby Jun 30, 2025
8747a93
Handle unknown dims
justinchuby Jun 30, 2025
92049c4
Simplify
justinchuby Jun 30, 2025
720845e
Create inclusive range
justinchuby Jun 30, 2025
bae78ab
WIP inference engine
justinchuby Jun 30, 2025
a77f487
Create readme
justinchuby Jun 30, 2025
6686457
Result
justinchuby Jun 30, 2025
3207e84
Summary of Complete Refactoring
justinchuby Jun 30, 2025
a572145
lint
justinchuby Jun 30, 2025
11f8958
Removes unused shape inference code
justinchuby Jun 30, 2025
f3c70da
Summary of Shape Simplifications
justinchuby Jun 30, 2025
4b6d80d
Create factory
justinchuby Jun 30, 2025
e03733b
Use Enum
justinchuby Jun 30, 2025
5a34891
Update logging calls
justinchuby Jun 30, 2025
ab09107
Working on engine
justinchuby Jun 30, 2025
9256233
todo
justinchuby Jun 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
"Programming Language :: Python :: 3.13",
"License :: OSI Approved :: Apache Software License",
]
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"]
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes", "sympy"]

[project.urls]
Homepage = "https://onnx.ai/ir-py"
Expand Down
35 changes: 29 additions & 6 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

import ml_dtypes
import numpy as np
import sympy
import sympy.utilities.misc
from typing_extensions import TypeIs

import onnx_ir
Expand Down Expand Up @@ -1115,13 +1117,14 @@
It is immutable and can be compared or hashed.
"""

__slots__ = ("_value",)
__slots__ = ("_expr", "_value")

def __init__(self, value: str | None) -> None:
def __init__(self, value: str | None, /, expr: sympy.Expr | None = None) -> None:
"""Initialize a symbolic dimension.

Args:
value: The value of the dimension. It should not be an int.
expr: An optional sympy expression representing the dimension.

Raises:
TypeError: If value is an int.
Expand All @@ -1132,6 +1135,7 @@
"If you are creating a Shape, use int directly instead of SymbolicDim."
)
self._value = value
self._expr: sympy.Expr | None = expr

def __eq__(self, other: object) -> bool:
"""Check equality with another SymbolicDim or string/None."""
Expand All @@ -1148,11 +1152,24 @@
"""The value of the symbolic dimension (string or None)."""
return self._value

@property
def expr(self) -> sympy.Expr | None:
"""The sympy expression representing the symbolic dimension."""
return self._expr

Check warning on line 1158 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1158

Added line #L1158 was not covered by tests

def __str__(self) -> str:
return f"{self._value}"
if self._value is not None:
return str(self._value)

Check warning on line 1162 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1162

Added line #L1162 was not covered by tests
if self._expr is not None:
return str(self._expr)
return "?"

Check warning on line 1165 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1164-L1165

Added lines #L1164 - L1165 were not covered by tests

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._value})"
if self._expr is not None:
expr_text = f", expr={self._expr!r}"

Check warning on line 1169 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1169

Added line #L1169 was not covered by tests
else:
expr_text = ""
return f"{self.__class__.__name__}({self._value}{expr_text})"

Check warning on line 1172 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1171-L1172

Added lines #L1171 - L1172 were not covered by tests


def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
Expand Down Expand Up @@ -1190,10 +1207,16 @@
return SymbolicDim(dim)
if _is_int_compatible(dim):
return int(dim)
if isinstance(dim, sympy.Expr):
# If the dimension is a sympy expression, we create a SymbolicDim with it
expr = sympy.sympify(dim)

Check warning on line 1212 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1212

Added line #L1212 was not covered by tests
if expr.is_integer:
return sympy.utilities.misc.as_int(expr)
return SymbolicDim(str(expr), expr=sympy.sympify(expr))

Check warning on line 1215 in src/onnx_ir/_core.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_core.py#L1214-L1215

Added lines #L1214 - L1215 were not covered by tests
if isinstance(dim, SymbolicDim):
return dim
raise TypeError(
f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
f"Expected int, str, sympy.Expr, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
)


Expand Down Expand Up @@ -1334,7 +1357,7 @@
def __getitem__(self, index):
return tuple(self._dims)[index]

def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None:
def __setitem__(self, index: int, value: int | SymbolicDim | str | sympy.Expr | None) -> None:
"""Set the dimension at the index.

Args:
Expand Down
2 changes: 2 additions & 0 deletions src/onnx_ir/_shape_type_inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class SymbolicInferenceEngine:
pass
127 changes: 127 additions & 0 deletions src/onnx_ir/_shape_type_inference/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Symbolic shape inference for ONNX IR."""

from __future__ import annotations

import abc
import dataclasses
from collections.abc import Collection, Sequence
import functools
from typing import Any, TypeVar, Callable
import numpy as np
import sympy

import onnx_ir as ir


def get_expr(shape: ir.Shape, index: int) -> sympy.Expr:
"""Get the expression or value at a specific index in the shape.

Args:
shape: The shape to get the expression from.
index: The index of the dimension to get.

Returns:
The expression or value at the specified index.
"""
dim = shape[index]

Check warning on line 26 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L26

Added line #L26 was not covered by tests
if isinstance(dim, ir.SymbolicDim):
if dim.expr is not None:
return dim.expr
return sympy.Symbol(dim.value)
return sympy.Integer(dim)

Check warning on line 31 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L29-L31

Added lines #L29 - L31 were not covered by tests


@dataclasses.dataclass
class InferenceResult:
values: Sequence[ir.Value] | None = None
failure: str | None = None


class NodeInferrer(abc.ABC):
"""Base class for node inferrers.

This class provides a common interface for all node inferrers.
"""

def __init__(self, op_type: str, opsets: Collection[int], domain: str = "") -> None:
"""Initialize the node inferrer.

Args:
op_type: The type of the operation.
opsets: A collection of ONNX opset versions supported by this inferrer.
domain: The domain of the operation, default is an empty string.
"""
self.op_type = op_type
self.opsets = opsets
self.domain = domain

Check warning on line 56 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L54-L56

Added lines #L54 - L56 were not covered by tests

@abc.abstractmethod
def infer(self, node: ir.Node) -> InferenceResult:
"""Infer the shape for the node.

Args:
node: The ONNX node to infer the type and shape for.

Returns:
A sequence of ONNX values containing the inferred shapes.
"""
raise NotImplementedError

Check warning on line 68 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L68

Added line #L68 was not covered by tests


def requires_non_none_inputs(
count: int, /
) -> Callable[[Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult]]:
"""Ensure that the node has a specific number of non-None inputs.

Args:
count: The exact number of non-None inputs required for the node.

Returns:
A decorator that checks the number of inputs and their non-None status.
"""

def decorator(
func: Callable[[Any, ir.Node], InferenceResult],
) -> Callable[[Any, ir.Node], InferenceResult]:
@functools.wraps(func)
def wrapper(self, node: ir.Node) -> InferenceResult:
if len(node.inputs) != count:
return InferenceResult(

Check warning on line 89 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L89

Added line #L89 was not covered by tests
failure=f"[{node.op_type} must have {count} inputs, got {len(node.inputs)}."
)
for i, inp in enumerate(node.inputs):
if inp is None:
return InferenceResult(failure=f"{node.op_type} input {i} cannot be None.")
return func(self, node)

Check warning on line 95 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L94-L95

Added lines #L94 - L95 were not covered by tests

return wrapper

return decorator


def requires_outputs(
count: int, /
) -> Callable[[Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult]]:
"""Ensure that the node has a specific number of outputs.

Args:
count: The exact number of outputs required for the node.

Returns:
A decorator that checks the number of outputs.
"""

def decorator(
func: Callable[[Any, ir.Node], InferenceResult],
) -> Callable[[Any, ir.Node], InferenceResult]:
@functools.wraps(func)
def wrapper(self, node: ir.Node) -> InferenceResult:
if len(node.outputs) != count:
return InferenceResult(

Check warning on line 120 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L120

Added line #L120 was not covered by tests
failure=f"[{node.op_type} must have {count} outputs, got {len(node.outputs)}."
)
return func(self, node)

Check warning on line 123 in src/onnx_ir/_shape_type_inference/_common.py

View check run for this annotation

Codecov / codecov/patch

src/onnx_ir/_shape_type_inference/_common.py#L123

Added line #L123 was not covered by tests

return wrapper

return decorator
Loading
Loading