diff --git a/backends/mlu/kernels/set_value_kernel.cc b/backends/mlu/kernels/set_value_kernel.cc index 77e447991..115a8a88c 100644 --- a/backends/mlu/kernels/set_value_kernel.cc +++ b/backends/mlu/kernels/set_value_kernel.cc @@ -324,7 +324,167 @@ void SetValueKernel(const Context& dev_ctx, TensorCopy(dev_ctx, in_temp, false, out); } +template +void SetTensorValueKernel(const Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& value, + const phi::IntArray& starts, + const phi::IntArray& ends, + const phi::IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + phi::DenseTensor* out) { + dev_ctx.template Alloc(out); + + std::vector starts_local = starts.GetData(); + std::vector ends_local = ends.GetData(); + std::vector steps_local = steps.GetData(); + + auto in_dims = x.dims(); + custom_kernel::CheckAndUpdateSliceAttrs( + in_dims, axes, &starts_local, &ends_local, &steps_local); + auto slice_dims = custom_kernel::GetSliceDims( + in_dims, axes, starts_local, ends_local, &steps_local); + auto decrease_slice_dims = + custom_kernel::GetDecreasedDims(slice_dims, decrease_axes); + + auto slice_dims_for_assign = decrease_slice_dims; + if (!none_axes.empty()) { + std::vector slice_dims_with_none; + size_t none_axes_cur = 0, decrease_axes_cur = 0; + for (int i = 0; i < slice_dims.size(); ++i) { + while (none_axes_cur < none_axes.size() && + none_axes[none_axes_cur] <= i) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + if (decrease_axes_cur < decrease_axes.size() && + decrease_axes[decrease_axes_cur] == i) { + decrease_axes_cur++; + } else { + slice_dims_with_none.push_back(slice_dims[i]); + } + } + while (none_axes_cur < none_axes.size()) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + + slice_dims_for_assign = phi::make_ddim(slice_dims_with_none); + } + int in_size = in_dims.size(); + int starts_indices[in_size] = {0}; + int ends_indices[in_size] = {0}; + int strides_indices[in_size] = {0}; + + for (int i = 0; i < in_dims.size(); ++i) { + starts_indices[i] = 0; + ends_indices[i] = static_cast(slice_dims[i]); + strides_indices[i] = 1; + } + for (size_t i = 0; i < axes.size(); i++) { + int axis_index = axes[i]; + starts_indices[axis_index] = static_cast(starts_local[i]); + ends_indices[axis_index] = static_cast(ends_local[i]); + strides_indices[axis_index] = static_cast(steps_local[i]); + } + + phi::DenseTensor value_temp; + if (slice_dims_for_assign == value.dims()) { + value_temp = value; + } else { + value_temp.Resize(slice_dims_for_assign); + dev_ctx.template Alloc(&value_temp); + MLUCnnlTensorDesc value_t_desc(value); + MLUCnnlTensorDesc value_temp_desc(value_temp); + MLUCnnl::BroadcastTo(dev_ctx, + value_t_desc.get(), + GetBasePtr(&value), + value_temp_desc.get(), + GetBasePtr(&value_temp)); + } + + int64_t input_numel = phi::product(in_dims); + int64_t value_numel = phi::product(value_temp.dims()); + phi::DenseTensor in_temp, out_temp, val_temp, index_out; + int64_t stride_step = phi::product(in_dims); + std::vector index_indices(stride_step); + std::iota(index_indices.begin(), index_indices.end(), 0); + phi::DenseTensor index_temp; + in_temp = x; + val_temp = value_temp; + custom_kernel::TensorFromVector(dev_ctx, index_indices, dev_ctx, &index_temp); + dev_ctx.Wait(); + index_temp.Resize(in_dims); + auto index_dims = in_dims; + for (int i = 0; i < in_dims.size(); ++i) { + if (starts_indices[i] < 0 || ends_indices[i] < 0) { + starts_indices[i] -= in_dims[i]; + ends_indices[i] -= in_dims[i]; + } + if (strides_indices[i] > 0) + index_dims[i] = + static_cast((ends_indices[i] - starts_indices[i] - 1) / + strides_indices[i]) + + 1; + else + index_dims[i] = + static_cast((ends_indices[i] - starts_indices[i] + 1) / + strides_indices[i]) + + 1; + } + auto new_in_dims = phi::make_ddim({input_numel}); + auto new_val_dims = phi::make_ddim({value_numel}); + in_temp.Resize(new_in_dims); + val_temp.Resize(new_val_dims); + index_out.Resize(index_dims); + dev_ctx.template Alloc(&index_out); + cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE; + MLUCnnlTensorDesc x_desc(in_temp); + MLUCnnlTensorDesc indices_desc(index_temp); + MLUCnnlTensorDesc indices_out_desc(index_out); + MLUCnnlTensorDesc updates_desc(val_temp); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::StridedSlice(dev_ctx, + starts_indices, + ends_indices, + strides_indices, + indices_desc.get(), + GetBasePtr(&index_temp), + indices_out_desc.get(), + GetBasePtr(&index_out)); + PADDLE_ENFORCE_EQ( + static_cast(phi::product(index_out.dims())), + phi::product(slice_dims_for_assign), + phi::errors::InvalidArgument( + "OP(set_value) error index indices and value update not match ")); + Tensor index_final; + index_final = index_out; + int64_t indices_numel = phi::product(index_dims); + auto new_index_dims = phi::make_ddim({indices_numel}); + index_final.Resize(new_index_dims); + MLUCnnlTensorDesc indices_final_desc(index_final); + MLUCnnl::ScatterRefFunctor(dev_ctx, + x_desc.get(), + GetBasePtr(&in_temp), + updates_desc.get(), + GetBasePtr(&val_temp), + indices_final_desc.get(), + GetBasePtr(&index_final), + mode); + in_temp.Resize(in_dims); + TensorCopy(dev_ctx, in_temp, false, out); +} + } // namespace custom_kernel PD_REGISTER_PLUGIN_KERNEL( set_value, mlu, ALL_LAYOUT, custom_kernel::SetValueKernel, float, int) {} + +PD_REGISTER_PLUGIN_KERNEL(set_value_with_tensor, + mlu, + ALL_LAYOUT, + custom_kernel::SetTensorValueKernel, + float, + int) {} diff --git a/backends/mlu/kernels/yolo_box_kernel.cc b/backends/mlu/kernels/yolo_box_kernel.cc index 2772d2a16..3b4268d72 100644 --- a/backends/mlu/kernels/yolo_box_kernel.cc +++ b/backends/mlu/kernels/yolo_box_kernel.cc @@ -75,6 +75,7 @@ void YoloBoxKernel(const Context& dev_ctx, Tensor anchors_temp; anchors_temp.Resize({size}); custom_kernel::TensorFromVector(dev_ctx, anchors, dev_ctx, &anchors_temp); + dev_ctx.Wait(); MLUOpTensorDesc anchors_desc(anchors_temp); MLUCnnlTensorDesc boxes_desc_cnnl( 4, boxes_out_dim.data(), ToCnnlDataType());