-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Labels
bugSomething isn't workingSomething isn't working
Description
The following function
import onnxscript
common_opset = onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib.common", version=1)
torchlib_opset = onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib", version=1)
@onnxscript.script(common_opset)
def IsScalar(input):
"""Return whether the input has rank 0, or is a scalar."""
return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))
@onnxscript.script(torchlib_opset)
def aten_clamp_max(self, max_):
"""clamp_max(Tensor self, Tensor max) -> Tensor"""
self_size = op.Size(self)
max_shape = op.Shape(max_)
if self_size == 0:
result = op.Expand(self, max_shape)
else:
if IsScalar(max_):
max_ = op.CastLike(max_, self)
result = op.Clip(self, None, max_)
else:
result = op.Min(self, max_)
return result
where IsScalar
is an OnnxFunxtion from a custom opset does not have that opset imported for the function. I notice IsScalar is used in an if
branch/subgraph so that may be the issue.
Generated model:
E <
E ir_version: 8,
E opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1],
E producer_name: "pytorch",
E producer_version: "2.2.0"
E >
E main_graph (float16[5] input_0, float16 input_1) => (float16[5] _val_2) {
E _val_2 = pkg.onnxscript.torch_lib.aten_clamp_max (input_0, input_1)
E }
E <
E domain: "pkg.onnxscript.torch_lib",
E opset_import: ["" : 18]
E >
E aten_clamp_max (self, max_) => (result_5)
E {
E self_size = Size (self)
E max_shape = Shape (max_)
E int64_0 = Constant <value: tensor = int64 int64_0 {0}> ()
E int64_0_cast = CastLike (int64_0, self_size)
E cond = Equal (self_size, int64_0_cast)
E result_5 = If (cond) <then_branch: graph = thenGraph_7 () => ( result) {
E result = Expand (self, max_shape)
E }, else_branch: graph = elseGraph_7 () => ( result_4) {
E cond_0 = pkg.onnxscript.torch_lib.common.IsScalar (max_)
E result_4 = If (cond_0) <then_branch: graph = thenGraph_10 () => ( result_2) {
E max__1 = CastLike (max_, self)
E result_2 = Clip (self, , max__1)
E }, else_branch: graph = elseGraph_10 () => ( result_3) {
E result_3 = Min (self, max_)
E }>
E }>
E }
E <
E domain: "pkg.onnxscript.torch_lib.common",
E opset_import: ["" : 18]
E >
E Rank (input) => (return_val)
E {
E tmp = Shape (input)
E return_val = Size (tmp)
E }
E <
E domain: "pkg.onnxscript.torch_lib.common",
E opset_import: ["" : 18]
E >
E IsScalar (input) => (return_val)
E {
E tmp = Shape (input)
E tmp_0 = Size (tmp)
E tmp_1 = Constant <value_int: int = 0> ()
E return_val = Equal (tmp_0, tmp_1)
E }
Original issue onnx/onnx#5701
cc @gramalingam
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working