From 386305beb83449d77fbfbd94a955ab1ba481082d Mon Sep 17 00:00:00 2001 From: LokeZhou Date: Fri, 9 Jun 2023 07:52:46 +0000 Subject: [PATCH 1/5] add embedding test --- tests/test_nn_Embedding.py | 61 +++++++++++++++++++++++++++ tests/test_nn_functional_embedding.py | 54 ++++++++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 tests/test_nn_Embedding.py create mode 100644 tests/test_nn_functional_embedding.py diff --git a/tests/test_nn_Embedding.py b/tests/test_nn_Embedding.py new file mode 100644 index 000000000..a6ea74cf2 --- /dev/null +++ b/tests/test_nn_Embedding.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.Embedding") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + embedding = torch.nn.Embedding(4, 3) + w0 = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + with torch.no_grad(): + embedding.weight[0]=w0[0] + embedding.weight[1]=w0[1] + embedding.weight[3]=w0[3] + x = torch.LongTensor([[0],[1],[3]]) + result = embedding(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + padding_idx = 0 + embedding = torch.nn.Embedding(4, 3,padding_idx=padding_idx) + w0 = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + with torch.no_grad(): + embedding.weight[0]=w0[0] + embedding.weight[1]=w0[1] + embedding.weight[2]=w0[2] + embedding.weight[3]=w0[3] + x = torch.LongTensor([[0],[1],[3]]) + result = embedding(x) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_embedding.py b/tests/test_nn_functional_embedding.py new file mode 100644 index 000000000..6a0556f47 --- /dev/null +++ b/tests/test_nn_functional_embedding.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.functional.embedding") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import numpy as np + embedding_matrix = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + + x = torch.tensor(np.array([[0,1],[2,3]])) + result = torch.nn.functional.embedding(x,embedding_matrix) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import numpy as np + embedding_matrix = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + + x = torch.tensor(np.array([[0,1],[2,3]])) + result = torch.nn.functional.embedding(x,embedding_matrix,padding_idx=0) + """ + ) + + obj.run(pytorch_code, ["result"]) From 3b098e53c249aec60c1032ab09d778b868717701 Mon Sep 17 00:00:00 2001 From: LokeZhou Date: Fri, 9 Jun 2023 08:02:13 +0000 Subject: [PATCH 2/5] add embedding --- paconvert/api_mapping.json | 39 ++++++++++++++++++++++++++++++++++++++ paconvert/api_matcher.py | 33 ++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 00e31996a..d5f33787c 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -3640,6 +3640,45 @@ "inplace" ] }, + "torch.nn.Embedding": { + "Matcher": "EmbeddingMatcher", + "paddle_api": "paddle.nn.Embedding", + "args_list": [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "sparse" + ], + "unsupport_args": [ + "max_norm", + "norm_type", + "scale_grad_by_freq" + ] + }, + "torch.nn.functional.embedding": { + "Matcher": "FunctionalEmbeddingDMatcher", + "paddle_api": "paddle.nn.functional.embedding", + "args_list": [ + "input", + "weight", + "padding_idx", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "sparse" + ], + "kwargs_change": { + "input": "x" + }, + "unsupport_args": [ + "max_norm", + "norm_type", + "scale_grad_by_freq" + ] + }, "torch.nn.Hardshrink": { "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.Hardshrink", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index c7e6eef95..2b7128fa4 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3614,3 +3614,36 @@ def generate_code(self, kwargs): if "n" in kwargs and kwargs["n"] != "(1)": return None return GenericMatcher.generate_code(self, kwargs) + + +class EmbeddingMatcher(BaseMatcher): + def generate_code(self, kwargs): + if "max_norm" in kwargs and kwargs["max_norm"] is not None: + return None + if "norm_type" in kwargs: + return None + if "scale_grad_by_freq" in kwargs: + return None + + code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(kwargs)) + return code + + +class FunctionalEmbeddingDMatcher(BaseMatcher): + def generate_code(self, kwargs): + if "max_norm" in kwargs and kwargs["max_norm"] is not None: + return None + if "norm_type" in kwargs: + return None + if "scale_grad_by_freq" in kwargs: + return None + + if "kwargs_change" in self.api_mapping: + kwargs_change = self.api_mapping["kwargs_change"] + for key in list(kwargs_change.keys()): + if key in kwargs: + kwargs[kwargs_change[key]] = kwargs[key] + kwargs.pop(key) + + code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(kwargs)) + return code From b68fe02d94c08e4cd52439fe45f5036a94b7be6a Mon Sep 17 00:00:00 2001 From: LokeZhou Date: Wed, 14 Jun 2023 13:09:21 +0000 Subject: [PATCH 3/5] fix embedding --- paconvert/api_mapping.json | 2 +- paconvert/api_matcher.py | 13 ------------- tests/test_nn_Embedding.py | 22 ++++++++++++++++++++++ 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index d5f33787c..dec37d3ce 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -3641,7 +3641,7 @@ ] }, "torch.nn.Embedding": { - "Matcher": "EmbeddingMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.Embedding", "args_list": [ "num_embeddings", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 2b7128fa4..cebe0948b 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3616,19 +3616,6 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) -class EmbeddingMatcher(BaseMatcher): - def generate_code(self, kwargs): - if "max_norm" in kwargs and kwargs["max_norm"] is not None: - return None - if "norm_type" in kwargs: - return None - if "scale_grad_by_freq" in kwargs: - return None - - code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(kwargs)) - return code - - class FunctionalEmbeddingDMatcher(BaseMatcher): def generate_code(self, kwargs): if "max_norm" in kwargs and kwargs["max_norm"] is not None: diff --git a/tests/test_nn_Embedding.py b/tests/test_nn_Embedding.py index a6ea74cf2..0a7076035 100644 --- a/tests/test_nn_Embedding.py +++ b/tests/test_nn_Embedding.py @@ -59,3 +59,25 @@ def test_case_2(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + padding_idx = 0 + embedding = torch.nn.Embedding(4, 3,padding_idx=padding_idx,max_norm=2.0) + w0 = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + with torch.no_grad(): + embedding.weight[0]=w0[0] + embedding.weight[1]=w0[1] + embedding.weight[2]=w0[2] + embedding.weight[3]=w0[3] + x = torch.LongTensor([[0],[1],[3]]) + result = embedding(x) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="paddle unsupport") From 6e8045fb908f6691b80228b09ae4445007fb1c81 Mon Sep 17 00:00:00 2001 From: LokeZhou Date: Mon, 19 Jun 2023 12:19:13 +0000 Subject: [PATCH 4/5] fix nn_functional_embedding --- paconvert/api_mapping.json | 2 +- paconvert/api_matcher.py | 36 +++++++++++++-------------- tests/test_nn_functional_embedding.py | 14 +++++++++++ 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index dec37d3ce..71cdc5d20 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -3659,7 +3659,7 @@ ] }, "torch.nn.functional.embedding": { - "Matcher": "FunctionalEmbeddingDMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.functional.embedding", "args_list": [ "input", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index cebe0948b..913128070 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3616,21 +3616,21 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) -class FunctionalEmbeddingDMatcher(BaseMatcher): - def generate_code(self, kwargs): - if "max_norm" in kwargs and kwargs["max_norm"] is not None: - return None - if "norm_type" in kwargs: - return None - if "scale_grad_by_freq" in kwargs: - return None - - if "kwargs_change" in self.api_mapping: - kwargs_change = self.api_mapping["kwargs_change"] - for key in list(kwargs_change.keys()): - if key in kwargs: - kwargs[kwargs_change[key]] = kwargs[key] - kwargs.pop(key) - - code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(kwargs)) - return code +# class FunctionalEmbeddingMatcher(BaseMatcher): +# def generate_code(self, kwargs): +# if "max_norm" in kwargs and kwargs["max_norm"] is not None: +# return None +# if "norm_type" in kwargs: +# return None +# if "scale_grad_by_freq" in kwargs: +# return None + +# if "kwargs_change" in self.api_mapping: +# kwargs_change = self.api_mapping["kwargs_change"] +# for key in list(kwargs_change.keys()): +# if key in kwargs: +# kwargs[kwargs_change[key]] = kwargs[key] +# kwargs.pop(key) + +# code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(kwargs)) +# return code diff --git a/tests/test_nn_functional_embedding.py b/tests/test_nn_functional_embedding.py index 6a0556f47..abc4a11e8 100644 --- a/tests/test_nn_functional_embedding.py +++ b/tests/test_nn_functional_embedding.py @@ -52,3 +52,17 @@ def test_case_2(): ) obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + w0 = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + result = torch.nn.functional.embedding(x,embedding_matrix,padding_idx=0,max_norm=2) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="paddle unsupport") From 03a5cdcabd1a091df8749d2f5a1d66fad87a2ae6 Mon Sep 17 00:00:00 2001 From: LokeZhou Date: Mon, 19 Jun 2023 12:21:28 +0000 Subject: [PATCH 5/5] delete annotation code --- paconvert/api_matcher.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 913128070..c7e6eef95 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3614,23 +3614,3 @@ def generate_code(self, kwargs): if "n" in kwargs and kwargs["n"] != "(1)": return None return GenericMatcher.generate_code(self, kwargs) - - -# class FunctionalEmbeddingMatcher(BaseMatcher): -# def generate_code(self, kwargs): -# if "max_norm" in kwargs and kwargs["max_norm"] is not None: -# return None -# if "norm_type" in kwargs: -# return None -# if "scale_grad_by_freq" in kwargs: -# return None - -# if "kwargs_change" in self.api_mapping: -# kwargs_change = self.api_mapping["kwargs_change"] -# for key in list(kwargs_change.keys()): -# if key in kwargs: -# kwargs[kwargs_change[key]] = kwargs[key] -# kwargs.pop(key) - -# code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(kwargs)) -# return code