- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13
[WIP] Create symbolic type/shape inference logic #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Draft
      
      
            justinchuby
  wants to merge
  31
  commits into
  main
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
justinchu/symbolic-inference-claude
  
      
      
   
  
    
  
  
  
 
  
      
    base: main
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
  
     Draft
                    Changes from 20 commits
      Commits
    
    
            Show all changes
          
          
            31 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      1adca71
              
                Scaffold
              
              
                justinchuby f48537b
              
                Add support for expr in SymbolicDim
              
              
                justinchuby 7464af1
              
                wip
              
              
                justinchuby 27ae0ca
              
                wip
              
              
                justinchuby 1afe64a
              
                Create NodeInferencer
              
              
                justinchuby 78ad6e0
              
                inference_common
              
              
                justinchuby 5aa2df7
              
                Update shapes
              
              
                justinchuby dbc3593
              
                update
              
              
                justinchuby b9f0528
              
                Claude - add sympy import
              
              
                justinchuby c9a35b7
              
                Claude and lint
              
              
                justinchuby 65e3dd2
              
                concat
              
              
                justinchuby 7960770
              
                Update _maybe_convert_to_symbolic_dim
              
              
                justinchuby a7704c5
              
                reshape
              
              
                justinchuby 922a597
              
                Update the way dim is set
              
              
                justinchuby 9183848
              
                Simplify
              
              
                justinchuby 9300aba
              
                Update
              
              
                justinchuby 8747a93
              
                Handle unknown dims
              
              
                justinchuby 92049c4
              
                Simplify
              
              
                justinchuby 720845e
              
                Create inclusive range
              
              
                justinchuby bae78ab
              
                WIP inference engine
              
              
                justinchuby a77f487
              
                Create readme
              
              
                justinchuby 6686457
              
                Result
              
              
                justinchuby 3207e84
              
                Summary of Complete Refactoring
              
              
                justinchuby a572145
              
                lint
              
              
                justinchuby 11f8958
              
                Removes unused shape inference code
              
              
                justinchuby f3c70da
              
                Summary of Shape Simplifications
              
              
                justinchuby 4b6d80d
              
                Create factory
              
              
                justinchuby e03733b
              
                Use Enum
              
              
                justinchuby 5a34891
              
                Update logging calls
              
              
                justinchuby ab09107
              
                Working on engine
              
              
                justinchuby 9256233
              
                todo
              
              
                justinchuby File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| """Symbolic shape and type inference for ONNX IR.""" | ||
