Skip to content

Opsets not imported for functions when the op is used in an if branch #1109

@justinchuby

Description

@justinchuby

The following function

import onnxscript

common_opset = onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib.common", version=1)
torchlib_opset = onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib", version=1)

@onnxscript.script(common_opset)
def IsScalar(input):
    """Return whether the input has rank 0, or is a scalar."""

    return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))


@onnxscript.script(torchlib_opset)
def aten_clamp_max(self, max_):
    """clamp_max(Tensor self, Tensor max) -> Tensor"""

    self_size = op.Size(self)
    max_shape = op.Shape(max_)
    if self_size == 0:
        result = op.Expand(self, max_shape)
    else:
        if IsScalar(max_):
            max_ = op.CastLike(max_, self)
            result = op.Clip(self, None, max_)
        else:
            result = op.Min(self, max_)

    return result

where IsScalar is an OnnxFunxtion from a custom opset does not have that opset imported for the function. I notice IsScalar is used in an if branch/subgraph so that may be the issue.

Generated model:

E   <
E      ir_version: 8,
E      opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1],
E      producer_name: "pytorch",
E      producer_version: "2.2.0"
E   >
E   main_graph (float16[5] input_0, float16 input_1) => (float16[5] _val_2) {
E      _val_2 = pkg.onnxscript.torch_lib.aten_clamp_max (input_0, input_1)
E   }
E   <
E     domain: "pkg.onnxscript.torch_lib",
E     opset_import: ["" : 18]
E   >
E   aten_clamp_max (self, max_) => (result_5)
E   {
E      self_size = Size (self)
E      max_shape = Shape (max_)
E      int64_0 = Constant <value: tensor = int64 int64_0 {0}> ()
E      int64_0_cast = CastLike (int64_0, self_size)
E      cond = Equal (self_size, int64_0_cast)
E      result_5 = If (cond) <then_branch: graph = thenGraph_7 () => ( result) {
E         result = Expand (self, max_shape)
E      }, else_branch: graph = elseGraph_7 () => ( result_4) {
E         cond_0 = pkg.onnxscript.torch_lib.common.IsScalar (max_)
E         result_4 = If (cond_0) <then_branch: graph = thenGraph_10 () => ( result_2) {
E            max__1 = CastLike (max_, self)
E            result_2 = Clip (self, , max__1)
E         }, else_branch: graph = elseGraph_10 () => ( result_3) {
E            result_3 = Min (self, max_)
E         }>
E      }>
E   }
E   <
E     domain: "pkg.onnxscript.torch_lib.common",
E     opset_import: ["" : 18]
E   >
E   Rank (input) => (return_val)
E   {
E      tmp = Shape (input)
E      return_val = Size (tmp)
E   }
E   <
E     domain: "pkg.onnxscript.torch_lib.common",
E     opset_import: ["" : 18]
E   >
E   IsScalar (input) => (return_val)
E   {
E      tmp = Shape (input)
E      tmp_0 = Size (tmp)
E      tmp_1 = Constant <value_int: int = 0> ()
E      return_val = Equal (tmp_0, tmp_1)
E   }

Original issue onnx/onnx#5701

cc @gramalingam

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions