Skip to content

Commit 5a24314

Browse files
committed
add boundary check
1 parent 27043c9 commit 5a24314

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

+7-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ ONNX_OPERATOR_KERNEL_EX(
5353

5454
Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
5555
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
56-
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
56+
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
5757
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform);
5858
const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
5959

@@ -102,8 +102,12 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
102102
shader.MainFunctionBody() << " // The default zero point is 8 for unsigned 4-bit quantization.\n"
103103
" let zero_point = output_element_t(8.0);\n";
104104
}
105-
shader.MainFunctionBody() << " let scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n"
106-
<< " let b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n"
105+
shader.MainFunctionBody() << " var scale = output_element_t(0);\n"
106+
" var b_data = input_b_value_t(0);\n"
107+
<< " if (block < n_blocks_per_col) {\n"
108+
<< " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n"
109+
<< " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n"
110+
<< " }\n"
107111
<< " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n"
108112
<< " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n";
109113
switch (a.NumComponents()) {

0 commit comments

Comments
 (0)