Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/eager/pylayer/py_layer_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ class GradNodePyLayer : public GradNodeBase {
GradNodePyLayer(const GradNodePyLayer& other) : GradNodeBase(other) {
this->ctx_ = other.ctx_;
Py_INCREF(this->ctx_);
this->name_ = other.name_;
this->forward_outputs_meta_ = other.forward_outputs_meta_;
this->forward_outputs_place_ = other.forward_outputs_place_;
this->forward_outputs_dist_attr_ = other.forward_outputs_dist_attr_;
this->forward_outputs_global_dims_ = other.forward_outputs_global_dims_;
this->forward_outputs_is_dist_meta_ = other.forward_outputs_is_dist_meta_;
this->grad_in_dtype_consistent_ = other.grad_in_dtype_consistent_;
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The member variable grad_in_dtype_consistent_ is being copied in the copy constructor but it's not initialized in the primary constructor. This could lead to undefined behavior when a GradNodePyLayer object is created using the primary constructor, as the boolean member will have an indeterminate value. Consider adding initialization in the primary constructor's member initializer list or providing a default value in the class declaration.

Copilot uses AI. Check for mistakes.
}

~GradNodePyLayer() override;
Expand Down
36 changes: 36 additions & 0 deletions test/legacy_test/test_pylayer_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,42 @@ def backward(ctx, dy):
z.backward()
self.assertEqual(cus_tanh_backward_input.dtype, paddle.float16)

def test_pylayer_with_partial_grad(self):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, x2, func1, func2=paddle.square):
ctx.func = func2
y1 = func1(x1)
y2 = func1(x2)
ctx.save_for_backward(y1, y2)
ctx.mark_non_differentiable(y2)
return y1, 1, y2, None

@staticmethod
def backward(ctx, dy1, dy2):
y1, y2 = ctx.saved_tensor()
re1 = dy1 * (1 - ctx.func(y1))
re2 = dy2 * (1 - paddle.square(y2))
return re1, re2

input1 = paddle.randn([2, 3]).astype("float64")
input2 = input1.detach().clone()
input1.stop_gradient = False
input2.stop_gradient = False
z = tanh.apply(input1, input1, paddle.tanh, paddle.square)
z = z[0] + z[2]
out = z.mean()
(input1_grad,) = paddle.grad(out, [input1], retain_graph=True)

y2_0 = paddle.tanh(input2)
y2_1 = paddle.tanh(input2)
y2_1.stop_gradient = True
z2 = y2_0 + y2_1
out2 = z2.mean()
(input2_grad,) = paddle.grad(out2, [input2], retain_graph=True)

np.testing.assert_allclose(input1_grad, input2_grad)


if __name__ == '__main__':
unittest.main()
Loading