Skip to content

Commit

Permalink
[torchlib] Fix aten_instance_norm (#1964)
Browse files Browse the repository at this point in the history
Otherwise exporter raises `Could not determine the dtype for the input
'inputs'.`
  • Loading branch information
justinchuby authored Nov 25, 2024
1 parent 9592227 commit e282467
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4391,7 +4391,10 @@ def aten_instance_norm(
), "running_mean and running_var must be provided when use_input_stats is False"

batch_size = op.Shape(input, start=0, end=1)
bn_input = op.Reshape(input, op.Concat([1, -1], op.Shape(input, start=2), axis=0))
bn_input = op.Reshape(
input,
op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0),
)
weight = op.Tile(weight, batch_size)
bias = op.Tile(bias, batch_size)
running_mean = op.Tile(running_mean, batch_size)
Expand Down

0 comments on commit e282467

Please sign in to comment.