Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions paddle/cinn/hlir/pe/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ std::vector<std::vector<int>> GetMatmulNewShapes(
auto& new_y_shape = new_shape[1];
auto& out_shape = new_shape[2];

int x_dim = x_shape.size(), y_dim = y_shape.size();
int max_dim = std::max(x_shape.size(), y_shape.size());
size_t x_dim = x_shape.size(), y_dim = y_shape.size();
size_t max_dim = std::max(x_shape.size(), y_shape.size());
int out_dim = max_dim >= 3 ? 3 : (max_dim <= 2 ? 2 : max_dim);

auto get_input_shape = [out_dim](const std::vector<int>& old_shape) {
Expand Down Expand Up @@ -309,8 +309,9 @@ std::vector<Tensor> Matmul(const Tensor& A,
const std::string& name) {
std::vector<Expr> shape_A = A->shape;
std::vector<Expr> shape_B = B->shape;
int a_dim = shape_A.size();
int b_dim = shape_B.size();
// NOTE(large-tensor): tensor dimensions are small integers
int a_dim = static_cast<int>(shape_A.size());
int b_dim = static_cast<int>(shape_B.size());
PADDLE_ENFORCE_EQ(
a_dim == 3U || a_dim == 2U,
true,
Expand Down Expand Up @@ -347,7 +348,8 @@ std::vector<Tensor> Matmul(const Tensor& A,
auto temp = Compute(
output_shape,
[=](const std::vector<Expr>& indice) {
int out_dim = indice.size();
// NOTE(large-tensor): tensor dimensions are small integers
int out_dim = static_cast<int>(indice.size());
std::vector<Expr> A_indice;
std::vector<Expr> B_indice;
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -458,13 +460,15 @@ ir::Tensor Concat(const std::vector<ir::Tensor>& input_tensors,
int axis,
const std::string& name) {
// input size 1 is valid for Concat
int input_size = input_tensors.size();
// NOTE(large-tensor): input size is a small integer
int input_size = static_cast<int>(input_tensors.size());
PADDLE_ENFORCE_GE(input_size,
1U,
::common::errors::InvalidArgument(
"Concat should have at least 1 input tensors"));
std::vector<Expr> output_shape = input_tensors[0]->shape;
int input_dim = output_shape.size();
// NOTE(large-tensor): tensor dimensions are small integers
int input_dim = static_cast<int>(output_shape.size());
PADDLE_ENFORCE_EQ(
axis >= -input_dim && axis < input_dim,
true,
Expand Down Expand Up @@ -518,8 +522,9 @@ std::vector<Tensor> MatmulV2(const Tensor& A,
const cinn::common::Target& target) {
std::vector<Expr> shape_A = A->shape;
std::vector<Expr> shape_B = B->shape;
int a_dim = shape_A.size();
int b_dim = shape_B.size();
// NOTE(large-tensor): tensor dimensions are small integers
int a_dim = static_cast<int>(shape_A.size());
int b_dim = static_cast<int>(shape_B.size());
PADDLE_ENFORCE_EQ(
a_dim == 3U || a_dim == 2U,
true,
Expand Down Expand Up @@ -566,7 +571,8 @@ std::vector<Tensor> MatmulV2(const Tensor& A,
packedB_shape,
[=](const std::vector<Expr>& indice) {
std::vector<Expr> indice_b;
int indice_dim = indice.size();
// NOTE(large-tensor): tensor dimensions are small integers
int indice_dim = static_cast<int>(indice.size());
PADDLE_ENFORCE_GE(indice_dim,
3,
::common::errors::InvalidArgument(
Expand All @@ -590,7 +596,8 @@ std::vector<Tensor> MatmulV2(const Tensor& A,
[=](const std::vector<Expr>& indice) {
std::vector<Expr> indice_a;
std::vector<Expr> indice_b;
int out_dim = indice.size();
// NOTE(large-tensor): tensor dimensions are small integers
int out_dim = static_cast<int>(indice.size());
PADDLE_ENFORCE_EQ(
out_dim == 3U || out_dim == 2U,
true,
Expand Down Expand Up @@ -635,8 +642,9 @@ std::vector<Tensor> MatmulMKL(const Tensor& A,
"Mkl should be used in the cpu environment."));
std::vector<Expr> shape_A = A->shape;
std::vector<Expr> shape_B = B->shape;
int a_dim = shape_A.size();
int b_dim = shape_B.size();
// NOTE(large-tensor): tensor dimensions are small integers
int a_dim = static_cast<int>(shape_A.size());
int b_dim = static_cast<int>(shape_B.size());
PADDLE_ENFORCE_EQ(
a_dim == 3U || a_dim == 2U,
true,
Expand Down Expand Up @@ -930,8 +938,9 @@ std::vector<Tensor> MulMKL(const Tensor& A,
"Mkl should be used in the cpu environment."));
std::vector<Expr> shape_A = A->shape;
std::vector<Expr> shape_B = B->shape;
int a_dim = shape_A.size();
int b_dim = shape_B.size();
// NOTE(large-tensor): tensor dimensions are small integers
int a_dim = static_cast<int>(shape_A.size());
int b_dim = static_cast<int>(shape_B.size());
PADDLE_ENFORCE_EQ(
a_dim,
2U,
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ inline void TensorToVector(const phi::DenseTensor& src,
memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size, nullptr);
}
#endif
for (unsigned int i = 0; i < src.numel(); i++) {
for (int64_t i = 0; i < src.numel(); i++) {
(*dst)[i] = static_cast<bool>(array[i]);
}
delete[] array;
Expand Down Expand Up @@ -408,7 +408,7 @@ inline void TensorToVector(const phi::DenseTensor& src,

memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size);

for (unsigned int i = 0; i < src.numel(); i++) {
for (int64_t i = 0; i < src.numel(); i++) {
(*dst)[i] = static_cast<bool>(array[i]);
}
delete[] array;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
'LegacyInterpolateInferMeta',
'NceInferMeta',
'PyramidHashInferMeta',
'RmsNormInferMeta',
'FusedRmsNormQuantInferMeta',
'SigmoidCrossEntropyWithLogitsInferMeta',
'StackInferMeta',
'WeightOnlyLinearInferMeta',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ bool GatherOpInferSymbolicShape(pir::Operation *op,
"3 when the axis is not set."));
const auto &axis_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
// NOTE(large-tensor): axis is a small integer
axis =
static_cast<int>(axis_shape_or_data.data().value()[0].Get<int64_t>());
}
Expand All @@ -983,7 +984,10 @@ bool GatherOpInferSymbolicShape(pir::Operation *op,
const std::vector<symbol::DimExpr> &index_sym_shape =
index_shape_or_data.shape();

if (axis < 0) axis += input_sym_shape.size();
if (axis < 0) {
// NOTE(large-tensor): tensor rank is a small integer
axis += static_cast<int>(input_sym_shape.size());
}

const auto &out_sym_shape = [&] {
std::vector<symbol::DimExpr> out_sym_shape;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op,
infer_context->GetShapeOrDataForValue(input_values[0]).shape();

size_t rank = out_dims.size();
axis = axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank));
// NOTE(large-tensor): axis is a small integer.
axis = axis >= 0
? axis
: static_cast<int>(std::max(int64_t(0), int64_t(axis + rank)));

for (size_t i = 1; i < input_size; ++i) {
const auto &operand_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void SliceDfsImpl(const ExprVec &datas,
int64_t start,
int64_t end,
int64_t cur_visit_axis,
int offset,
int64_t offset,
ExprVec *result) {
int64_t begin = 0;
int64_t stop = shape.at(cur_visit_axis);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,8 @@ bool BroadcastTensorsOpInferSymbolicShape(
// 1. Find Output rank = max(Inputs rank)
int target_rank = 0;
for (const auto &input_shape_or_data : input_shape_or_data_list) {
int tmp_rank = input_shape_or_data.shape().size();
// NOTE(large-tensor): tensor rank is a small integer
int tmp_rank = static_cast<int>(input_shape_or_data.shape().size());
target_rank = std::max(target_rank, tmp_rank);
}
// 2. Output dim(axis=x) = max(Inputs dim(axis=x))
Expand All @@ -749,7 +750,8 @@ bool BroadcastTensorsOpInferSymbolicShape(
for (int i = 0; i < target_rank; i++) {
auto tmp_dim = symbol::DimExpr{1};
for (const auto &input_shape_or_data : input_shape_or_data_list) {
int axis = input_shape_or_data.shape().size();
// NOTE(large-tensor): tensor rank is a small integer
int axis = static_cast<int>(input_shape_or_data.shape().size());
axis = i - target_rank + axis;
if (axis >= 0) {
infer_context->AddBroadcastableCstr(input_shape_or_data.shape()[axis],
Expand Down Expand Up @@ -1143,7 +1145,10 @@ bool CrossEntropyWithSoftmaxOpInferSymbolicShape(
const auto &index_dim = index_shape.shape();
const auto &attributes = op->attributes();
int axis = attributes.at("axis").dyn_cast<pir::Int32Attribute>().data();
if (axis < 0) axis += input_shape.shape().size();
if (axis < 0) {
// NOTE(large-tensor): tensor rank is a small integer
axis += static_cast<int>(input_shape.shape().size());
}
bool soft_label =
attributes.at("soft_label").dyn_cast<pir::BoolAttribute>().data();
PADDLE_ENFORCE(!soft_label || input_dim.size() == index_dim.size(),
Expand Down Expand Up @@ -1197,7 +1202,8 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op,
const auto &shape_data_list =
x_shape.dyn_cast<symbol::TensorListShapeOrDataDimExprs>();

size_t rank = shape_data_list.at(0).shape().size();
// NOTE(large-tensor): tensor rank is a small integer
int rank = static_cast<int>(shape_data_list.at(0).shape().size());
const int64_t axis = [&] {
int64_t axis = axis_expr.data()->at(0).dyn_cast<int64_t>();
return axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank));
Expand All @@ -1216,8 +1222,8 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op,

const std::vector<symbol::DimExpr> &out_dims = [&] {
std::vector<symbol::DimExpr> out_dims = shape_data_list.at(0).shape();
for (size_t i = 0; i < rank; ++i) {
if (i != static_cast<size_t>(axis)) {
for (int i = 0; i < rank; ++i) {
if (i != axis) {
details::BuildCstrEqForTensorListAlongAxis(
infer_context, shape_data_list, i);
continue;
Expand Down Expand Up @@ -2751,7 +2757,7 @@ bool GroupNormOpInferSymbolicShape(

infer_context->SetShapeOrDataForValue(op->result(0), x_shape);

int64_t channel_idx;
size_t channel_idx;
std::string data_format =
op->attribute<pir::StrAttribute>("data_format").AsString();
if (data_format == "NHWC") {
Expand Down Expand Up @@ -2867,9 +2873,10 @@ bool LerpOpInferSymbolicShape(pir::Operation *op,
std::vector<symbol::DimExpr> x_shape = x_shape_or_data.shape();
std::vector<symbol::DimExpr> y_shape = y_shape_or_data.shape();
std::vector<symbol::DimExpr> w_shape = w_shape_or_data.shape();
int x_ndims = x_shape.size();
int y_ndims = y_shape.size();
int w_ndims = w_shape.size();
// NOTE(large-tensor): tensor dimensions are small integers
int x_ndims = static_cast<int>(x_shape.size());
int y_ndims = static_cast<int>(y_shape.size());
int w_ndims = static_cast<int>(w_shape.size());
std::vector<symbol::DimExpr> out1_shape;
std::vector<symbol::DimExpr> out2_shape;
int diffxy = x_ndims - y_ndims;
Expand Down Expand Up @@ -3414,6 +3421,59 @@ bool RoiAlignOpInferSymbolicShape(
return true;
}

bool RmsNormOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
// Get the shapes of input tensors
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &scale_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
std::vector<int64_t> normalized_shape =
paddle::dialect::details::GetVectorAttr<int64_t>(op, "normalized_shape");

std::vector<symbol::DimExpr> x_dims = x_shape_or_data.shape();
int x_dims_size = x_dims.size();
int normalized_shape_size = normalized_shape.size();
int begin_norm_axis = x_dims_size - normalized_shape_size;

// Flatten x_dims to 2D and get dim[1]
PADDLE_ENFORCE_LT(normalized_shape_size,
x_dims_size,
"normalized_shape must be less than x_dims");
for (int i = 0; i < normalized_shape_size; i++) {
infer_context->AddEqualCstr(
x_dims[x_dims_size - i - 1],
symbol::DimExpr(normalized_shape[normalized_shape_size - i - 1]));
}

if (!scale_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
std::vector<symbol::DimExpr> scale_dims = scale_shape_or_data.shape();
PADDLE_ENFORCE_EQ(
scale_dims.size(),
normalized_shape_size,
"scale_dims.size() must be equal to normalized_shape_size");
for (int i = 0; i < normalized_shape_size; i++) {
infer_context->AddEqualCstr(scale_dims[i],
symbol::DimExpr(normalized_shape[i]));
}
}

// Set output shapes
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_dims)});

// Set invvar
std::vector<symbol::DimExpr> before_norm_dims(
x_dims.begin(), x_dims.begin() + begin_norm_axis);
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(before_norm_dims)});

return true;
}

bool SpectralNormOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &weight_shape =
Expand Down Expand Up @@ -3688,7 +3748,7 @@ bool NceOpInferSymbolicShape(pir::Operation *op,
infer_context->AddEqualCstr(weight_shape[0], bias_shape[0]);
}

int num_total_classes =
int64_t num_total_classes =
op->attribute<pir::Int64Attribute>("num_total_classes").data();
infer_context->AddEqualCstr(symbol::DimExpr(num_total_classes),
weight_shape[0]);
Expand All @@ -3700,7 +3760,7 @@ bool NceOpInferSymbolicShape(pir::Operation *op,
symbol::TensorShapeOrDataDimExprs(out_shape)});

bool is_test = op->attribute<pir::BoolAttribute>("is_test").data();
int num_neg_samples =
int64_t num_neg_samples =
op->attribute<pir::Int64Attribute>("num_neg_samples").data();
if (!is_test) {
std::vector<symbol::DimExpr> sample_out_shape = {x_shape[0]};
Expand Down Expand Up @@ -4072,7 +4132,7 @@ bool RnnOpInferSymbolicShape(pir::Operation *op,
"The rank of PreState in RNN must be 3. But "
"the received rank is %d.",
pre_state_shape_or_data_list[0].shape().size()));
for (size_t i = 0; i < 3; ++i) {
for (int i = 0; i < 3; ++i) {
details::BuildCstrEqForTensorListAlongAxis(
infer_context, pre_state_shape_or_data_list, i);
}
Expand Down Expand Up @@ -4344,7 +4404,8 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
infer_context->GetShapeOrDataForValue(operand_source)
.dyn_cast<symbol::TensorListShapeOrDataDimExprs>();

size_t rank = shape_data_list.at(0).shape().size();
// NOTE(large-tensor): tensor rank is a small integer
int rank = static_cast<int>(shape_data_list.at(0).shape().size());
if (axis < 0) axis += rank + 1;
const symbol::ShapeOrDataDimExprs shape_data = [&] {
std::vector<symbol::DimExpr> result_shape = {};
Expand Down Expand Up @@ -4378,7 +4439,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
} else {
// case 2: data is empty, eg: shape_data_list =
// [[shape:{5,6,7},data:{}],...]
for (size_t i = 0; i < rank; ++i) {
for (int i = 0; i < rank; ++i) {
details::BuildCstrEqForTensorListAlongAxis(
infer_context, shape_data_list, i);
}
Expand Down Expand Up @@ -4759,9 +4820,10 @@ bool WhereOpInferSymbolicShape(pir::Operation *op,
const std::vector<pir::Value> &operands = {
op->operand_source(0), op->operand_source(1), op->operand_source(2)};

size_t rank = x_shape.size();
// NOTE(large-tensor): tensor rank is a small integer
int rank = static_cast<int>(x_shape.size());

for (size_t i = 0; i < rank; ++i) {
for (int i = 0; i < rank; ++i) {
paddle::dialect::details::BuildCstrEqForTensorListAlongAxis(
infer_context, operands, i);
}
Expand Down Expand Up @@ -4849,7 +4911,8 @@ bool YoloLossOpInferSymbolicShape(
infer_context->GetShapeOrDataForValue(op->operand_source(2)).shape();
const std::vector<int> &anchors_mask =
paddle::dialect::details::GetVectorAttr<int>(op, "anchor_mask");
int mask_num = anchors_mask.size();
// NOTE(large-tensor): mask number is a small integer
int mask_num = static_cast<int>(anchors_mask.size());
int class_num = op->attribute<pir::Int32Attribute>("class_num").data();

PADDLE_ENFORCE_EQ(x_shape.size(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(RandomRouting)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RandomRouting_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedRmsNormQuant)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RoiPool)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RmsNorm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rnn)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rnn_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RoiAlign)
Expand Down
Loading
Loading