Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 16 additions & 24 deletions paddle/phi/kernels/legacy/gpu/moe_combine_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,11 @@ void MoeCombineGradKernel(const Context& dev_ctx,
DenseTensor* grad_combine_weights_helper) {
dev_ctx.template Alloc<T>(grad_x);
dev_ctx.template Alloc<T>(grad_combine_weights_helper);
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x);
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(grad_combine_weights_helper->dims())),
0,
grad_combine_weights_helper);
Full<T, Context>(dev_ctx, grad_x->dims(), 0, grad_x);
Full<T, Context>(dev_ctx,
grad_combine_weights_helper->dims(),
0,
grad_combine_weights_helper);
auto x_shape = x.dims();
auto combine_weights_shape = combine_weights.dims();
moe_combine_bwd<T, Context>(dev_ctx,
Expand All @@ -182,18 +180,13 @@ void MoeCombineAutoGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(grad_combine_weights_helper);
dev_ctx.template Alloc<int32_t>(grad_scatter_index);

phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x);
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(grad_combine_weights_helper->dims())),
0,
grad_combine_weights_helper);
phi::Full<int32_t, Context>(
dev_ctx,
phi::IntArray(common::vectorize(grad_scatter_index->dims())),
0,
grad_scatter_index);
Full<T, Context>(dev_ctx, grad_x->dims(), 0, grad_x);
Full<T, Context>(dev_ctx,
grad_combine_weights_helper->dims(),
0,
grad_combine_weights_helper);
Full<int32_t, Context>(
dev_ctx, grad_scatter_index->dims(), 0, grad_scatter_index);

// TODO(nieyuntao): Temporarily use 'grad_combine_weight_intermediate' to
// bypass the grad_combine_weights_helper's shape mismatch to kernel shape
Expand All @@ -207,11 +200,10 @@ void MoeCombineAutoGradKernel(const Context& dev_ctx,
x.dims()[1]}));
grad_combine_weight_intermediate_meta.set_dtype(combine_weights.dtype());
dev_ctx.template Alloc<T>(grad_combine_weight_intermediate);
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(
grad_combine_weight_intermediate->dims())),
0,
grad_combine_weight_intermediate);
Full<T, Context>(dev_ctx,
grad_combine_weight_intermediate->dims(),
0,
grad_combine_weight_intermediate);

