Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
IMPORT_TEMPLATE = """
import paddle
from paddle import _C_ops
from paddle.tensor import magic_method_func
from .. import core
"""

Expand Down Expand Up @@ -56,10 +57,18 @@ def _{name}(*args, **kwargs):
return _C_ops.{name}(*args, **kwargs)
"""
SET_METHOD_TEMPLATE = """
# set methods for paddle.Tensor in dygraph
# set methods && magical methods for paddle.Tensor in dygraph
local_tensor = core.eager.Tensor

magic_method_dict = {v: k for k, v in magic_method_func}

for method_name, method in methods_map:
setattr(local_tensor, method_name, method)

magic_name = magic_method_dict.get(method_name)
if magic_name:
setattr(local_tensor, magic_name, method)

setattr(paddle.tensor, method_name, method)

"""
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/pybind/arg_pre_process.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,25 @@ void LogsumexpPreProcess(pir::Value* x,
void SumPreProcess(Value* x, Value* axis) {
paddle::dialect::SetStopGradient(axis);
}

void BinCountPreProcess(Tensor* x,
paddle::optional<Tensor>* weights,
Scalar* minlength) {
CheckDataType("bincount",
"x",
x->dtype(),
{phi::DataType::INT32, phi::DataType::INT64});
}

void BinCountPreProcess(Value* x,
paddle::optional<Value>* weights,
Value* minlength) {
CheckDataType("bincount",
"x",
pir::GetValueDtype(*x),
{phi::DataType::INT32, phi::DataType::INT64});
}

void IsClosePreProcess(Value* x, Value* y, Value* rtol, Value* atol) {
/*
if in_pir_mode():
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pybind/arg_pre_process.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace pybind {
using Tensor = paddle::Tensor;
using Value = pir::Value;
using IntArray = paddle::experimental::IntArray;
using Scalar = paddle::experimental::Scalar;
using IntVector = std::vector<int64_t>;

void ExpandAsPreProcess(paddle::Tensor* x,
Expand All @@ -39,6 +40,13 @@ void ExpandAsPreProcess(Value* x,
void RollPreProcess(Tensor* x, IntArray* shifts, IntVector* axis);
void RollPreProcess(Value* x, Value* shifts, IntVector* axis);

void BinCountPreProcess(Tensor* x,
paddle::optional<Tensor>* weights,
Scalar* minlength);
void BinCountPreProcess(Value* x,
paddle::optional<Value>* weights,
Value* minlength);

void LogsumexpPreProcess(Tensor* x, std::vector<int>* axis, bool* reduce_all);
void LogsumexpPreProcess(Value* x, std::vector<int>* axis, bool* reduce_all);

Expand Down
Loading
Loading