Skip to content

Commit

Permalink
allow scalar axes for Unsqueeze for WebGPU (#22054)
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire authored Sep 12, 2024
1 parent 951b1b7 commit 84f7332
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/js/operators/unsqueeze.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ class Unsqueeze final : public JsKernel, public UnsqueezeBase {
if (num_inputs == 2) { // axes is an input
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1,
"An axes tensor must be a vector tensor.");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 ||
axes_tensor->Shape().NumDimensions() == 1,
"An axes tensor must be a scalar or a vector tensor.");
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
const auto* data = axes_tensor->Data<int64_t>();
axes.assign(data, data + nDims);
Expand Down

0 comments on commit 84f7332

Please sign in to comment.