auto x_shape = x.dims();
auto combine_weights_shape = combine_weights.dims();
Expand Down
3 changes: 1 addition & 2 deletions paddle/phi/kernels/legacy/gpu/moe_combine_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ void MoeCombineKernel(const Context& dev_ctx,
DenseTensor* y) {
dev_ctx.template Alloc<T>(y); // T cannot support phi::dtype::float8 very
// well, maybe replaced with x.dtype();
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y);
Full<T, Context>(dev_ctx, y->dims(), 0, y);
auto combine_weights_shape = combine_weights.dims();
auto x_shape = x.dims();
moe_combine_fwd<T, Context>(dev_ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ void MoeCombineNoWeightGradKernel(const Context& dev_ctx,
const int64_t k = scatter_index_shape[1];

dev_ctx.template Alloc<T>(grad_x);
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x);
Full<T, Context>(dev_ctx, grad_x->dims(), 0, grad_x);

moe_combine_no_weight_bwd<T>(combine_weights.data<T>(),
scatter_index.data<int>(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,7 @@ void MoeDispatchAndQuantKernel(const Context &dev_ctx,
sizeof(phi::float8_e4m3fn) * out_fp8->numel(),
dev_ctx.stream());

phi::Full<float, Context>(
dev_ctx, phi::IntArray(common::vectorize(scale->dims())), 1, scale);
Full<float, Context>(dev_ctx, scale->dims(), 1, scale);

const auto &x_shape = x.dims();
const auto &gate_logits_shape = gate_logits.dims();
Expand Down
3 changes: 1 addition & 2 deletions paddle/phi/kernels/legacy/gpu/moe_gate_dispatch_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ void MoeGateDispatchKernel(const Context &dev_ctx,
dev_ctx.template Alloc<float>(combine_weights);
dev_ctx.template Alloc<T>(y);

phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y);
Full<T, Context>(dev_ctx, y->dims(), 0, y);
auto x_dims = x.dims();
auto gate_logits_dims = gate_logits.dims();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ void MoEDispatchPermuteKernel(const Context &dev_ctx,
dev_ctx.template Alloc<int>(scatter_index);
dev_ctx.template Alloc<float>(combine_weights);
dev_ctx.template Alloc<T>(y);
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y);
Full<T, Context>(dev_ctx, y->dims(), 0, y);
const auto &x_shape = x.dims();
const auto &gate_logits_shape = gate_logits.dims();
int64_t num_rows = x_shape[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,8 @@ void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(
DenseTensor* combine_weights_grad) {
dev_ctx.template Alloc<T>(x_grad);
dev_ctx.template Alloc<float>(combine_weights_grad);
phi::Full<float, Context>(
dev_ctx,
phi::IntArray(common::vectorize(combine_weights_grad->dims())),
0,
combine_weights_grad);
Full<float, Context>(
dev_ctx, combine_weights_grad->dims(), 0, combine_weights_grad);
DenseTensor t_scatter_index;
phi::Transpose<int, Context>(
dev_ctx, scatter_index, {1, 0}, &t_scatter_index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,7 @@ void apply_moe_dispatch_fwd(
y->Resize({expert_offset_host.back(), x.dims()[1]});
dev_ctx.template Alloc<T>(y);
}
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y);
Full<T, Context>(dev_ctx, y->dims(), 0, y);
copy_unpermuted_to_permuted_kernelLauncher(
x.data<T>(),
y->data<T>(), // out
Expand Down Expand Up @@ -526,31 +525,14 @@ void MoeGateDispatchPartialNoSoftMaxTopkKernel(
dev_ctx.template Alloc<int64_t>(expert_offset);
dev_ctx.template Alloc<int64_t>(expert_nums_local);
dev_ctx.template Alloc<float>(combine_weights_out);
phi::Full<int32_t, Context>(
dev_ctx,
phi::IntArray(common::vectorize(scatter_index->dims())),
0,
scatter_index);
phi::Full<int32_t, Context>(
dev_ctx,
phi::IntArray(common::vectorize(scatter_index_rev->dims())),
0,
scatter_index_rev);
phi::Full<int64_t, Context>(
dev_ctx,
phi::IntArray(common::vectorize(expert_offset->dims())),
0,
expert_offset);
phi::Full<int64_t, Context>(
dev_ctx,
phi::IntArray(common::vectorize(expert_nums_local->dims())),
0,
expert_nums_local);
phi::Full<float, Context>(
dev_ctx,
phi::IntArray(common::vectorize(combine_weights_out->dims())),
0,
combine_weights_out);
Full<int32_t, Context>(dev_ctx, scatter_index->dims(), 0, scatter_index);
Full<int32_t, Context>(
dev_ctx, scatter_index_rev->dims(), 0, scatter_index_rev);
Full<int64_t, Context>(dev_ctx, expert_offset->dims(), 0, expert_offset);
Full<int64_t, Context>(
dev_ctx, expert_nums_local->dims(), 0, expert_nums_local);
Full<float, Context>(
dev_ctx, combine_weights_out->dims(), 0, combine_weights_out);
phi::Copy(
dev_ctx, combine_weights, dev_ctx.GetPlace(), false, combine_weights_out);
const auto &x_shape = x.dims();
Expand Down
11 changes: 2 additions & 9 deletions paddle/phi/kernels/stride/indexing_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,7 @@ void IndexPutGradKernel_V2(const Context& dev_ctx,
dev_ctx.template Alloc<T>(x_grad);
// Fill value_grad with 0.
if (value_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(value_grad->dims())),
0,
value_grad);
phi::Full<T, Context>(dev_ctx, value_grad->dims(), 0, value_grad);
}
return;
}
Expand Down Expand Up @@ -390,10 +386,7 @@ void IndexPutGradKernel_V2(const Context& dev_ctx,
x_grad->ShareInplaceVersionCounterWith(out_grad);
} else {
DenseTensor value_zero;
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(value.dims())),
0,
&value_zero);
phi::Full<T, Context>(dev_ctx, value.dims(), 0, &value_zero);
if (funcs::IsInUint32Range(x_grad->numel(), value.numel())) {
LaunchIndexPutKernel_V2<T, Context>(
dev_ctx, out_grad, indices, value_zero, false, x_grad);
Expand Down
6 changes: 2 additions & 4 deletions paddle/phi/kernels/stride/reduce_stride_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,7 @@ void ProdStrideKernel(const Context& dev_ctx,

if (x_.numel() == 0) {
// fill with 1.
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 1, out);
phi::Full<T, Context>(dev_ctx, out->dims(), 1, out);
return;
}

Expand Down Expand Up @@ -647,8 +646,7 @@ void MeanStrideKernel(const Context& dev_ctx,
}

if (x_.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
phi::Full<T, Context>(dev_ctx, out->dims(), NAN, out);
return;
}

Expand Down
Loading