diff --git a/paddle/fluid/eager/pylayer/py_layer_node.h b/paddle/fluid/eager/pylayer/py_layer_node.h index 3e4a7df2ccda9b..2a0cbdcf0135b7 100644 --- a/paddle/fluid/eager/pylayer/py_layer_node.h +++ b/paddle/fluid/eager/pylayer/py_layer_node.h @@ -45,8 +45,13 @@ class GradNodePyLayer : public GradNodeBase { GradNodePyLayer(const GradNodePyLayer& other) : GradNodeBase(other) { this->ctx_ = other.ctx_; Py_INCREF(this->ctx_); + this->name_ = other.name_; this->forward_outputs_meta_ = other.forward_outputs_meta_; this->forward_outputs_place_ = other.forward_outputs_place_; + this->forward_outputs_dist_attr_ = other.forward_outputs_dist_attr_; + this->forward_outputs_global_dims_ = other.forward_outputs_global_dims_; + this->forward_outputs_is_dist_meta_ = other.forward_outputs_is_dist_meta_; + this->grad_in_dtype_consistent_ = other.grad_in_dtype_consistent_; } ~GradNodePyLayer() override; diff --git a/test/legacy_test/test_pylayer_op.py b/test/legacy_test/test_pylayer_op.py index 6d48a95860619d..1506c0317e2fd4 100644 --- a/test/legacy_test/test_pylayer_op.py +++ b/test/legacy_test/test_pylayer_op.py @@ -779,6 +779,42 @@ def backward(ctx, dy): z.backward() self.assertEqual(cus_tanh_backward_input.dtype, paddle.float16) + def test_pylayer_with_partial_grad(self): + class tanh(PyLayer): + @staticmethod + def forward(ctx, x1, x2, func1, func2=paddle.square): + ctx.func = func2 + y1 = func1(x1) + y2 = func1(x2) + ctx.save_for_backward(y1, y2) + ctx.mark_non_differentiable(y2) + return y1, 1, y2, None + + @staticmethod + def backward(ctx, dy1, dy2): + y1, y2 = ctx.saved_tensor() + re1 = dy1 * (1 - ctx.func(y1)) + re2 = dy2 * (1 - paddle.square(y2)) + return re1, re2 + + input1 = paddle.randn([2, 3]).astype("float64") + input2 = input1.detach().clone() + input1.stop_gradient = False + input2.stop_gradient = False + z = tanh.apply(input1, input1, paddle.tanh, paddle.square) + z = z[0] + z[2] + out = z.mean() + (input1_grad,) = paddle.grad(out, [input1], retain_graph=True) + + y2_0 = paddle.tanh(input2) + y2_1 = paddle.tanh(input2) + y2_1.stop_gradient = True + z2 = y2_0 + y2_1 + out2 = z2.mean() + (input2_grad,) = paddle.grad(out2, [input2], retain_graph=True) + + np.testing.assert_allclose(input1_grad, input2_grad) + if __name__ == '__main__': unittest.main()