Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -4891,6 +4891,15 @@
"target": "label"
}
},
"torch.nn.MSELoss": {
"Matcher": "MseLossMatcher",
"paddle_api": "paddle.nn.MSELoss",
"args_list": [
"size_average",
"reduce",
"reduction"
]
},
"torch.nn.functional.margin_ranking_loss": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.margin_ranking_loss",
Expand Down Expand Up @@ -8536,6 +8545,21 @@
"pos_weight"
]
},
"torch.nn.functional.binary_cross_entropy": {
"Matcher": "FunctionalBinaryCrossEntropyMatcher",
"paddle_api": "paddle.nn.functional.binary_cross_entropy",
"args_list": [
"input",
"target",
"weight",
"size_average",
"reduce",
"reduction"
],
"kwargs_change": {
"target": "label"
}
},
"torch.nn.functional.max_pool2d": {
"Matcher": "FunctionalMaxPool2DMatcher",
"paddle_api": "paddle.nn.functional.max_pool2d",
Expand Down Expand Up @@ -8612,6 +8636,16 @@
"pos_weight"
]
},
"torch.nn.BCELoss": {
"Matcher": "BCELossMatcher",
"paddle_api": "paddle.nn.BCELoss",
"args_list": [
"weight",
"size_average",
"reduce",
"reduction"
]
},
"torch.utils.data.BatchSampler": {
"Matcher": "TorchUtilDataBatchSampler",
"args_list": [
Expand Down Expand Up @@ -8642,6 +8676,49 @@
"input": "x"
}
},

"torch.nn.L1Loss": {
"Matcher": "L1LossMatcher",
"paddle_api": "paddle.nn.L1Loss",
"args_list": [
"size_average",
"reduce",
"reduction"
]
},
"torch.nn.Unfold": {
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Jun 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以用genericmatcher吧,改成那个吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel_size 参数 pytorch支持tuple,paddle不支持,改为genericmatcher遇到tuple会报错

"Matcher": "UnfoldMatcher",
"paddle_api": "paddle.nn.Unfold",
"args_list": [
"kernel_size",
"dilation",
"padding",
"stride"
],
"kwargs_change": {
"kernel_size": "kernel_sizes",
"dilation": "dilations",
"padding": "paddings",
"stride": "strides"
}
},
"torch.nn.functional.unfold": {
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Jun 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以用genericmatcher吧,改成那个吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同torch.nn.Unfold

"Matcher": "FunctionalUnfoldMatcher",
"paddle_api": "paddle.nn.functional.unfold",
"args_list": [
"input",
"kernel_size",
"dilation",
"padding",
"stride"
],
"kwargs_change": {
"input": "x",
"kernel_size": "kernel_sizes",
"dilation": "dilations",
"padding": "paddings",
"stride": "strides"

"torch.nn.modules.batchnorm._BatchNorm": {
"Matcher": "Modules_BatchNormBaseMatcher",
"paddle_api": "paddle.nn.layer.norm._BatchNormBase",
Expand All @@ -8656,6 +8733,7 @@
],
"kwargs_change": {
"eps": "epsilon"

}
}
}
254 changes: 254 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3594,6 +3594,55 @@ def generate_code(self, kwargs):
return code


