Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
marigoold committed Sep 18, 2023
1 parent d5555d1 commit 2bdc3cd
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions oneflow/user/ops/fused_group_norm_min_max_observer_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,26 @@ namespace oneflow {
// check input shape
CHECK_GT_OR_RETURN(x.shape().NumAxes(), 2)
<< "fused_group_norm_min_max_observer: input should have at least 2 dims, got "
<< x.shape().NumAxes() << " dims";
<< x.shape().NumAxes() << " dims.";

// check num_group
int32_t num_groups = ctx->Attr<int32_t>("num_groups");
CHECK_EQ_OR_RETURN(channel_size % num_groups, 0) << "Channels should be divisble by num_groups. ";
CHECK_EQ_OR_RETURN(channel_size % num_groups, 0) << "fused_group_norm_min_max_observer: Channels should be divisble by num_groups.";

// check gamma and beta size
if (affine) {
const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0);
CHECK_EQ_OR_RETURN(gamma.shape().elem_cnt(), channel_size)
<< "The size of `gamma` must equal to channel_size, expected " << channel_size
<< " but got " << gamma.shape().elem_cnt();
<< "fused_group_norm_min_max_observer: The size of `gamma` must be equal to channel_size, expected " << channel_size
<< " but got " << gamma.shape().elem_cnt() << ".";
const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0);
CHECK_EQ_OR_RETURN(beta.shape().elem_cnt(), channel_size)
<< "The size of `beta` must equal to channel_size, expected " << channel_size << " but got "
<< beta.shape().elem_cnt();
<< "fused_group_norm_min_max_observer: The size of `beta` must be equal to channel_size, expected " << channel_size << " but got "
<< beta.shape().elem_cnt() << ".";
}

CHECK_OR_RETURN(ctx->Attr<bool>("per_layer_quantization"))
<< "dynamic quantization only supports per-layer quantization";
<< "fused_group_norm_min_max_observer: dynamic quantization only supports per-layer quantization.";
ctx->SetOutputShape("y", 0, x.shape());
ctx->SetOutputShape("y_scale", 0, Shape({1}));
ctx->SetOutputShape("y_zero_point", 0, Shape({1}));
Expand Down

0 comments on commit 2bdc3cd

Please sign in to comment.