@@ -53,7 +53,7 @@ ONNX_OPERATOR_KERNEL_EX(
53
53
54
54
Status MatMulNBitsProgram::GenerateShaderCode (ShaderHelper& shader) const {
55
55
const auto & a = shader.AddInput (" input_a" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
56
- const auto & b = shader.AddInput (" input_b" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
56
+ const auto & b = shader.AddInput (" input_b" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias );
57
57
const auto & scales = shader.AddInput (" scales" , ShaderUsage::UseUniform);
58
58
const auto & y = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
59
59
@@ -102,8 +102,12 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
102
102
shader.MainFunctionBody () << " // The default zero point is 8 for unsigned 4-bit quantization.\n "
103
103
" let zero_point = output_element_t(8.0);\n " ;
104
104
}
105
- shader.MainFunctionBody () << " let scale = " << scales.GetByOffset (" b_row * n_blocks_per_col + block" ) << " ;\n "
106
- << " let b_data = " << b.GetByIndices (" input_b_indices_t(b_row, block, 0)" ) << " ;\n "
105
+ shader.MainFunctionBody () << " var scale = output_element_t(0);\n "
106
+ " var b_data = input_b_value_t(0);\n "
107
+ << " if (block < n_blocks_per_col) {\n "
108
+ << " scale = " << scales.GetByOffset (" b_row * n_blocks_per_col + block" ) << " ;\n "
109
+ << " b_data = " << b.GetByIndices (" input_b_indices_t(b_row, block, 0)" ) << " ;\n "
110
+ << " }\n "
107
111
<< " var word_offset = local_id.x * " << block_size / a.NumComponents () << " ;\n "
108
112
<< " for (var i: u32 = 0; i < " << components_b_ << " ; i++) {\n " ;
109
113
switch (a.NumComponents ()) {
0 commit comments