Skip to content

Commit

Permalink
Cast operator
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 17, 2024
1 parent 2e91a8b commit 87f9edb
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 0 deletions.
117 changes: 117 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/cast.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <vector>

#include "core/providers/webgpu/tensor/cast.h"

#include "core/providers/webgpu/shader_helper.h"

namespace onnxruntime {
namespace webgpu {

namespace {
const std::vector<MLDataType>& CastOpTypeConstraints() {
// currently support boolean, integer and float types that explicitly allowed in WGSL:
// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section
//
static std::vector<MLDataType> types{
DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<bool>()};
return types;
}
} // namespace

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Cast,
kOnnxDomain,
6, 8,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", CastOpTypeConstraints())
.TypeConstraint("T2", CastOpTypeConstraints()),
Cast);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Cast,
kOnnxDomain,
9, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", CastOpTypeConstraints())
.TypeConstraint("T2", CastOpTypeConstraints()),
Cast);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Cast,
kOnnxDomain,
13, 18,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", CastOpTypeConstraints())
.TypeConstraint("T2", CastOpTypeConstraints()),
Cast);
ONNX_OPERATOR_KERNEL_EX(
Cast,
kOnnxDomain,
19,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", CastOpTypeConstraints())
.TypeConstraint("T2", CastOpTypeConstraints()),
Cast);

Status Cast::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
auto* output_tensor = context.Output(0, input_tensor->Shape());
int64_t size = input_tensor->Shape().Size();
if (size == 0) {
return Status::OK();
}
SafeInt<uint32_t> vec_size = (size + 3) / 4;

CastProgram program{to_};
program
.AddInput({input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4})
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({
{static_cast<uint32_t>(vec_size)},
})
.CacheHint(std::to_string(to_));
return context.RunProgram(program);
}

Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const {
const auto& input = sh.AddInput("x", ShaderUsage::UseUniform);
const auto& output = sh.AddOutput("y", ShaderUsage::UseUniform);
std::string expression;
switch (to_) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
expression = "vec4<f16>(a)";
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
expression = "vec4<f32>(a)";
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
expression = "vec4<i32>(a)";
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
expression = "vec4<u32>(a)";
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
expression = "vec4<bool>(a)";
break;
default:
ORT_NOT_IMPLEMENTED("Cast to type ", to_, " is not supported.");
}
sh.SetMainFunctionBody(sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"),
" let a = ", input.GetByOffset("global_idx"), ";\n ",
output.SetByOffset("global_idx", expression));

return Status::OK();
}

} // namespace webgpu
} // namespace onnxruntime
39 changes: 39 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace webgpu {

class CastProgram final : public Program<CastProgram> {
public:
CastProgram(int32_t to) : Program{"Cast"}, to_{to} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

private:
int32_t to_;
};

class Cast final : public WebGpuKernel {
public:
Cast(const OpKernelInfo& info) : WebGpuKernel(info) {
int64_t to;
Status status = info.GetAttr("to", &to);
ORT_ENFORCE(status.IsOK(), "Attribute to is not set.");
to_ = SafeInt<int32_t>(to);

// ignore attribute 'saturate' as float8 is not supported in WebGPU
}

Status ComputeInternal(ComputeContext& context) const override;

private:
int32_t to_;
};

} // namespace webgpu
} // namespace onnxruntime

0 comments on commit 87f9edb

Please sign in to comment.