Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -1245,10 +1245,11 @@ Tensor elu_decomp(const Tensor& x, const float alpha) {
}

template <typename T>
Tensor lerp_decomp(const Tensor& x, const Tensor& y, const Tensor& weight) {
Tensor lerp_decomp(const Tensor& x, const Tensor& y, const Scalar& weight) {
Tensor x_cast = ConvertToMT<T>(x);
Tensor y_cast = ConvertToMT<T>(y);
Tensor weight_cast = ConvertToMT<T>(weight);
Tensor weight_tensor =
full_scalar<T>(weight.to<double>(), x_cast.dtype(), x_cast.place());
Tensor half = full_scalar<T>((0.5), x_cast.dtype(), x_cast.place());
Tensor one = full_scalar<T>(1.0, x_cast.dtype(), x_cast.place());
Tensor zero;
Expand All @@ -1258,10 +1259,10 @@ Tensor lerp_decomp(const Tensor& x, const Tensor& y, const Tensor& weight) {
Tensor zero_x = backend::full_with_tensor<T>(shape64<T>(x), 0.0, x.dtype());
Tensor zero_y = backend::full_with_tensor<T>(shape64<T>(y), 0.0, x.dtype());
zero = zero_x + zero_y;
weight_expended = backend::expand<T>(weight_cast, shape64<T>(zero));
weight_expended = backend::expand<T>(weight_tensor, shape64<T>(zero));
} else {
auto out_dims = phi::funcs::BroadcastTwoDims(x.dims(), y.dims());
weight_expended = expand<T>(weight_cast, phi::vectorize(out_dims));
weight_expended = expand<T>(weight_tensor, phi::vectorize(out_dims));
}

Tensor res = where<T>(weight_expended.abs() < half,
Expand Down
7 changes: 1 addition & 6 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1610,15 +1610,10 @@ void LayerNormGradInferMeta(const MetaTensor& x,
}
}

void LerpInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
MetaTensor* out) {
void LerpInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto w_dims = weight.dims();
DDim out_dims = funcs::GetOutputDimsForDynamicShape(x_dims, y_dims);
out_dims = funcs::GetOutputDimsForDynamicShape(out_dims, w_dims);
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ PADDLE_API void LayerNormGradInferMeta(const MetaTensor& x,

PADDLE_API void LerpInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
MetaTensor* out);

PADDLE_API void LinearV2InferMeta(const MetaTensor& input,
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/lerp_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lerp_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
lerp_grad, CPU, ALL_LAYOUT, phi::LerpGradKernel, float, double) {}
PD_REGISTER_KERNEL(lerp_grad,
CPU,
ALL_LAYOUT,
phi::LerpGradKernel,
phi::float16,
phi::bfloat16,
float,
double) {}
9 changes: 8 additions & 1 deletion paddle/phi/kernels/cpu/lerp_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lerp_kernel_impl.h"

PD_REGISTER_KERNEL(lerp, CPU, ALL_LAYOUT, phi::LerpKernel, float, double) {}
PD_REGISTER_KERNEL(lerp,
CPU,
ALL_LAYOUT,
phi::LerpKernel,
phi::float16,
phi::bfloat16,
float,
double) {}
137 changes: 108 additions & 29 deletions paddle/phi/kernels/gpu/lerp_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/phi/common/amp_type_traits.h"
Expand Down Expand Up @@ -83,6 +84,53 @@ __global__ void LerpGradKernelCompatibleImpl(const T* weight,
}
}

template <typename T>
__global__ void LerpGradScalarValueKernelImpl(T weight_value,
const T* dout,
T* dx,
T* dy,
const int64_t out_size,
const int64_t x_size,
const int64_t y_size) {
CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) {
double temp_dx =
static_cast<double>(weight_value) * static_cast<double>(dout[idx]);
if (dx) {
if (idx < x_size) {
dx[idx] = static_cast<T>(static_cast<double>(dout[idx]) - temp_dx);
}
}
if (dy) {
if (idx < y_size) {
dy[idx] = static_cast<T>(temp_dx);
}
}
}
}

