Skip to content

Commit

Permalink
support qgemm with s32 bias on x86 platform (#1045)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 authored Feb 27, 2023
1 parent 8ab4818 commit fa76eb5
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 1 deletion.
63 changes: 62 additions & 1 deletion pytorch_blade/tests/disc/pdl/test_e2e/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def forward(self, x):
"The patterns corresponding to pytorch before version "
"1.9.0 has not yet been implemented ")
class TestX86CPULiner(X86CPUDiscPdlQuantizationE2ETestCase):
def test_s8s8s8s32_per_channel_with_bias(self):
def test_s8s8s8f32_per_channel_with_bias(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -233,6 +233,67 @@ def forward(self, x):
model = Model().eval().to(self.device)
self._test_e2e(model, inp, pdll_files=pdll_files, enable_int8=True)

def test_s8s8s8s32_per_channel_with_bias(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.input_scale = 0.1
self.input_zero_point = 0
self.output_scale = 0.2
self.output_zero_point = 0
self.register_buffer("weight_scale", torch.randn(64))
self.register_buffer("weight_zero_point",
torch.zeros(64).to(zero_point_dtype))
self.register_buffer("bias_zero_point",
torch.zeros(64).to(zero_point_dtype))
self.weight_quant_min = -128
self.weight_quant_max = 127
self.activation_quant_min = -128
self.activation_quant_max = 127
self.bias_quant_min = -2**31
self.bias_quant_max = 2**31 - 1
self.register_buffer("weight", torch.randn(64, 64))
self.register_buffer("bias", torch.randn(64))
self.ch_axis = 0

def forward(self, x):
x = torch.fake_quantize_per_tensor_affine(
x, self.input_scale, self.input_zero_point,
self.activation_quant_min, self.activation_quant_max
)
weight = torch.fake_quantize_per_channel_affine(
self.weight, self.weight_scale, self.weight_zero_point,
self.ch_axis, self.weight_quant_min, self.weight_quant_max
)
bias_scale = self.input_scale * self.weight_scale
quant_bias = torch.fake_quantize_per_channel_affine(
self.bias, bias_scale, self.bias_zero_point,
self.ch_axis, self.bias_quant_min, self.bias_quant_max
)
x = F.linear(x, weight, bias=quant_bias)
x = torch.fake_quantize_per_tensor_affine(
x, self.output_scale, self.output_zero_point,
self.activation_quant_min, self.activation_quant_max
)
return x

pdll_files = [
os.path.join(self.common_pdll_dir, "fake_quant.pdll"),
os.path.join(self.device_pdll_dir, "dequant_gemm_quant.pdll")
]
pdll_files = ",".join(pdll_files)
inp = torch.randn(1, 64).to(self.device)
model = Model().eval().to(self.device)
self._test_e2e(model, inp, pdll_files=pdll_files, enable_int8=True)

inp = torch.randn(1, 2, 64).to(self.device)
model = Model().eval().to(self.device)
self._test_e2e(model, inp, pdll_files=pdll_files, enable_int8=True)

inp = torch.randn(1, 2, 3, 64).to(self.device)
model = Model().eval().to(self.device)
self._test_e2e(model, inp, pdll_files=pdll_files, enable_int8=True)

def test_s8s8s8_per_channel_without_bias(self):
class Model(nn.Module):
def __init__(self):
Expand Down
145 changes: 145 additions & 0 deletions tao_compiler/mlir/xla/ral/context/common_context_impl_quantization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,145 @@ MemRefType<int8_t, NDims> ral_pdll_qgemm_onednn_s8_s8_s8_f32_per_channel(
return result;
}

template <int NDims>
MemRefType<int8_t, NDims> ral_pdll_qgemm_onednn_s8_s8_s8_s32_per_channel(
ExecutionContext* ctx, opaque_t /*stream_handle*/,
MemRefType<int8_t, NDims> input, MemRefType<int8_t, 2> weight,
MemRefType<int32_t, 1> bias, MemRefType<float, 0> inputScales,
MemRefType<int32_t, 0> inputZeroPoints, MemRefType<float, 1> weightScales,
MemRefType<int32_t, 1> weightZeroPoints, MemRefType<float, 0> resultScales,
MemRefType<int32_t, 0> resultZeroPoints, void* customAttrs) {
CpuTimer timer("ral_pdll_qgemm_onednn_s8_s8_s8_f32_per_channel");
int64_t resultSizes[NDims] = {0};
if (isEmptyMemref(input) || isEmptyMemref(weight) || isEmptyMemref(bias)) {
TAO_VLOG(1)
<< "ral_pdll_qgemm_onednn_s8_s8_s8_f32_per_channel: early return for "
"empty tensor";
return assignMemRef<int8_t, NDims>(nullptr, resultSizes);
}
if (TAO_VLOG_IS_ON(1)) {
for (int i = 0; i < Size(input); ++i) {
TAO_VLOG(0) << "input[" << i
<< "] = " << static_cast<int32_t>(input.data[i]);
}
for (int i = 0; i < Size(weight); ++i) {
TAO_VLOG(0) << "weight[" << i
<< "] = " << static_cast<int32_t>(weight.data[i]);
}
for (int i = 0; i < Size(bias); ++i) {
TAO_VLOG(0) << "bias[" << i << "] = " << static_cast<float>(bias.data[i]);
}
}
auto attr = getOrParsePDLAttr(
ctx, customAttrs, "ral_pdll_qgemm_onednn_s8_s8_s8_f32_per_channel");
if (!attr) {
ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n");
}
auto& dictAttr = attr->as<DictPDLAttr>();
bool tp_a = dictAttr.get("transpose_a").as<BoolPDLAttr>().getValue();
bool tp_b = dictAttr.get("transpose_b").as<BoolPDLAttr>().getValue();
int64_t m = 1;
int64_t k;
if (tp_a) {
for (int i = NDims - 1; i > 0; i--) {
m = m * input.sizes[i];
}
k = input.sizes[0];
} else {
for (int i = 0; i < NDims - 1; i++) {
m = m * input.sizes[i];
}
k = input.sizes[NDims - 1];
}
if (k != (tp_b ? weight.sizes[1] : weight.sizes[0])) {
ctx->signalError(Context::FAILURE,
"mismatch contraction dim for "
"ral_pdll_qgemm_onednn_s8_s8_s8_f32_per_channel");
}

int64_t n = (tp_b ? weight.sizes[0] : weight.sizes[1]);
resultSizes[0] = m;
resultSizes[1] = n;
auto driver = ctx->getDriver<cpu::CPUDriver>(cpu::CPUDriver::name());
auto data = static_cast<int8_t*>(driver->alloc(ctx, m * n * sizeof(int8_t)));
int64_t gemmResultSizes[2];
gemmResultSizes[0] = m;
gemmResultSizes[1] = n;
int64_t gemmInputSizes[2] = {m, k};
auto gemmResult = assignMemRef<int8_t, 2>(data, gemmResultSizes);

ideep::tensor input_t{dims{m, k}, ideep::data_type::s8,
tp_a ? format_tag::ba : format_tag::ab, input.data};
ideep::tensor weight_t{dims{k, n}, ideep::data_type::s8,
tp_b ? format_tag::ba : format_tag::ab, weight.data};
ideep::tensor bias_t{dims{1, n}, ideep::data_type::s32, format_tag::ab,
bias.data};
ideep::tensor output_t{dims{m, n}, ideep::data_type::s8, format_tag::ab,
gemmResult.data};

std::vector<float> input_scales({inputScales.data[0]});
std::vector<int32_t> input_zero_point({inputZeroPoints.data[0]});
std::vector<float> weight_scales(weightScales.data,
weightScales.data + weightScales.sizes[0]);
std::vector<int32_t> weight_zero_point(
weightZeroPoints.data, weightZeroPoints.data + weightZeroPoints.sizes[0]);
std::vector<float> output_scales({resultScales.data[0]});
std::vector<int32_t> output_zero_point({resultZeroPoints.data[0]});

ideep::matmul_forward_params param;
// TODO: Only calculate the scale & zero_point once.
ideep::matmul_forward::prepare(
param, input_t, weight_t, bias_t, output_t, input_scales, weight_scales,
output_scales, input_zero_point, output_zero_point, 1.0f, 1.0f,
ideep::attr_t(), ideep::data_type::s8, ideep::lowp_kind::s8s8,
ideep::engine::cpu_engine());

bool weight_is_const =
dictAttr.get("weight_is_const").as<BoolPDLAttr>().getValue();
bool bias_is_const =
dictAttr.get("bias_is_const").as<BoolPDLAttr>().getValue();
if (!isWeightPrePackingEnabled() || !weight_is_const || !bias_is_const) {
// TODO: add a template to control whether bias shoule be reordered
ideep::matmul_forward::compute<true, true>(param, input_t, weight_t, bias_t,
output_t);
} else {
ideep::tensor packed_weight;
ideep::tensor packed_bias;
std::string unique_name =
"disc.ral_pdll_qgemm_onednn_s8_s8_s8_f32_per_channel";
auto state = ctx->getOrCreateResource<OnednnGemmState>(
unique_name, []() { return new OnednnGemmState; });
packed_weight = state->get_or_create_packed_weight(
weight.data, weight_t, param.pd.weights_desc(), param.weights_attr);
packed_bias = state->get_or_create_packed_bias(
bias.data, bias_t, param.pd.bias_desc(), param.bias_attr);

ideep::matmul_forward::compute<true, false>(param, input_t, packed_weight,
packed_bias, output_t);
}
if (tp_a) {
for (int i = NDims - 1; i > 0; i--) {
resultSizes[i] = input.sizes[i];
}
resultSizes[0] = n;
} else {
for (int i = 0; i < NDims - 1; i++) {
resultSizes[i] = input.sizes[i];
}
resultSizes[NDims - 1] = n;
}
auto result = assignMemRef<int8_t, NDims>(gemmResult.data, resultSizes);

if (TAO_VLOG_IS_ON(1)) {
for (int i = 0; i < Size(result); ++i) {
TAO_VLOG(0) << "output[" << i
<< "] = " << static_cast<int32_t>(result.data[i]);
}
}
timer.Stop();
return result;
}

template <int NDims>
MemRefType<int8_t, NDims> ral_pdll_qgemm_onednn_s8_s8_s8_per_channel(
ExecutionContext* ctx, opaque_t /*stream_handle*/,
Expand Down Expand Up @@ -1565,6 +1704,12 @@ TAO_RAL_API("ral_pdll_qgemm", "cpu",
ral_pdll_qgemm_onednn_s8_s8_s8_f32_per_channel<3>);
TAO_RAL_API("ral_pdll_qgemm", "cpu",
ral_pdll_qgemm_onednn_s8_s8_s8_f32_per_channel<4>);
TAO_RAL_API("ral_pdll_qgemm", "cpu",
ral_pdll_qgemm_onednn_s8_s8_s8_s32_per_channel<2>);
TAO_RAL_API("ral_pdll_qgemm", "cpu",
ral_pdll_qgemm_onednn_s8_s8_s8_s32_per_channel<3>);
TAO_RAL_API("ral_pdll_qgemm", "cpu",
ral_pdll_qgemm_onednn_s8_s8_s8_s32_per_channel<4>);
TAO_RAL_API("ral_pdll_qgemm", "cpu",
ral_pdll_qgemm_onednn_s8_s8_s8_per_channel<2>);
TAO_RAL_API("ral_pdll_qgemm", "cpu",
Expand Down

0 comments on commit fa76eb5

Please sign in to comment.