Skip to content

Commit b9e69e8

Browse files
anandoleecopybara-github
authored andcommitted
Add "absl::StatusOr<PythonConstMessagePointer> GetConstMessagePointer(PyObject* msg)" in proto_api which works with cpp extension, upb and pure python.
PiperOrigin-RevId: 699316527
1 parent aded9b7 commit b9e69e8

File tree

4 files changed

+209
-36
lines changed

4 files changed

+209
-36
lines changed

python/build_targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def build_targets(name):
439439
visibility = ["//visibility:public"],
440440
deps = [
441441
"//src/google/protobuf",
442+
"//src/google/protobuf/io",
442443
"@com_google_absl//absl/log:absl_check",
443444
"@com_google_absl//absl/status",
444445
"@system_python//:python_headers",

python/google/protobuf/proto_api.cc

+85
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#include "google/protobuf/proto_api.h"
22

3+
#include <Python.h>
4+
5+
#include <memory>
36
#include <string>
47

58
#include "absl/log/absl_check.h"
9+
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
610
#include "google/protobuf/message.h"
711
namespace google {
812
namespace protobuf {
@@ -52,6 +56,87 @@ PythonMessageMutator PyProto_API::CreatePythonMessageMutator(
5256
return PythonMessageMutator(owned_msg, msg, py_msg);
5357
}
5458

59+
PythonConstMessagePointer::PythonConstMessagePointer(Message* owned_msg,
60+
const Message* message,
61+
PyObject* py_msg)
62+
: owned_msg_(owned_msg), message_(message), py_msg_(py_msg) {
63+
ABSL_DCHECK(py_msg != nullptr);
64+
ABSL_DCHECK(message != nullptr);
65+
Py_INCREF(py_msg_);
66+
}
67+
68+
PythonConstMessagePointer::PythonConstMessagePointer(
69+
PythonConstMessagePointer&& other)
70+
: owned_msg_(other.owned_msg_ == nullptr ? nullptr
71+
: other.owned_msg_.release()),
72+
message_(other.message_),
73+
py_msg_(other.py_msg_) {
74+
other.message_ = nullptr;
75+
other.py_msg_ = nullptr;
76+
}
77+
78+
bool PythonConstMessagePointer::NotChanged() {
79+
ABSL_DCHECK(!PyErr_Occurred());
80+
if (owned_msg_ == nullptr) {
81+
return false;
82+
}
83+
84+
PyObject* py_serialized_pb(
85+
PyObject_CallMethod(py_msg_, "SerializeToString", nullptr));
86+
if (py_serialized_pb == nullptr) {
87+
PyErr_Format(PyExc_ValueError, "Fail to serialize py_msg");
88+
return false;
89+
}
90+
char* data;
91+
Py_ssize_t len;
92+
if (PyBytes_AsStringAndSize(py_serialized_pb, &data, &len) < 0) {
93+
Py_DECREF(py_serialized_pb);
94+
PyErr_Format(PyExc_ValueError, "Fail to get bytes from serialized data");
95+
return false;
96+
}
97+
98+
// Even if serialize python message deterministic above, the
99+
// serialize result may still diff between languages. So parse to
100+
// another c++ message for compare.
101+
std::unique_ptr<google::protobuf::Message> parsed_msg(owned_msg_->New());
102+
parsed_msg->ParseFromArray(data, static_cast<int>(len));
103+
std::string wire_other;
104+
google::protobuf::io::StringOutputStream stream_other(&wire_other);
105+
google::protobuf::io::CodedOutputStream output_other(&stream_other);
106+
output_other.SetSerializationDeterministic(true);
107+
parsed_msg->SerializeToCodedStream(&output_other);
108+
109+
std::string wire;
110+
google::protobuf::io::StringOutputStream stream(&wire);
111+
google::protobuf::io::CodedOutputStream output(&stream);
112+
output.SetSerializationDeterministic(true);
113+
owned_msg_->SerializeToCodedStream(&output);
114+
115+
if (wire == wire_other) {
116+
Py_DECREF(py_serialized_pb);
117+
return true;
118+
}
119+
PyErr_Format(PyExc_ValueError, "pymessage has been changed");
120+
Py_DECREF(py_serialized_pb);
121+
return false;
122+
}
123+
124+
PythonConstMessagePointer::~PythonConstMessagePointer() {
125+
if (py_msg_ == nullptr) {
126+
ABSL_DCHECK(message_ == nullptr);
127+
ABSL_DCHECK(owned_msg_ == nullptr);
128+
return;
129+
}
130+
ABSL_DCHECK(owned_msg_ != nullptr);
131+
ABSL_DCHECK(NotChanged());
132+
Py_DECREF(py_msg_);
133+
}
134+
135+
PythonConstMessagePointer PyProto_API::CreatePythonConstMessagePointer(
136+
Message* owned_msg, const Message* msg, PyObject* py_msg) const {
137+
return PythonConstMessagePointer(owned_msg, msg, py_msg);
138+
}
139+
55140
} // namespace python
56141
} // namespace protobuf
57142
} // namespace google

