Skip to content
15 changes: 11 additions & 4 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
sign,
sin,
sum,
tan,
tanh,
)
from paddle.base.libpaddle import DataType
Expand Down Expand Up @@ -125,7 +126,6 @@
sqrt_,
square,
square_,
tan,
tan_,
)

Expand Down Expand Up @@ -5043,7 +5043,10 @@ def rad2deg(x: Tensor, name: str | None = None) -> Tensor:
return out


def deg2rad(x: Tensor, name: str | None = None) -> Tensor:
@param_one_alias(['x', 'input'])
def deg2rad(
x: Tensor, name: str | None = None, *, out: Tensor | None = None
) -> Tensor:
r"""
Convert each of the elements of input x from degrees to angles in radians.

Expand All @@ -5054,6 +5057,7 @@ def deg2rad(x: Tensor, name: str | None = None) -> Tensor:
Args:
x (Tensor): An N-D Tensor, the data type is float32, float64, int32, int64.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output Tensor. If set, the result will be stored in this Tensor. Default is None.

Returns:
out (Tensor): An N-D Tensor, the shape and data type is the same with input (The output data type is float32 when the input data type is int).
Expand All @@ -5080,7 +5084,7 @@ def deg2rad(x: Tensor, name: str | None = None) -> Tensor:
if in_dynamic_or_pir_mode():
if convert_dtype(x.dtype) in ['int32', 'int64']:
x = cast(x, dtype="float32")
return _C_ops.scale(x, deg2rad_scale, 0.0, True)
return _C_ops.scale(x, deg2rad_scale, 0.0, True, out=out)
else:
check_variable_and_dtype(
x, 'x', ['int32', 'int64', 'float32', 'float64'], 'deg2rad'
Expand All @@ -5097,7 +5101,10 @@ def deg2rad(x: Tensor, name: str | None = None) -> Tensor:
outputs={'Out': out_cast},
attrs={'in_dtype': x.dtype, 'out_dtype': paddle.float32},
)
out = helper.create_variable_for_type_inference(dtype=out_cast.dtype)
if out is None:
out = helper.create_variable_for_type_inference(
dtype=out_cast.dtype
)
helper.append_op(
type='scale',
inputs={'X': out_cast},
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6377,6 +6377,7 @@ class TestActivationAPI_Compatibility(unittest.TestCase):
("paddle.tanh", np.tanh, {'min_val': -1.0, 'max_val': 1.0}),
("paddle.cosh", np.cosh, {'min_val': -1.0, 'max_val': 1.0}),
("paddle.sinh", np.sinh, {'min_val': -1.0, 'max_val': 1.0}),
("paddle.tan", np.tan, {'min_val': -1.0, 'max_val': 1.0}),
]
ACTIVATION_NOT_METHOD_CONFIGS = [
(
Expand Down
130 changes: 127 additions & 3 deletions test/legacy_test/test_deg2rad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

os.environ['FLAGS_enable_pir_api'] = '0'

import numpy as np
from op_test import get_device_place

Expand Down Expand Up @@ -61,22 +64,143 @@ def test_dygraph(self):


class TestDeg2radAPI2(TestDeg2radAPI):
# Test input data type is int
# Test input data type is int64
def setUp(self):
self.x_np = [180]
self.x_np = np.array([180]).astype(np.int64)
self.x_shape = [1]
self.out_np = np.pi
self.x_dtype = 'int64'

def test_dygraph(self):
paddle.disable_static()

x2 = paddle.to_tensor([180])
# Test int64 input
x2 = paddle.to_tensor([180], dtype="int64")
result2 = paddle.deg2rad(x2)
np.testing.assert_allclose(np.pi, result2.numpy(), rtol=1e-05)

paddle.enable_static()


class TestDeg2radAPI3(TestDeg2radAPI):
# Test input data type is int32
def setUp(self):
self.x_np = np.array([180]).astype(np.int32)
self.x_shape = [1]
self.out_np = np.pi
self.x_dtype = 'int32'

def test_dygraph(self):
paddle.disable_static()

# Test int32 input
x3 = paddle.to_tensor([180], dtype="int32")
result3 = paddle.deg2rad(x3)
np.testing.assert_allclose(np.pi, result3.numpy(), rtol=1e-05)

paddle.enable_static()


class TestDeg2radAPI4(TestDeg2radAPI):
# Test input data type is float32
def setUp(self):
self.x_np = np.array(
[180.0, -180.0, 360.0, -360.0, 90.0, -90.0]
).astype(np.float32)
self.x_shape = [6]
self.out_np = np.deg2rad(self.x_np)
self.x_dtype = 'float32'


class TestDeg2radAliasAndOut(unittest.TestCase):
def test_alias(self):
paddle.disable_static()
x = paddle.to_tensor([180.0])
expected = np.deg2rad(180.0)

# Test alias
res = paddle.deg2rad(input=x)
np.testing.assert_allclose(res.numpy(), expected, rtol=1e-05)

paddle.enable_static()

def test_out(self):
paddle.disable_static()
x = paddle.to_tensor([180.0])
expected = np.deg2rad(180.0)

# Test without out parameter (default None)
res_no_out = paddle.deg2rad(x)
np.testing.assert_allclose(res_no_out.numpy(), expected, rtol=1e-05)

# Test out parameter with float input
out = paddle.zeros([1], dtype="float32")
res = paddle.deg2rad(x, out=out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PIR无需测out

np.testing.assert_allclose(out.numpy(), expected, rtol=1e-05)
self.assertTrue(res is out)

# Test out parameter with int64 input
x_int = paddle.to_tensor([180], dtype="int64")
out_float = paddle.zeros([1], dtype="float32")
res = paddle.deg2rad(x_int, out=out_float)
np.testing.assert_allclose(out_float.numpy(), expected, rtol=1e-05)
self.assertTrue(res is out_float)

# Test out parameter with int32 input
x_int32 = paddle.to_tensor([180], dtype="int32")
out_float32 = paddle.zeros([1], dtype="float32")
res = paddle.deg2rad(x_int32, out=out_float32)
np.testing.assert_allclose(out_float32.numpy(), expected, rtol=1e-05)
self.assertTrue(res is out_float32)

paddle.enable_static()


class TestDeg2radStaticOut(unittest.TestCase):
def test_static_out_float(self):
"""Test out parameter in static graph with float input"""
paddle.enable_static()
startup_program = paddle.static.Program()
train_program = paddle.static.Program()
with paddle.static.program_guard(startup_program, train_program):
x = paddle.static.data(name='input', dtype='float32', shape=[1])
out = paddle.static.data(name='out', dtype='float32', shape=[1])
result = paddle.deg2rad(x, out=out)

place = get_device_place()
exe = base.Executor(place)
x_np = np.array([180.0]).astype(np.float32)
out_np = np.zeros([1]).astype(np.float32)
expected = np.deg2rad(180.0)

res, out_res = exe.run(
feed={'input': x_np, 'out': out_np},
fetch_list=[result, out],
)
np.testing.assert_allclose(out_res, expected, rtol=1e-05)

def test_static_out_int(self):
"""Test out parameter in static graph with int input"""
paddle.enable_static()
startup_program = paddle.static.Program()
train_program = paddle.static.Program()
with paddle.static.program_guard(startup_program, train_program):
x = paddle.static.data(name='input', dtype='int64', shape=[1])
out = paddle.static.data(name='out', dtype='float32', shape=[1])
result = paddle.deg2rad(x, out=out)

place = get_device_place()
exe = base.Executor(place)
x_np = np.array([180]).astype(np.int64)
out_np = np.zeros([1]).astype(np.float32)
expected = np.deg2rad(180.0)

res, out_res = exe.run(
feed={'input': x_np, 'out': out_np},
fetch_list=[result, out],
)
np.testing.assert_allclose(out_res, expected, rtol=1e-05)


if __name__ == '__main__':
unittest.main()
Loading