Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4821,8 +4821,14 @@ def logit_(
return _C_ops.logit_(x, eps)


@param_two_alias(["x", "input"], ["y", "end"])
def lerp(
x: Tensor, y: Tensor, weight: float | Tensor, name: str | None = None
x: Tensor,
y: Tensor,
weight: float | Tensor,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor:
r"""
Does a linear interpolation between x and y based on weight.
Expand All @@ -4834,9 +4840,12 @@ def lerp(

Args:
x (Tensor): An N-D Tensor with starting points, the data type is bfloat16, float16, float32, float64.
Alias: ``input`` .
y (Tensor): An N-D Tensor with ending points, the data type is bfloat16, float16, float32, float64.
Alias: ``end`` .
weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is bfloat16, float16, float32, float64.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output Tensor. If provided, the result will be stored in `out`. Default is None.

Returns:
out (Tensor): An N-D Tensor, the shape and data type is the same with input.
Expand All @@ -4862,7 +4871,8 @@ def lerp(
weight = paddle.full(shape=[], fill_value=weight, dtype=x.dtype)

if in_dynamic_or_pir_mode():
return _C_ops.lerp(x, y, weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果out非None,这里没有处理。这里其实不需要处理,因为_C_ops里面自己会处理是否为None的情况

if out is None:
return _C_ops.lerp(x, y, weight, out=out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以看下ops.yaml里,使用 Scalar(double) weight 的形式能否下沉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没太理解,可以稍微解释一下吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没太理解,可以稍微解释一下吗

下沉的主要问题是这个weight可以传float或Tensor,paddle内部内型Scalar可以接收Tensor以及任意的float/int/double等scalar类型,所以可以将op参数改成Scalar类型。这个你可以试一下,参考下allclose就是用scalar实现的,但是还需要改kernel的参数类型,不排除有兼容性问题。如果有问题就还是用装饰器。

else:
check_variable_and_dtype(
x, 'x', ['uint16', 'float16', 'float32', 'float64'], 'lerp'
Expand All @@ -4879,11 +4889,13 @@ def lerp(

helper = LayerHelper('lerp', **locals())
inputs = {'X': x, 'Y': y, 'Weight': weight}
out = helper.create_variable_for_type_inference(dtype=x.dtype)
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='lerp', inputs=inputs, outputs={'Out': out})
return out


@param_two_alias(["x", "input"], ["y", "end"])
@inplace_apis_in_dygraph_only
def lerp_(
x: Tensor, y: Tensor, weight: float | Tensor, name: str | None = None
Expand Down
26 changes: 26 additions & 0 deletions test/legacy_test/test_lerp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,32 @@ def test_x_y_broadcast_w(self):
np.testing.assert_allclose(res_ref, out.numpy(), rtol=1e-05)
paddle.enable_static()

def test_alias(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
w = paddle.to_tensor(self.w)

# Test with input, end
out1 = paddle.lerp(input=x, end=y, weight=w)
np.testing.assert_allclose(out1.numpy(), self.res_ref, rtol=1e-05)

# Test with x, y (alias)
out2 = paddle.lerp(x=x, y=y, weight=w)
np.testing.assert_allclose(out2.numpy(), self.res_ref, rtol=1e-05)
paddle.enable_static()

def test_out(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
w = paddle.to_tensor(self.w)
out = paddle.empty_like(x)

paddle.lerp(x, y, w, out=out)
np.testing.assert_allclose(out.numpy(), self.res_ref, rtol=1e-05)
paddle.enable_static()


@unittest.skipIf(
not (core.is_compiled_with_cuda() or is_custom_device())
Expand Down
Loading