python/google/protobuf/proto_api.h

+41-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
// PyProtoAPICapsuleName(), 0));
1818
// if (!py_proto_api) { ...handle ImportError... }
1919
// Then use the methods of the returned class:
20-
// py_proto_api->GetMessagePointer(...);
20+
// py_proto_api->GetConstMessagePointer(...);
2121

2222
#ifndef GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
2323
#define GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
@@ -31,11 +31,14 @@
3131
#include "google/protobuf/descriptor_database.h"
3232
#include "google/protobuf/message.h"
3333

34+
PyObject* pymessage_mutate_const(PyObject* self, PyObject* args);
35+
3436
namespace google {
3537
namespace protobuf {
3638
namespace python {
3739

3840
class PythonMessageMutator;
41+
class PythonConstMessagePointer;
3942

4043
// Note on the implementation:
4144
// This API is designed after
@@ -55,23 +58,36 @@ struct PyProto_API {
5558
// Side-effect: The message will definitely be cleared. *When* the message
5659
// gets cleared is undefined (C++ will clear it up-front, python/upb will
5760
// clear it on destruction). Nothing should rely on the python message
58-
// during the lifetime of this object
61+
// during the lifetime of this object.
5962
// User should not hold onto the returned PythonMessageMutator while
60-
// calling back into Python
63+
// calling back into Python.
6164
// Warning: there is a risk of deadlock with Python/C++ if users use the
6265
// returned message->GetDescriptor()->file->pool()
6366
virtual absl::StatusOr<PythonMessageMutator> GetClearedMessageMutator(
6467
PyObject* msg) const = 0;
6568

69+
// Returns a PythonConstMessagePointer. For UPB and Pure Python, it points
70+
// to a new c++ message copied from python message. For cpp extension, it
71+
// points the internal c++ message.
72+
// User should not hold onto the returned PythonConstMessagePointer
73+
// while calling back into Python.
74+
virtual absl::StatusOr<PythonConstMessagePointer> GetConstMessagePointer(
75+
PyObject* msg) const = 0;
76+
6677
// If the passed object is a Python Message, returns its internal pointer.
6778
// Otherwise, returns NULL with an exception set.
79+
// TODO: Remove deprecated GetMessagePointer().
80+
[[deprecated(
81+
"GetMessagePointer() only work with Cpp Extension, "
82+
"please migrate to GetConstMessagePointer().")]]
6883
virtual const Message* GetMessagePointer(PyObject* msg) const = 0;
6984

7085
// If the passed object is a Python Message, returns a mutable pointer.
7186
// Otherwise, returns NULL with an exception set.
7287
// This function will succeed only if there are no other Python objects
7388
// pointing to the message, like submessages or repeated containers.
7489
// With the current implementation, only empty messages are in this case.
90+
// TODO: Remove deprecated GetMutableMessagePointer().
7591
[[deprecated(
7692
"GetMutableMessagePointer() only work with Cpp Extension, "
7793
"please migrate to GetClearedMessageMutator().")]]
@@ -133,6 +149,8 @@ struct PyProto_API {
133149
PythonMessageMutator CreatePythonMessageMutator(Message* owned_msg,
134150
Message* msg,
135151
PyObject* py_msg) const;
152+
PythonConstMessagePointer CreatePythonConstMessagePointer(
153+
Message* owned_msg, const Message* msg, PyObject* py_msg) const;
136154
};
137155

138156
// User should not hold onto this object while calling back into Python
@@ -161,6 +179,26 @@ class PythonMessageMutator {
161179
PyObject* py_msg_;
162180
};
163181

182+
class PythonConstMessagePointer {
183+
public:
184+
PythonConstMessagePointer(PythonConstMessagePointer&& other);
185+
~PythonConstMessagePointer();
186+
187+
const Message& get() { return *message_; }
188+
189+
private:
190+
friend struct google::protobuf::python::PyProto_API;
191+
PythonConstMessagePointer(Message* owned_msg, const Message* message,
192+
PyObject* py_msg);
193+
194+
friend PyObject* ::pymessage_mutate_const(PyObject* self, PyObject* args);
195+
// Check if the const message has been changed.
196+
bool NotChanged();
197+
std::unique_ptr<Message> owned_msg_;
198+
const Message* message_;
199+
PyObject* py_msg_;
200+
};
201+
164202
inline const char* PyProtoAPICapsuleName() {
165203
static const char kCapsuleName[] = "google.protobuf.pyext._message.proto_API";
166204
return kCapsuleName;

python/google/protobuf/pyext/message_module.cc

+82-33
Original file line numberDiff line numberDiff line change
@@ -158,54 +158,103 @@ google::protobuf::DynamicMessageFactory* GetFactory() {
158158
return factory;
159159
}
160160

161+
absl::StatusOr<google::protobuf::Message*> CreateNewMessage(PyObject* py_msg) {
162+
PyObject* pyd = PyObject_GetAttrString(py_msg, "DESCRIPTOR");
163+
if (pyd == nullptr) {
164+
return absl::InvalidArgumentError("py_msg has no attribute 'DESCRIPTOR'");
165+
}
166+
167+
PyObject* fn = PyObject_GetAttrString(pyd, "full_name");
168+
if (fn == nullptr) {
169+
return absl::InvalidArgumentError(
170+
"DESCRIPTOR has no attribute 'full_name'");
171+
}
172+
173+
const char* descriptor_full_name = PyUnicode_AsUTF8(fn);
174+
if (descriptor_full_name == nullptr) {
175+
return absl::InternalError("Fail to convert descriptor full name");
176+
}
177+
178+
PyObject* pyfile = PyObject_GetAttrString(pyd, "file");
179+
Py_DECREF(pyd);
180+
if (pyfile == nullptr) {
181+
return absl::InvalidArgumentError("DESCRIPTOR has no attribute 'file'");
182+
}
183+
auto gen_d = google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
184+
descriptor_full_name);
185+
if (gen_d) {
186+
Py_DECREF(pyfile);
187+
Py_DECREF(fn);
188+
return google::protobuf::MessageFactory::generated_factory()
189+
->GetPrototype(gen_d)
190+
->New();
191+
}
192+
auto d = FindMessageDescriptor(pyfile, descriptor_full_name);
193+
Py_DECREF(pyfile);
194+
RETURN_IF_ERROR(d.status());
195+
Py_DECREF(fn);
196+
return GetFactory()->GetPrototype(*d)->New();
197+
}
198+
199+
bool CopyToOwnedMsg(google::protobuf::Message** copy, const google::protobuf::Message& message) {
200+
*copy = message.New();
201+
std::string wire;
202+
message.SerializeToString(&wire);
203+
(*copy)->ParseFromArray(wire.data(), wire.size());
204+
return true;
205+
}
206+
161207
// C++ API. Clients get at this via proto_api.h
162208
struct ApiImplementation : google::protobuf::python::PyProto_API {
163209
absl::StatusOr<google::protobuf::python::PythonMessageMutator> GetClearedMessageMutator(
164210
PyObject* py_msg) const override {
165211
if (PyObject_TypeCheck(py_msg, google::protobuf::python::CMessage_Type)) {
166212
google::protobuf::Message* message =
167213
google::protobuf::python::PyMessage_GetMutableMessagePointer(py_msg);
214+
if (message == nullptr) {
215+
return absl::InternalError(
216+
"Fail to get message pointer. The message "
217+
"may already had a reference.");
218+
}
168219
message->Clear();
169220
return CreatePythonMessageMutator(nullptr, message, py_msg);
170221
}
171-
PyObject* pyd = PyObject_GetAttrString(py_msg, "DESCRIPTOR");
172-
if (pyd == nullptr) {
173-
return absl::InvalidArgumentError("py_msg has no attribute 'DESCRIPTOR'");
174-
}
175222

176-
PyObject* fn = PyObject_GetAttrString(pyd, "full_name");
177-
if (fn == nullptr) {
178-
return absl::InvalidArgumentError(
179-
"DESCRIPTOR has no attribute 'full_name'");
180-
}
223+
auto msg = CreateNewMessage(py_msg);
224+
RETURN_IF_ERROR(msg.status());
225+
return CreatePythonMessageMutator(*msg, *msg, py_msg);
226+
}
181227

182-
const char* descriptor_full_name = PyUnicode_AsUTF8(fn);
183-
if (descriptor_full_name == nullptr) {
184-
return absl::InternalError("Fail to convert descriptor full name");
228+
absl::StatusOr<google::protobuf::python::PythonConstMessagePointer>
229+
GetConstMessagePointer(PyObject* py_msg) const override {
230+
if (PyObject_TypeCheck(py_msg, google::protobuf::python::CMessage_Type)) {
231+
const google::protobuf::Message* message =
232+
google::protobuf::python::PyMessage_GetMessagePointer(py_msg);
233+
google::protobuf::Message* owned_msg = nullptr;
234+
ABSL_DCHECK(CopyToOwnedMsg(&owned_msg, *message));
235+
return CreatePythonConstMessagePointer(owned_msg, message, py_msg);
185236
}
186-
187-
PyObject* pyfile = PyObject_GetAttrString(pyd, "file");
188-
Py_DECREF(pyd);
189-
if (pyfile == nullptr) {
190-
return absl::InvalidArgumentError("DESCRIPTOR has no attribute 'file'");
237+
auto msg = CreateNewMessage(py_msg);
238+
RETURN_IF_ERROR(msg.status());
239+
PyObject* serialized_pb(
240+
PyObject_CallMethod(py_msg, "SerializeToString", nullptr));
241+
if (serialized_pb == nullptr) {
242+
return absl::InternalError("Fail to serialize py_msg");
191243
}
192-
auto gen_d =
193-
google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
194-
descriptor_full_name);
195-
if (gen_d) {
196-
Py_DECREF(pyfile);
197-
Py_DECREF(fn);
198-
google::protobuf::Message* msg = google::protobuf::MessageFactory::generated_factory()
199-
->GetPrototype(gen_d)
200-
->New();
201-
return CreatePythonMessageMutator(msg, msg, py_msg);
244+
char* data;
245+
Py_ssize_t len;
246+
if (PyBytes_AsStringAndSize(serialized_pb, &data, &len) < 0) {
247+
Py_DECREF(serialized_pb);
248+
return absl::InternalError(
249+
"Fail to get bytes from py_msg serialized data");
202250
}
203-
auto d = FindMessageDescriptor(pyfile, descriptor_full_name);
204-
Py_DECREF(pyfile);
205-
RETURN_IF_ERROR(d.status());
206-
Py_DECREF(fn);
207-
google::protobuf::Message* msg = GetFactory()->GetPrototype(*d)->New();
208-
return CreatePythonMessageMutator(msg, msg, py_msg);
251+
if (!(*msg)->ParseFromArray(data, len)) {
252+
Py_DECREF(serialized_pb);
253+
return absl::InternalError(
254+
"Couldn't parse py_message to google::protobuf::Message*!");
255+
}
256+
Py_DECREF(serialized_pb);
257+
return CreatePythonConstMessagePointer(*msg, *msg, py_msg);
209258
}
210259

211260
const google::protobuf::Message* GetMessagePointer(PyObject* msg) const override {

0 commit comments

Comments
 (0)