blob: ce58a840e42347381b930e2d0ec8053fe0a2fced [file] [log] [blame]
#include "google/protobuf/proto_api.h"
#include <Python.h>
#include <climits>
#include <memory>
#include <string>
#include "absl/log/absl_check.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/message.h"
namespace google {
namespace protobuf {
namespace python {
PythonMessageMutator::PythonMessageMutator(Message* owned_msg, Message* message,
PyObject* py_msg)
: owned_msg_(owned_msg), message_(message), py_msg_(py_msg) {
ABSL_DCHECK(py_msg != nullptr);
ABSL_DCHECK(message != nullptr);
Py_INCREF(py_msg_);
}
PythonMessageMutator::PythonMessageMutator(PythonMessageMutator&& other)
: owned_msg_(other.owned_msg_ == nullptr ? nullptr
: other.owned_msg_.release()),
message_(other.message_),
py_msg_(other.py_msg_) {
other.message_ = nullptr;
other.py_msg_ = nullptr;
}
PythonMessageMutator::~PythonMessageMutator() {
if (py_msg_ == nullptr) {
return;
}
// PyErr_Occurred check is required because PyObject_CallMethod need this
// check.
if (!PyErr_Occurred() && owned_msg_ != nullptr) {
std::string wire;
message_->SerializePartialToString(&wire);
PyObject* py_wire = PyBytes_FromStringAndSize(
wire.data(), static_cast<Py_ssize_t>(wire.size()));
PyObject* parse =
PyObject_CallMethod(py_msg_, "ParseFromString", "O", py_wire);
Py_DECREF(py_wire);
if (parse != nullptr) {
Py_DECREF(parse);
}
}
Py_DECREF(py_msg_);
}
PythonMessageMutator PyProto_API::CreatePythonMessageMutator(
Message* owned_msg, Message* msg, PyObject* py_msg) const {
return PythonMessageMutator(owned_msg, msg, py_msg);
}
PythonConstMessagePointer::PythonConstMessagePointer(Message* owned_msg,
const Message* message,
PyObject* py_msg)
: owned_msg_(owned_msg), message_(message), py_msg_(py_msg) {
ABSL_DCHECK(py_msg != nullptr);
ABSL_DCHECK(message != nullptr);
Py_INCREF(py_msg_);
}
PythonConstMessagePointer::PythonConstMessagePointer(
PythonConstMessagePointer&& other)
: owned_msg_(other.owned_msg_ == nullptr ? nullptr
: other.owned_msg_.release()),
message_(other.message_),
py_msg_(other.py_msg_) {
other.message_ = nullptr;
other.py_msg_ = nullptr;
}
bool PythonConstMessagePointer::NotChanged() {
ABSL_DCHECK(!PyErr_Occurred());
if (owned_msg_ == nullptr) {
return false;
}
// Skip the check if too large. Parse won't work
// for messages larger than 2 GB.
if (message_->ByteSizeLong() > INT_MAX) {
return true;
}
PyObject* py_serialized_pb(
PyObject_CallMethod(py_msg_, "SerializePartialToString", nullptr));
if (py_serialized_pb == nullptr) {
PyErr_Format(PyExc_ValueError, "Fail to serialize py_msg");
return false;
}
char* data;
Py_ssize_t len;
if (PyBytes_AsStringAndSize(py_serialized_pb, &data, &len) < 0) {
Py_DECREF(py_serialized_pb);
PyErr_Format(PyExc_ValueError, "Fail to get bytes from serialized data");
return false;
}
// Even if serialize python message deterministic above, the
// serialize result may still diff between languages. So parse to
// another c++ message for compare.
std::unique_ptr<google::protobuf::Message> parsed_msg(owned_msg_->New());
parsed_msg->ParsePartialFromArray(data, static_cast<int>(len));
std::string wire_other;
google::protobuf::io::StringOutputStream stream_other(&wire_other);
google::protobuf::io::CodedOutputStream output_other(&stream_other);
output_other.SetSerializationDeterministic(true);
parsed_msg->SerializePartialToCodedStream(&output_other);
output_other.Trim();
std::string wire;
google::protobuf::io::StringOutputStream stream(&wire);
google::protobuf::io::CodedOutputStream output(&stream);
output.SetSerializationDeterministic(true);
owned_msg_->SerializePartialToCodedStream(&output);
output.Trim();
if (wire == wire_other) {
Py_DECREF(py_serialized_pb);
return true;
}
PyErr_Format(PyExc_ValueError, "pymessage has been changed");
Py_DECREF(py_serialized_pb);
return false;
}
PythonConstMessagePointer::~PythonConstMessagePointer() {
if (py_msg_ == nullptr) {
ABSL_DCHECK(message_ == nullptr);
ABSL_DCHECK(owned_msg_ == nullptr);
return;
}
ABSL_DCHECK(owned_msg_ != nullptr);
ABSL_DCHECK(NotChanged());
Py_DECREF(py_msg_);
}
PythonConstMessagePointer PyProto_API::CreatePythonConstMessagePointer(
Message* owned_msg, const Message* msg, PyObject* py_msg) const {
return PythonConstMessagePointer(owned_msg, msg, py_msg);
}
} // namespace python
} // namespace protobuf
} // namespace google