|  | ||
| __all__ = [ | ||
| "SymbolicInferenceEngine", | ||
| "InferenceError", | ||
| "NodeInferrer", | ||
| "InferenceResult", | ||
| ] | ||
|  | ||
|  | ||
| from onnx_ir._shape_type_inference._common import InferenceResult, NodeInferrer | ||
| from onnx_ir._shape_type_inference._engine import ( | ||
| InferenceError, | ||
| SymbolicInferenceEngine, | ||
| ) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| """Symbolic shape inference for ONNX IR.""" | ||
|          | ||
|  | ||
| from __future__ import annotations | ||
|          | ||
|  | ||
| import abc | ||
| import dataclasses | ||
| import functools | ||
| from collections.abc import Collection, Sequence | ||
| from typing import Any, Callable | ||
|  | ||
| import sympy | ||
|  | ||
| import onnx_ir as ir | ||
|  | ||
|  | ||
| MAX_SUPPORTED_OPSET = 23 | ||
|  | ||
|  | ||
| def get_expr(shape: ir.Shape, index: int) -> sympy.Expr: | ||
| """Get the expression or value at a specific index in the shape. | ||
|  | ||
| Args: | ||
| shape: The shape to get the expression from. | ||
| index: The index of the dimension to get. | ||
|  | ||
| Returns: | ||
| The expression or value at the specified index. | ||
| """ | ||
| dim = shape[index] | ||
| if isinstance(dim, ir.SymbolicDim): | ||
| if dim.expr is not None: | ||
| return dim.expr | ||
| if dim.value is None: | ||
| return sympy.Symbol("__unknown__") | ||
| return sympy.Symbol(dim.value) | ||
| return sympy.Integer(dim) | ||
|  | ||
|  | ||
| @dataclasses.dataclass | ||
| class InferenceResult: | ||
| values: Sequence[ir.Value] | None = None | ||
| failure: str | None = None | ||
|  | ||
|  | ||
| class NodeInferrer(abc.ABC): | ||
| """Base class for node inferrers. | ||
|  | ||
| This class provides a common interface for all node inferrers. | ||
| """ | ||
|  | ||
| def __init__(self, op_type: str, opsets: Collection[int], domain: str = "") -> None: | ||
| """Initialize the node inferrer. | ||
|  | ||
| Args: | ||
| op_type: The type of the operation. | ||
| opsets: A collection of ONNX opset versions supported by this inferrer. | ||
| domain: The domain of the operation, default is an empty string. | ||
| """ | ||
| self.op_type = op_type | ||
| self.opsets = opsets | ||
| self.domain = domain | ||
|  | ||
| def __repr__(self) -> str: | ||
| """Return a string representation of the node inferrer.""" | ||
| return f"{self.__class__.__name__}(op_type={self.op_type}, opsets={self.opsets}, domain={self.domain})" | ||
|  | ||
| @abc.abstractmethod | ||
| def infer(self, node: ir.Node) -> InferenceResult: | ||
| """Infer the shape for the node. | ||
|  | ||
| Args: | ||
| node: The ONNX node to infer the type and shape for. | ||
|  | ||
| Returns: | ||
| A sequence of ONNX values containing the inferred shapes. | ||
| """ | ||
| raise NotImplementedError | ||
|  | ||
|  | ||
| def requires_non_none_inputs( | ||
| count: int, / | ||
| ) -> Callable[ | ||
| [Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult] | ||
| ]: | ||
| """Ensure that the node has a specific number of non-None inputs. | ||
|  | ||
| Args: | ||
| count: The exact number of non-None inputs required for the node. | ||
|  | ||
| Returns: | ||
| A decorator that checks the number of inputs and their non-None status. | ||
| """ | ||
|  | ||
| def decorator( | ||
| func: Callable[[Any, ir.Node], InferenceResult], | ||
| ) -> Callable[[Any, ir.Node], InferenceResult]: | ||
| @functools.wraps(func) | ||
| def wrapper(self, node: ir.Node) -> InferenceResult: | ||
| if len(node.inputs) != count: | ||
| return InferenceResult( | ||
| failure=f"[{node.op_type} must have {count} inputs, got {len(node.inputs)}." | ||
| ) | ||
| for i, inp in enumerate(node.inputs): | ||
| if inp is None: | ||
| return InferenceResult(failure=f"{node.op_type} input {i} cannot be None.") | ||
| return func(self, node) | ||
|  | ||
| return wrapper | ||
|  | ||
| return decorator | ||
|  | ||
|  | ||
| def requires_outputs( | ||
| count: int, / | ||
| ) -> Callable[ | ||
| [Callable[[Any, ir.Node], InferenceResult]], Callable[[Any, ir.Node], InferenceResult] | ||
| ]: | ||
| """Ensure that the node has a specific number of outputs. | ||
|  | ||
| Args: | ||
| count: The exact number of outputs required for the node. | ||
|  | ||
| Returns: | ||
| A decorator that checks the number of outputs. | ||
| """ | ||
|  | ||
| def decorator( | ||
| func: Callable[[Any, ir.Node], InferenceResult], | ||
| ) -> Callable[[Any, ir.Node], InferenceResult]: | ||
| @functools.wraps(func) | ||
| def wrapper(self, node: ir.Node) -> InferenceResult: | ||
| if len(node.outputs) != count: | ||
| return InferenceResult( | ||
| failure=f"[{node.op_type} must have {count} outputs, got {len(node.outputs)}." | ||
| ) | ||
| return func(self, node) | ||
|  | ||
| return wrapper | ||
|  | ||
| return decorator | ||
|  | ||
|  | ||
| def inclusive_range(start_or_end: int = 0, end: int | None = None) -> range: | ||
| """Create an inclusive range from start to end with a given step. | ||
|  | ||
| Args: | ||
| start_or_end: The starting value of the range. | ||
| end: The ending value of the range (inclusive). | ||
|  | ||
| Returns: | ||
| A range object that includes both start and end. | ||
| """ | ||
| if end is None: | ||
| end = start_or_end | ||
| start = 0 | ||
| else: | ||
| start = start_or_end | ||
|  | ||
| return range(start, end + 1) | ||
      
      Oops, something went wrong.
        
    
  
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.