Skip to content

Commit

Permalink
[MLU] fix yolo_box and add set_value_with_tensor kernel (PaddlePaddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
cifar10 authored Apr 6, 2023
1 parent 31e4378 commit 2d3390d
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 0 deletions.
160 changes: 160 additions & 0 deletions backends/mlu/kernels/set_value_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,167 @@ void SetValueKernel(const Context& dev_ctx,
TensorCopy(dev_ctx, in_temp, false, out);
}

template <typename T, typename Context>
void SetTensorValueKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
const phi::DenseTensor& value,
const phi::IntArray& starts,
const phi::IntArray& ends,
const phi::IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
phi::DenseTensor* out) {
dev_ctx.template Alloc<T>(out);

std::vector<int64_t> starts_local = starts.GetData();
std::vector<int64_t> ends_local = ends.GetData();
std::vector<int64_t> steps_local = steps.GetData();

auto in_dims = x.dims();
custom_kernel::CheckAndUpdateSliceAttrs(
in_dims, axes, &starts_local, &ends_local, &steps_local);
auto slice_dims = custom_kernel::GetSliceDims(
in_dims, axes, starts_local, ends_local, &steps_local);
auto decrease_slice_dims =
custom_kernel::GetDecreasedDims(slice_dims, decrease_axes);

auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
std::vector<int64_t> slice_dims_with_none;
size_t none_axes_cur = 0, decrease_axes_cur = 0;
for (int i = 0; i < slice_dims.size(); ++i) {
while (none_axes_cur < none_axes.size() &&
none_axes[none_axes_cur] <= i) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
if (decrease_axes_cur < decrease_axes.size() &&
decrease_axes[decrease_axes_cur] == i) {
decrease_axes_cur++;
} else {
slice_dims_with_none.push_back(slice_dims[i]);
}
}
while (none_axes_cur < none_axes.size()) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}

slice_dims_for_assign = phi::make_ddim(slice_dims_with_none);
}
int in_size = in_dims.size();
int starts_indices[in_size] = {0};
int ends_indices[in_size] = {0};
int strides_indices[in_size] = {0};

for (int i = 0; i < in_dims.size(); ++i) {
starts_indices[i] = 0;
ends_indices[i] = static_cast<int>(slice_dims[i]);
strides_indices[i] = 1;
}
for (size_t i = 0; i < axes.size(); i++) {
int axis_index = axes[i];
starts_indices[axis_index] = static_cast<int>(starts_local[i]);
ends_indices[axis_index] = static_cast<int>(ends_local[i]);
strides_indices[axis_index] = static_cast<int>(steps_local[i]);
}

phi::DenseTensor value_temp;
if (slice_dims_for_assign == value.dims()) {
value_temp = value;
} else {
value_temp.Resize(slice_dims_for_assign);
dev_ctx.template Alloc<T>(&value_temp);
MLUCnnlTensorDesc value_t_desc(value);
MLUCnnlTensorDesc value_temp_desc(value_temp);
MLUCnnl::BroadcastTo(dev_ctx,
value_t_desc.get(),
GetBasePtr(&value),
value_temp_desc.get(),
GetBasePtr(&value_temp));
}

int64_t input_numel = phi::product(in_dims);
int64_t value_numel = phi::product(value_temp.dims());
phi::DenseTensor in_temp, out_temp, val_temp, index_out;
int64_t stride_step = phi::product(in_dims);
std::vector<int64_t> index_indices(stride_step);
std::iota(index_indices.begin(), index_indices.end(), 0);
phi::DenseTensor index_temp;
in_temp = x;
val_temp = value_temp;
custom_kernel::TensorFromVector(dev_ctx, index_indices, dev_ctx, &index_temp);
dev_ctx.Wait();
index_temp.Resize(in_dims);
auto index_dims = in_dims;
for (int i = 0; i < in_dims.size(); ++i) {
if (starts_indices[i] < 0 || ends_indices[i] < 0) {
starts_indices[i] -= in_dims[i];
ends_indices[i] -= in_dims[i];
}
if (strides_indices[i] > 0)
index_dims[i] =
static_cast<int>((ends_indices[i] - starts_indices[i] - 1) /
strides_indices[i]) +
1;
else
index_dims[i] =
static_cast<int>((ends_indices[i] - starts_indices[i] + 1) /
strides_indices[i]) +
1;
}
auto new_in_dims = phi::make_ddim({input_numel});
auto new_val_dims = phi::make_ddim({value_numel});
in_temp.Resize(new_in_dims);
val_temp.Resize(new_val_dims);
index_out.Resize(index_dims);
dev_ctx.template Alloc<int64_t>(&index_out);
cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE;
MLUCnnlTensorDesc x_desc(in_temp);
MLUCnnlTensorDesc indices_desc(index_temp);
MLUCnnlTensorDesc indices_out_desc(index_out);
MLUCnnlTensorDesc updates_desc(val_temp);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::StridedSlice(dev_ctx,
starts_indices,
ends_indices,
strides_indices,
indices_desc.get(),
GetBasePtr(&index_temp),
indices_out_desc.get(),
GetBasePtr(&index_out));
PADDLE_ENFORCE_EQ(
static_cast<int64_t>(phi::product(index_out.dims())),
phi::product(slice_dims_for_assign),
phi::errors::InvalidArgument(
"OP(set_value) error index indices and value update not match "));
Tensor index_final;
index_final = index_out;
int64_t indices_numel = phi::product(index_dims);
auto new_index_dims = phi::make_ddim({indices_numel});
index_final.Resize(new_index_dims);
MLUCnnlTensorDesc indices_final_desc(index_final);
MLUCnnl::ScatterRefFunctor(dev_ctx,
x_desc.get(),
GetBasePtr(&in_temp),
updates_desc.get(),
GetBasePtr(&val_temp),
indices_final_desc.get(),
GetBasePtr(&index_final),
mode);
in_temp.Resize(in_dims);
TensorCopy(dev_ctx, in_temp, false, out);
}

} // namespace custom_kernel

PD_REGISTER_PLUGIN_KERNEL(
set_value, mlu, ALL_LAYOUT, custom_kernel::SetValueKernel, float, int) {}

PD_REGISTER_PLUGIN_KERNEL(set_value_with_tensor,
mlu,
ALL_LAYOUT,
custom_kernel::SetTensorValueKernel,
float,
int) {}
1 change: 1 addition & 0 deletions backends/mlu/kernels/yolo_box_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ void YoloBoxKernel(const Context& dev_ctx,
Tensor anchors_temp;
anchors_temp.Resize({size});
custom_kernel::TensorFromVector(dev_ctx, anchors, dev_ctx, &anchors_temp);
dev_ctx.Wait();
MLUOpTensorDesc anchors_desc(anchors_temp);
MLUCnnlTensorDesc boxes_desc_cnnl(
4, boxes_out_dim.data(), ToCnnlDataType<T>());
Expand Down

0 comments on commit 2d3390d

Please sign in to comment.