template <typename T>
__global__ void LerpGradScalarValueKernelCompatibleImpl(T weight_value,
const T* dout,
T* dx,
T* dy,
const int64_t out_size,
const int64_t x_size,
const int64_t y_size) {
T remaining_weight_value = static_cast<T>(1) - weight_value;
CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) {
if (dx) {
if (idx < x_size) {
dx[idx] = remaining_weight_value * dout[idx];
}
}
if (dy) {
if (idx < y_size) {
dy[idx] = weight_value * dout[idx];
}
}
}
}

template <typename T, typename WeightT = T>
__global__ void LerpGradScalarKernelImpl(const WeightT* weight,
const T* dout,
Expand Down Expand Up @@ -280,9 +328,9 @@ template <typename T, typename Context>
void LerpGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
const DenseTensor& out,
const DenseTensor& out_grad,
const Scalar& weight,
DenseTensor* x_grad,
DenseTensor* y_grad) {
if (out_grad.numel() == 0) {
Expand Down Expand Up @@ -311,57 +359,89 @@ void LerpGradKernel(const Context& dev_ctx,
"less than or equal to 6, but the value received is %d.",
rank));

// check if x_grad and y_grad need to be reduced
// if x has a different dimension with y or weight in the middle axis, then
// they need to be broadcast and then reduced.
// Weight is always a scalar now - extract the value
T weight_value = weight.to<T>();
int64_t x_grad_size = 0, y_grad_size = 0;
T* x_grad_data = NULL;
T* y_grad_data = NULL;

const T* out_grad_data = out_grad.data<T>();
const int64_t out_size = out_grad.numel();
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_size);

