diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 00e31996a..71cdc5d20 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -3640,6 +3640,45 @@ "inplace" ] }, + "torch.nn.Embedding": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.Embedding", + "args_list": [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "sparse" + ], + "unsupport_args": [ + "max_norm", + "norm_type", + "scale_grad_by_freq" + ] + }, + "torch.nn.functional.embedding": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.functional.embedding", + "args_list": [ + "input", + "weight", + "padding_idx", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "sparse" + ], + "kwargs_change": { + "input": "x" + }, + "unsupport_args": [ + "max_norm", + "norm_type", + "scale_grad_by_freq" + ] + }, "torch.nn.Hardshrink": { "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.Hardshrink", diff --git a/tests/test_nn_Embedding.py b/tests/test_nn_Embedding.py new file mode 100644 index 000000000..0a7076035 --- /dev/null +++ b/tests/test_nn_Embedding.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.Embedding") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + embedding = torch.nn.Embedding(4, 3) + w0 = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + with torch.no_grad(): + embedding.weight[0]=w0[0] + embedding.weight[1]=w0[1] + embedding.weight[3]=w0[3] + x = torch.LongTensor([[0],[1],[3]]) + result = embedding(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + padding_idx = 0 + embedding = torch.nn.Embedding(4, 3,padding_idx=padding_idx) + w0 = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + with torch.no_grad(): + embedding.weight[0]=w0[0] + embedding.weight[1]=w0[1] + embedding.weight[2]=w0[2] + embedding.weight[3]=w0[3] + x = torch.LongTensor([[0],[1],[3]]) + result = embedding(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + padding_idx = 0 + embedding = torch.nn.Embedding(4, 3,padding_idx=padding_idx,max_norm=2.0) + w0 = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + with torch.no_grad(): + embedding.weight[0]=w0[0] + embedding.weight[1]=w0[1] + embedding.weight[2]=w0[2] + embedding.weight[3]=w0[3] + x = torch.LongTensor([[0],[1],[3]]) + result = embedding(x) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="paddle unsupport") diff --git a/tests/test_nn_functional_embedding.py b/tests/test_nn_functional_embedding.py new file mode 100644 index 000000000..abc4a11e8 --- /dev/null +++ b/tests/test_nn_functional_embedding.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.functional.embedding") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import numpy as np + embedding_matrix = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + + x = torch.tensor(np.array([[0,1],[2,3]])) + result = torch.nn.functional.embedding(x,embedding_matrix) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import numpy as np + embedding_matrix = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + + x = torch.tensor(np.array([[0,1],[2,3]])) + result = torch.nn.functional.embedding(x,embedding_matrix,padding_idx=0) + """ + ) + + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + w0 = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + result = torch.nn.functional.embedding(x,embedding_matrix,padding_idx=0,max_norm=2) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="paddle unsupport")