class MseLossMatcher(BaseMatcher):
def generate_code(self, kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这一块的重复度很高,是否可以统一成一个Matcher

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if "size_average" in kwargs:
size_average = kwargs.pop("size_average")
if "True" in size_average:
size_average = True
elif "False" in size_average:
size_average = False
else:
size_average = None
else:
size_average = None

if "reduce" in kwargs:
reduce = kwargs.pop("reduce")
if "True" in reduce:
reduce = True
elif "False" in reduce:
reduce = False
else:
reduce = None
else:
reduce = None

if size_average is not None or reduce is not None:
if size_average is None:
size_average = True
if reduce is None:
reduce = True

if size_average and reduce:
reduction = '"""mean"""'
elif reduce:
reduction = '"""sum"""'
else:
reduction = '"""none"""'

kwargs["reduction"] = reduction

API_TEMPLATE = textwrap.dedent(
"""
paddle.nn.MSELoss({})
"""
)
code = API_TEMPLATE.format(self.kwargs_to_str(kwargs))

return code


class TupleAssignMatcher(BaseMatcher):
def generate_code(self, kwargs):
kwargs_change = {}
Expand Down Expand Up @@ -3728,6 +3777,210 @@ def generate_code(self, kwargs):
return GenericMatcher.generate_code(self, kwargs)


class L1LossMatcher(BaseMatcher):
def generate_code(self, kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

逻辑可以写成对每个kwargs遍历,判断是否kwargs,每个分支里再判断是否list,一共4个分支。用new_kwargs来接收kwargs,不然参数顺序会改变,导致代码风格不太好

for k in list(kwargs.keys()):
    if kwargs_change:
          if tuple:
          else:
    else:
          if tuple:
          else:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if "size_average" in kwargs:
size_average = kwargs.pop("size_average")
if "True" in size_average:
size_average = True
elif "False" in size_average:
size_average = False
else:
size_average = None
else:
size_average = None

if "reduce" in kwargs:
reduce = kwargs.pop("reduce")
if "True" in reduce:
reduce = True
elif "False" in reduce:
reduce = False
else:
reduce = None
else:
reduce = None

if size_average is not None or reduce is not None:
if size_average is None:
size_average = True
if reduce is None:
reduce = True

if size_average and reduce:
reduction = '"""mean"""'
elif reduce:
reduction = '"""sum"""'
else:
reduction = '"""none"""'

kwargs["reduction"] = reduction

API_TEMPLATE = textwrap.dedent(
"""
paddle.nn.L1Loss({})
"""
)

code = API_TEMPLATE.format(self.kwargs_to_str(kwargs))

return code


class BCELossMatcher(BaseMatcher):
def generate_code(self, kwargs):

if "size_average" in kwargs:
size_average = kwargs.pop("size_average")
if "True" in size_average:
size_average = True
elif "False" in size_average:
size_average = False
else:
size_average = None
else:
size_average = None

if "reduce" in kwargs:
reduce = kwargs.pop("reduce")
if "True" in reduce:
reduce = True
elif "False" in reduce:
reduce = False
else:
reduce = None
else:
reduce = None

if size_average is not None or reduce is not None:
if size_average is None:
size_average = True
if reduce is None:
reduce = True

if size_average and reduce:
reduction = '"""mean"""'
elif reduce:
reduction = '"""sum"""'
else:
reduction = '"""none"""'

kwargs["reduction"] = reduction

API_TEMPLATE = textwrap.dedent(
"""
paddle.nn.BCELoss({})
"""
)

code = API_TEMPLATE.format(self.kwargs_to_str(kwargs))

return code


class FunctionalBinaryCrossEntropyMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "size_average" in kwargs:
size_average = kwargs.pop("size_average")
if "True" in size_average:
size_average = True
elif "False" in size_average:
size_average = False
else:
size_average = None
else:
size_average = None

if "reduce" in kwargs:
reduce = kwargs.pop("reduce")
if "True" in reduce:
reduce = True
elif "False" in reduce:
reduce = False
else:
reduce = None
else:
reduce = None

if size_average is not None or reduce is not None:
if size_average is None:
size_average = True
if reduce is None:
reduce = True

if size_average and reduce:
reduction = '"""mean"""'
elif reduce:
reduction = '"""sum"""'
else:
reduction = '"""none"""'

kwargs["reduction"] = reduction

if "kwargs_change" in self.api_mapping:
kwargs_change = self.api_mapping["kwargs_change"]
for key in list(kwargs_change.keys()):
if key in kwargs:
kwargs[kwargs_change[key]] = kwargs[key]
kwargs.pop(key)

API_TEMPLACE = textwrap.dedent(
"""
paddle.nn.functional.binary_cross_entropy({})
"""
)
code = API_TEMPLACE.format(self.kwargs_to_str(kwargs))

return code


class UnfoldMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以起一个通用的名字,这个主要功能是把tuple转成list:
可以叫Tuple2ListMatcher

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def generate_code(self, kwargs):
if "kwargs_change" in self.api_mapping:
kwargs_change = self.api_mapping["kwargs_change"]
for key in list(kwargs_change.keys()):
if key in kwargs:
if isinstance(ast.literal_eval(kwargs[key]), tuple):
kwargs[key] = list(ast.literal_eval(kwargs[key]))
kwargs[kwargs_change[key]] = kwargs[key]
kwargs.pop(key)

if "paddings" not in kwargs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是默认值就不用单独设置了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

kwargs["paddings"] = 0

API_TEMPLACE = textwrap.dedent(
"""
paddle.nn.Unfold({})
"""
)
code = API_TEMPLACE.format(self.kwargs_to_str(kwargs))

return code


class FunctionalUnfoldMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "kwargs_change" in self.api_mapping:
kwargs_change = self.api_mapping["kwargs_change"]
for key in list(kwargs_change.keys()):
if key in kwargs:
if "input" not in key:
if isinstance(ast.literal_eval(kwargs[key]), tuple):
kwargs[key] = list(ast.literal_eval(kwargs[key]))
kwargs[kwargs_change[key]] = kwargs[key]
kwargs.pop(key)

API_TEMPLACE = textwrap.dedent(
"""
paddle.nn.functional.unfold({})
"""
)

code = API_TEMPLACE.format(self.kwargs_to_str(kwargs))

return code

class ParameterMatcher(BaseMatcher):
def get_paddle_nodes(self, args, kwargs):
kwargs = self.parse_args_and_kwargs(args, kwargs)
Expand Down Expand Up @@ -3948,3 +4201,4 @@ def generate_code(self, kwargs):
if "dim" not in kwargs:
return None
return GenericMatcher.generate_code(self, kwargs)

Loading