bool reduce_flag = XYNeedReduce(x, y, out);
if (!reduce_flag) {
int64_t x_grad_size = 0, y_grad_size = 0;
T* x_grad_data = NULL;
T* y_grad_data = NULL;

if (!reduce_flag) {
if (x_grad) {
x_grad_data = dev_ctx.template Alloc<T>(x_grad);
x_grad_size = x.numel();
}

if (y_grad) {
y_grad_data = dev_ctx.template Alloc<T>(y_grad);
y_grad_size = y.numel();
}

SwitchKernel<T, Context>(dev_ctx,
weight,
out_grad,
x_grad_size,
y_grad_size,
x_grad_data,
y_grad_data);

if (FLAGS_use_accuracy_compatible_kernel) {
LerpGradScalarValueKernelCompatibleImpl<T>
<<<gpu_config.GetGridSize(),
gpu_config.GetBlockSize(),
0,
dev_ctx.stream()>>>(weight_value,
out_grad_data,
x_grad_data,
y_grad_data,
out_size,
x_grad_size,
y_grad_size);
} else {
LerpGradScalarValueKernelImpl<T><<<gpu_config.GetGridSize(),
gpu_config.GetBlockSize(),
0,
dev_ctx.stream()>>>(weight_value,
out_grad_data,
x_grad_data,
y_grad_data,
out_size,
x_grad_size,
y_grad_size);
}
} else {
int64_t x_grad_size = 0, y_grad_size = 0;
DenseTensor b_xgrad = EmptyLike<T, Context>(dev_ctx, out_grad);
DenseTensor b_ygrad = EmptyLike<T, Context>(dev_ctx, out_grad);
T* x_grad_data = NULL;
T* y_grad_data = NULL;

if (x_grad) {
x_grad_data = dev_ctx.template Alloc<T>(&b_xgrad);
x_grad_size = out.numel();
}

if (y_grad) {
y_grad_data = dev_ctx.template Alloc<T>(&b_ygrad);
y_grad_size = out.numel();
}

SwitchKernel<T, Context>(dev_ctx,
weight,
out_grad,
x_grad_size,
y_grad_size,
x_grad_data,
y_grad_data);
if (FLAGS_use_accuracy_compatible_kernel) {
LerpGradScalarValueKernelCompatibleImpl<T>
<<<gpu_config.GetGridSize(),
gpu_config.GetBlockSize(),
0,
dev_ctx.stream()>>>(weight_value,
out_grad_data,
x_grad_data,
y_grad_data,
out_size,
x_grad_size,
y_grad_size);
} else {
LerpGradScalarValueKernelImpl<T><<<gpu_config.GetGridSize(),
gpu_config.GetBlockSize(),
0,
dev_ctx.stream()>>>(weight_value,
out_grad_data,
x_grad_data,
y_grad_data,
out_size,
x_grad_size,
y_grad_size);
}

auto zero_dim = common::make_ddim(std::vector<int64_t>(1, 1));
if (x_grad) {
Expand All @@ -376,7 +456,6 @@ void LerpGradKernel(const Context& dev_ctx,
x_grad->ShareDataWith(b_xgrad);
}
}

if (y_grad) {
std::vector<int> reduce_axis_y =
funcs::GetReduceDim(y_grad->dims().size() ? y_grad->dims() : zero_dim,
Expand Down
73 changes: 25 additions & 48 deletions paddle/phi/kernels/gpu/lerp_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
Expand All @@ -34,6 +35,22 @@ struct LerpElementWiseDirectCUDAFunctor {
}
};

template <typename T>
struct LerpScalarValueFunctor {
T weight_value_;

HOSTDEVICE inline LerpScalarValueFunctor(T weight_value)
: weight_value_(weight_value) {}

HOSTDEVICE inline T operator()(const T x, const T y) const {
if (abs(static_cast<float>(weight_value_)) < 0.5f) {
return x + weight_value_ * (y - x);
} else {
return y - (y - x) * (static_cast<T>(1) - weight_value_);
}
}
};

template <typename T, typename WeightT = T>
struct LerpScalarDirectCUDAFunctor {
const WeightT* weight_;
Expand All @@ -55,7 +72,7 @@ template <typename T, typename Context>
void LerpKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
const Scalar& weight,
DenseTensor* out) {
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
Expand All @@ -74,54 +91,14 @@ void LerpKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out);
std::vector<DenseTensor*> outputs = {out};

// Weight is always a scalar now - extract the value
T weight_value = weight.to<T>();
std::vector<const DenseTensor*> inputs;
if (weight.numel() == 1) {
inputs.reserve(2);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
if (weight.dtype() == DataType::FLOAT64) {
const double* weight_ptr = weight.data<double>();
auto functor = LerpScalarDirectCUDAFunctor<T, double>(weight_ptr);
funcs::BroadcastKernel<T>(dev_ctx, inputs, &outputs, functor);
} else {
const T* weight_ptr = weight.data<T>();
auto functor = LerpScalarDirectCUDAFunctor<T>(weight_ptr);
funcs::BroadcastKernel<T>(dev_ctx, inputs, &outputs, functor);
}
} else {
inputs.reserve(3);
auto functor = LerpElementWiseDirectCUDAFunctor<T>();
DenseTensor b_min = EmptyLike<T>(dev_ctx, *out);
if (x.dims().size() != y.dims().size() &&
weight.dims().size() != y.dims().size()) {
if (x.dims().size() < y.dims().size() &&
x.dims().size() < weight.dims().size()) {
// x broadcast to b_min
ExpandKernel<T, Context>(dev_ctx, x, vectorize(b_min.dims()), &b_min);
inputs.emplace_back(&b_min);
inputs.emplace_back(&y);
inputs.emplace_back(&weight);
} else if (y.dims().size() < weight.dims().size()) {
// y broadcast to b_min
ExpandKernel<T, Context>(dev_ctx, y, vectorize(b_min.dims()), &b_min);
inputs.emplace_back(&x);
inputs.emplace_back(&b_min);
inputs.emplace_back(&weight);
} else {
// weight broadcast to b_min
ExpandKernel<T, Context>(
dev_ctx, weight, vectorize(b_min.dims()), &b_min);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
inputs.emplace_back(&b_min);
}
} else {
inputs.emplace_back(&x);
inputs.emplace_back(&y);
inputs.emplace_back(&weight);
}
funcs::BroadcastKernel<T>(dev_ctx, inputs, &outputs, functor);
}
inputs.reserve(2);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
auto functor = LerpScalarValueFunctor<T>(weight_value);
funcs::BroadcastKernel<T>(dev_ctx, inputs, &outputs, functor);
}

} // namespace phi
Expand Down
Loading
Loading