VarInt Hacking with SLOP
PiperOrigin-RevId: 494835245
diff --git a/src/google/protobuf/extension_set_heavy.cc b/src/google/protobuf/extension_set_heavy.cc
index 7503d7d..8c51060 100644
--- a/src/google/protobuf/extension_set_heavy.cc
+++ b/src/google/protobuf/extension_set_heavy.cc
@@ -430,8 +430,8 @@
io::EpsCopyOutputStream stream(
target, MessageSetByteSize(),
io::CodedOutputStream::IsDefaultSerializationDeterministic());
- return InternalSerializeMessageSetWithCachedSizesToArray(extendee, target,
- &stream);
+ return stream.Finalize(InternalSerializeMessageSetWithCachedSizesToArray(
+ extendee, target, &stream));
}
} // namespace internal
diff --git a/src/google/protobuf/io/coded_stream.cc b/src/google/protobuf/io/coded_stream.cc
index f47c3fd..3f4027e 100644
--- a/src/google/protobuf/io/coded_stream.cc
+++ b/src/google/protobuf/io/coded_stream.cc
@@ -757,6 +757,43 @@
return s;
}
+uint8_t* EpsCopyOutputStream::FlushArray(uint8_t* ptr) {
+ GOOGLE_DCHECK(stream_ == nullptr);
+ if (PROTOBUF_PREDICT_FALSE(had_error_)) return ptr;
+ if (buffer_end_ != nullptr) {
+ const ptrdiff_t bytes = ptr - buffer_;
+ GOOGLE_DCHECK_GE(bytes, 0);
+ GOOGLE_DCHECK_LE(bytes, array_end_ - buffer_end_);
+ memcpy(buffer_end_, buffer_, static_cast<size_t>(bytes));
+ ptr = buffer_end_ + bytes;
+ buffer_end_ = nullptr;
+ }
+ return ptr;
+}
+
+std::pair<uint8_t*, uint8_t*> EpsCopyOutputStream::ConsumeArray(uint8_t* ptr,
+ int size) {
+ GOOGLE_DCHECK(array_end_ != nullptr);
+
+ if (buffer_end_ == nullptr) {
+ int avail = array_end_ - ptr;
+ if (size > avail) return {nullptr, nullptr};
+ return {ptr, ptr + size};
+ }
+ GOOGLE_DCHECK(ptr >= buffer_ && ptr <= buffer_ + kSlopBytes);
+
+ int bytes = ptr - buffer_;
+ int avail = array_end_ - buffer_end_ - bytes;
+ if (size > avail) return {nullptr, nullptr};
+
+ memcpy(buffer_end_, buffer_, bytes);
+ ptr = buffer_end_ + bytes;
+ buffer_end_ += bytes + size;
+ end_ = buffer_ + kSlopBytes;
+
+ return {ptr, buffer_};
+}
+
uint8_t* EpsCopyOutputStream::Trim(uint8_t* ptr) {
if (had_error_) return ptr;
int s = Flush(ptr);
@@ -879,6 +916,19 @@
}
uint8_t* EpsCopyOutputStream::EnsureSpaceFallback(uint8_t* ptr) {
+ if (array_end_) {
+ if (PROTOBUF_PREDICT_FALSE(had_error_)) return buffer_;
+ if (buffer_end_ == nullptr) {
+ buffer_end_ = ptr;
+ } else {
+ GOOGLE_DCHECK(ptr >= buffer_ && ptr <= buffer_ + kSlopBytes);
+ int bytes = ptr - buffer_;
+ memcpy(buffer_end_, buffer_, bytes);
+ buffer_end_ += bytes;
+ }
+ end_ = buffer_ + kSlopBytes;
+ return buffer_;
+ }
do {
if (PROTOBUF_PREDICT_FALSE(had_error_)) return buffer_;
int overrun = ptr - end_;
@@ -892,6 +942,13 @@
uint8_t* EpsCopyOutputStream::WriteRawFallback(const void* data, int size,
uint8_t* ptr) {
+ if (array_end_) {
+ auto ptrs = ConsumeArray(ptr, size);
+ if (PROTOBUF_PREDICT_FALSE(ptrs.first == nullptr)) return Error();
+ memcpy(ptrs.first, data, size);
+ return ptrs.second;
+ }
+
int s = GetSize(ptr);
while (s < size) {
std::memcpy(ptr, data, s);
@@ -965,16 +1022,16 @@
#endif
uint8_t* EpsCopyOutputStream::WriteCord(const absl::Cord& cord, uint8_t* ptr) {
- int s = GetSize(ptr);
if (stream_ == nullptr) {
- if (static_cast<int64_t>(cord.size()) <= s) {
- // Just copy it to the current buffer.
- return CopyCordToArray(cord, ptr);
- } else {
- return Error();
- }
- } else if (static_cast<int64_t>(cord.size()) <= s &&
- static_cast<int64_t>(cord.size()) < kMaxCordBytesToCopy) {
+ auto ptrs = ConsumeArray(ptr, static_cast<int>(cord.size()));
+ if (PROTOBUF_PREDICT_FALSE(ptrs.first == nullptr)) return Error();
+ // Just copy it to the current buffer.
+ CopyCordToArray(cord, ptr);
+ return ptrs.second;
+ }
+ int s = GetSize(ptr);
+ if (static_cast<int64_t>(cord.size()) <= s &&
+ static_cast<int64_t>(cord.size()) < kMaxCordBytesToCopy) {
// Just copy it to the current buffer.
return CopyCordToArray(cord, ptr);
} else {
diff --git a/src/google/protobuf/io/coded_stream.h b/src/google/protobuf/io/coded_stream.h
index 98ff721..425189a 100644
--- a/src/google/protobuf/io/coded_stream.h
+++ b/src/google/protobuf/io/coded_stream.h
@@ -111,6 +111,7 @@
#include <assert.h>
+#include <algorithm>
#include <atomic>
#include <climits>
#include <cstddef>
@@ -655,10 +656,12 @@
// pointed to the end of the array. When using this the total size is already
// known, so no need to maintain the slop region.
EpsCopyOutputStream(void* data, int size, bool deterministic)
- : end_(static_cast<uint8_t*>(data) + size),
+ : array_end_(static_cast<uint8_t*>(data) + size),
buffer_end_(nullptr),
stream_(nullptr),
- is_serialization_deterministic_(deterministic) {}
+ is_serialization_deterministic_(deterministic) {
+ end_ = array_end_ - std::min<int>(kSlopBytes, size);
+ }
// Initialize from stream but with the first buffer already given (eager).
EpsCopyOutputStream(void* data, int size, ZeroCopyOutputStream* stream,
@@ -667,10 +670,20 @@
*pp = SetInitialBuffer(data, size);
}
+ // Finlizes this instance. Invokes `Trim()` for stream bound instances and
+ // `FlushArray()` for array bound instances.
+ uint8_t* Finalize(uint8_t* ptr) {
+ return stream_ ? Trim(ptr) : FlushArray(ptr);
+ }
+
// Flush everything that's written into the underlying ZeroCopyOutputStream
// and trims the underlying stream to the location of ptr.
uint8_t* Trim(uint8_t* ptr);
+ // Flushes any yet unwritten data into the array provided at construction.
+ // Returns a pointer directly beyond the last byte written into the array.
+ uint8_t* FlushArray(uint8_t* ptr);
+
// After this it's guaranteed you can safely write kSlopBytes to ptr. This
// will never fail! The underlying stream can produce an error. Use HadError
// to check for errors.
@@ -831,6 +844,7 @@
private:
+ uint8_t* array_end_ = nullptr;
uint8_t* end_;
uint8_t* buffer_end_ = buffer_;
uint8_t buffer_[2 * kSlopBytes];
@@ -840,6 +854,8 @@
bool is_serialization_deterministic_;
bool skip_check_consistency = false;
+ std::pair<uint8_t*, uint8_t*> ConsumeArray(uint8_t* ptr, int size);
+
uint8_t* EnsureSpaceFallback(uint8_t* ptr);
inline uint8_t* Next();
int Flush(uint8_t* ptr);
@@ -913,6 +929,13 @@
PROTOBUF_ALWAYS_INLINE static uint8_t* UnsafeVarint(T value, uint8_t* ptr) {
static_assert(std::is_unsigned<T>::value,
"Varint serialization must be unsigned");
+#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)
+ // Force the promise that we always have 'kSlopBytes` on serialization
+ // materializes here as well as 'we can always write up to 10 varint bytes`.
+ // Note that we write 0 values as CodedStream relies on zero init data in
+ // any slop buffers, but that is fine, we just want to trigger sanitizers.
+ memset(ptr, 0, 10);
+#endif
while (PROTOBUF_PREDICT_FALSE(value >= 0x80)) {
*ptr = static_cast<uint8_t>(value | 0x80);
value >>= 7;
diff --git a/src/google/protobuf/message_lite.cc b/src/google/protobuf/message_lite.cc
index fc0f56a..b20c762 100644
--- a/src/google/protobuf/message_lite.cc
+++ b/src/google/protobuf/message_lite.cc
@@ -382,7 +382,7 @@
io::EpsCopyOutputStream out(
target, size,
io::CodedOutputStream::IsDefaultSerializationDeterministic());
- uint8_t* res = msg._InternalSerialize(target, &out);
+ uint8_t* res = out.Finalize(msg._InternalSerialize(target, &out));
GOOGLE_ABSL_DCHECK(target + size == res);
return res;
}
@@ -568,7 +568,7 @@
io::EpsCopyOutputStream out(
target, static_cast<int>(available.size()),
io::CodedOutputStream::IsDefaultSerializationDeterministic());
- auto res = _InternalSerialize(target, &out);
+ uint8_t* res = out.Finalize(_InternalSerialize(target, &out));
GOOGLE_ABSL_DCHECK_EQ(res, target + size);
buffer.IncreaseLengthBy(size);
output->Append(std::move(buffer));
diff --git a/src/google/protobuf/wire_format.h b/src/google/protobuf/wire_format.h
index 5abd813..d9ecb13 100644
--- a/src/google/protobuf/wire_format.h
+++ b/src/google/protobuf/wire_format.h
@@ -184,8 +184,9 @@
io::EpsCopyOutputStream stream(
target, static_cast<int>(ComputeUnknownFieldsSize(unknown_fields)),
io::CodedOutputStream::IsDefaultSerializationDeterministic());
- return InternalSerializeUnknownFieldsToArray(unknown_fields, target,
- &stream);
+ target =
+ InternalSerializeUnknownFieldsToArray(unknown_fields, target, &stream);
+ return stream.Finalize(target);
}
static uint8_t* InternalSerializeUnknownFieldsToArray(
const UnknownFieldSet& unknown_fields, uint8_t* target,
diff --git a/src/google/protobuf/wire_format_lite.h b/src/google/protobuf/wire_format_lite.h
index b95098e..3a46042 100644
--- a/src/google/protobuf/wire_format_lite.h
+++ b/src/google/protobuf/wire_format_lite.h
@@ -658,7 +658,8 @@
static_cast<int>(2 * io::CodedOutputStream::VarintSize32(
static_cast<uint32_t>(field_number) << 3)),
io::CodedOutputStream::IsDefaultSerializationDeterministic());
- return InternalWriteGroup(field_number, value, target, &stream);
+ target = InternalWriteGroup(field_number, value, target, &stream);
+ return stream.Finalize(target);
}
PROTOBUF_NDEBUG_INLINE static uint8_t* WriteMessageToArray(
int field_number, const MessageLite& value, uint8_t* target) {
@@ -669,8 +670,9 @@
static_cast<uint32_t>(field_number) << 3) +
io::CodedOutputStream::VarintSize32(size)),
io::CodedOutputStream::IsDefaultSerializationDeterministic());
- return InternalWriteMessage(field_number, value, value.GetCachedSize(),
- target, &stream);
+ target = InternalWriteMessage(field_number, value, value.GetCachedSize(),
+ target, &stream);
+ return stream.Finalize(target);
}
// Compute the byte size of a field. The XxSize() functions do NOT include