From eae4c3f22937b2bc40f11a32a7c6ac5094c3cc3e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 12 Sep 2024 21:01:45 -0700 Subject: [PATCH] make shape/stride correct when component != 1 --- onnxruntime/core/providers/webgpu/program.cc | 51 +++++++++++++++++++ onnxruntime/core/providers/webgpu/program.h | 34 +++---------- .../core/providers/webgpu/shader_helper.cc | 4 -- 3 files changed, 57 insertions(+), 32 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 023fa78a4196b..f12f6fb8a01c4 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -182,6 +182,57 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp } } +namespace { +TensorShape GetReducedShape(const TensorShape& shape, int component /* > 1 */) { + ORT_ENFORCE(shape.NumDimensions() > 0 && shape.GetDims()[shape.NumDimensions() - 1] % component == 0, + "Cannot reduce shape ", shape.ToString(), " by component=", component); + TensorShape reduced_shape = shape; + reduced_shape[reduced_shape.NumDimensions() - 1] /= component; + return reduced_shape; +} +} // namespace + +ProgramInput::ProgramInput(const Tensor* tensor) : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} + +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{component > 1}, + override_shape{} { + if (use_override_shape) { + override_shape = GetReducedShape(tensor->Shape(), component); + } +} + +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + +ProgramOutput::ProgramOutput(Tensor* tensor) + : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} + +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{component > 1}, + override_shape{} { + if (use_override_shape) { + override_shape = GetReducedShape(tensor->Shape(), component); + } +} + +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + ProgramBase::ProgramBase(const std::string& name) : name_{name}, dispatch_group_size_x_{0}, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 0daf247661362..2a2d4160e1617 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -208,20 +208,9 @@ int NumberOfComponents(ProgramVariableDataType type); ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); struct ProgramInput { - ProgramInput(const Tensor* tensor) - : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} - ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1) - : tensor{tensor}, - dependency{dependency}, - var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, - use_override_shape{false}, - override_shape{} {} - ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) - : tensor{tensor}, - dependency{dependency}, - var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, - use_override_shape{true}, - override_shape{override_shape} {} + ProgramInput(const Tensor* tensor); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); const Tensor* tensor; ProgramTensorMetadataDependency dependency; @@ -231,20 +220,9 @@ struct ProgramInput { }; struct ProgramOutput { - ProgramOutput(Tensor* tensor) - : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} - ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1) - : tensor{tensor}, - dependency{dependency}, - var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, - use_override_shape{false}, - override_shape{} {} - ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) - : tensor{tensor}, - dependency{dependency}, - var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, - use_override_shape{true}, - override_shape{override_shape} {} + ProgramOutput(Tensor* tensor); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); Tensor* tensor; ProgramTensorMetadataDependency dependency; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index be89efae5fc97..64ed98c78507b 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -156,10 +156,6 @@ Status ValidateVariableShape(const TensorShape& origin_shape, // if override shape specified, assert override_size == ceil( origin_size / 4 ) ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size(), "Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components); - } else if (num_components > 1) { - // if shape is not overriden, assert origin_shape[-1] % 4 == 0 - ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape[origin_shape.NumDimensions() - 1] % num_components == 0, - "Tensor original shape ", origin_shape, " cannot be divided by component number ", num_components); } return Status::OK();