diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc new file mode 100644 index 0000000000000..8d59570de9967 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webgpu/tensor/cast.h" + +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +const std::vector& CastOpTypeConstraints() { + // currently support boolean, integer and float types that explicitly allowed in WGSL: + // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section + // + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 6, 8, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 9, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 13, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_KERNEL_EX( + Cast, + kOnnxDomain, + 19, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); + +Status Cast::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + auto* output_tensor = context.Output(0, input_tensor->Shape()); + int64_t size = input_tensor->Shape().Size(); + if (size == 0) { + return Status::OK(); + } + SafeInt vec_size = (size + 3) / 4; + + CastProgram program{to_}; + program + .AddInput({input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(vec_size)}, + }) + .CacheHint(std::to_string(to_)); + return context.RunProgram(program); +} + +Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& input = sh.AddInput("x", ShaderUsage::UseUniform); + const auto& output = sh.AddOutput("y", ShaderUsage::UseUniform); + std::string expression; + switch (to_) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + expression = "vec4(a)"; + break; + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + expression = "vec4(a)"; + break; + default: + ORT_NOT_IMPLEMENTED("Cast to type ", to_, " is not supported."); + } + sh.SetMainFunctionBody(sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + " let a = ", input.GetByOffset("global_idx"), ";\n ", + output.SetByOffset("global_idx", expression)); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.h b/onnxruntime/core/providers/webgpu/tensor/cast.h new file mode 100644 index 0000000000000..c7216c18500fd --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/cast.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class CastProgram final : public Program { + public: + CastProgram(int32_t to) : Program{"Cast"}, to_{to} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + private: + int32_t to_; +}; + +class Cast final : public WebGpuKernel { + public: + Cast(const OpKernelInfo& info) : WebGpuKernel(info) { + int64_t to; + Status status = info.GetAttr("to", &to); + ORT_ENFORCE(status.IsOK(), "Attribute to is not set."); + to_ = SafeInt(to); + + // ignore attribute 'saturate' as float8 is not supported in WebGPU + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int32_t to_; +}; + +} // namespace webgpu +} // namespace onnxruntime