Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
clackhan committed Sep 22, 2023
1 parent 4ba3cc5 commit 7e154e1
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions oneflow/user/kernels/conv_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Conv2dQuantKernel final : public user_op::OpKernel, public user_op::CudaGr
cutlass::library::NumericTypeID::kS32, cutlass::library::NumericTypeID::kS32);
if (in->data_type() == DataType::kFloat16) {
key.element_A = cutlass::library::NumericTypeID::kF16;
return;
}
if (out->data_type() == DataType::kFloat) {
key.element_C = cutlass::library::NumericTypeID::kF32;
Expand Down
1 change: 1 addition & 0 deletions oneflow/user/kernels/fused_glu_quant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class GpuFusedGluQuantKernel final : public user_op::OpKernel, public user_op::C
);
if (input_x->data_type() == DataType::kFloat16) {
key.element_A = cutlass::library::NumericTypeID::kF16;
return;
}
if (data_type == DataType::kFloat) {
key.element_scalar = cutlass::library::NumericTypeID::kF32;
Expand Down
1 change: 1 addition & 0 deletions oneflow/user/kernels/grouped_matmul_quant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class GroupedMatmulQuantKernel final : public user_op::OpKernel, public user_op:
const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("as", 0);
if (a->data_type() == DataType::kFloat16) {
key.element_A = cutlass::library::NumericTypeID::kF16;
return;
}

if (GetDataType<OutType>::value == DataType::kFloat) {
Expand Down
1 change: 1 addition & 0 deletions oneflow/user/kernels/matmul_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class MatmulQuantKernel final : public user_op::OpKernel {

if (a->data_type() == DataType::kFloat16) {
key.element_A = cutlass::library::NumericTypeID::kF16;
return;
}

if (out->data_type() == DataType::kFloat) {
Expand Down

0 comments on commit 7e154e1

Please sign in to comment.