blob: e01d1342f7b45b3a2e8786411de5627548b0703b [file] [log] [blame]
Hong Shind7087fc2024-06-27 11:04:25 -07001// Protocol Buffers - Google's data interchange format
2// Copyright 2023 Google LLC. All rights reserved.
3//
4// Use of this source code is governed by a BSD-style
5// license that can be found in the LICENSE file or at
6// https://developers.google.com/open-source/licenses/bsd
7
8#ifndef PROTOBUF_HPB_HPB_H_
9#define PROTOBUF_HPB_HPB_H_
10
11#include <type_traits>
12#include <vector>
13
14#include "absl/status/status.h"
15#include "absl/status/statusor.h"
16#include "upb/base/status.hpp"
17#include "upb/mem/arena.hpp"
18#include "upb/message/copy.h"
19#include "upb/mini_table/extension.h"
20#include "upb/wire/decode.h"
21#include "upb/wire/encode.h"
22
23namespace protos {
24
25using Arena = ::upb::Arena;
26class ExtensionRegistry;
27
28template <typename T>
29using Proxy = std::conditional_t<std::is_const<T>::value,
30 typename std::remove_const_t<T>::CProxy,
31 typename T::Proxy>;
32
33// Provides convenient access to Proxy and CProxy message types.
34//
35// Using rebinding and handling of const, Ptr<Message> and Ptr<const Message>
36// allows copying const with T* const and avoids using non-copyable Proxy types
37// directly.
38template <typename T>
39class Ptr final {
40 public:
41 Ptr() = delete;
42
43 // Implicit conversions
44 Ptr(T* m) : p_(m) {} // NOLINT
45 Ptr(const Proxy<T>* p) : p_(*p) {} // NOLINT
46 Ptr(Proxy<T> p) : p_(p) {} // NOLINT
47 Ptr(const Ptr& m) = default;
48
49 Ptr& operator=(Ptr v) & {
50 Proxy<T>::Rebind(p_, v.p_);
51 return *this;
52 }
53
54 Proxy<T> operator*() const { return p_; }
55 Proxy<T>* operator->() const {
56 return const_cast<Proxy<T>*>(std::addressof(p_));
57 }
58
59#ifdef __clang__
60#pragma clang diagnostic push
61#pragma clang diagnostic ignored "-Wclass-conversion"
62#endif
63 template <typename U = T, std::enable_if_t<!std::is_const<U>::value, int> = 0>
64 operator Ptr<const T>() const {
65 Proxy<const T> p(p_);
66 return Ptr<const T>(&p);
67 }
68#ifdef __clang__
69#pragma clang diagnostic pop
70#endif
71
72 private:
73 Ptr(upb_Message* msg, upb_Arena* arena) : p_(msg, arena) {} // NOLINT
74
75 friend class Ptr<const T>;
76 friend typename T::Access;
77
78 Proxy<T> p_;
79};
80
Hong Shin096b1392024-07-03 10:04:09 -070081// Suppress -Wctad-maybe-unsupported with our manual deduction guide
82template <typename T>
83Ptr(T* m) -> Ptr<T>;
84
Hong Shind7087fc2024-06-27 11:04:25 -070085inline absl::string_view UpbStrToStringView(upb_StringView str) {
86 return absl::string_view(str.data, str.size);
87}
88
89// TODO: update bzl and move to upb runtime / protos.cc.
90inline upb_StringView UpbStrFromStringView(absl::string_view str,
91 upb_Arena* arena) {
92 const size_t str_size = str.size();
93 char* buffer = static_cast<char*>(upb_Arena_Malloc(arena, str_size));
94 memcpy(buffer, str.data(), str_size);
95 return upb_StringView_FromDataAndSize(buffer, str_size);
96}
97
98template <typename T>
99typename T::Proxy CreateMessage(::protos::Arena& arena) {
100 return typename T::Proxy(upb_Message_New(T::minitable(), arena.ptr()),
101 arena.ptr());
102}
103
104// begin:github_only
105// // This type exists to work around an absl type that has not yet been
106// // released.
107// struct SourceLocation {
108// static SourceLocation current() { return {}; }
109// absl::string_view file_name() { return "<unknown>"; }
110// int line() { return 0; }
111// };
112// end:github_only
113
114// begin:google_only
115using SourceLocation = absl::SourceLocation;
116// end:google_only
117
118absl::Status MessageAllocationError(
119 SourceLocation loc = SourceLocation::current());
120
121absl::Status ExtensionNotFoundError(
122 int extension_number, SourceLocation loc = SourceLocation::current());
123
124absl::Status MessageDecodeError(upb_DecodeStatus status,
125 SourceLocation loc = SourceLocation::current());
126
127absl::Status MessageEncodeError(upb_EncodeStatus status,
128 SourceLocation loc = SourceLocation::current());
129
130namespace internal {
131struct PrivateAccess {
132 template <typename T>
133 static auto* GetInternalMsg(T&& message) {
134 return message->msg();
135 }
136 template <typename T>
137 static auto Proxy(upb_Message* p, upb_Arena* arena) {
138 return typename T::Proxy(p, arena);
139 }
140 template <typename T>
141 static auto CProxy(const upb_Message* p, upb_Arena* arena) {
142 return typename T::CProxy(p, arena);
143 }
144 template <typename T>
145 static auto CreateMessage(upb_Arena* arena) {
146 return typename T::Proxy(upb_Message_New(T::minitable(), arena), arena);
147 }
148};
149
150template <typename T>
151auto* GetInternalMsg(T&& message) {
152 return PrivateAccess::GetInternalMsg(std::forward<T>(message));
153}
154
155template <typename T>
156T CreateMessage() {
157 return T();
158}
159
160template <typename T>
161typename T::Proxy CreateMessageProxy(upb_Message* msg, upb_Arena* arena) {
162 return typename T::Proxy(msg, arena);
163}
164
165template <typename T>
166typename T::CProxy CreateMessage(const upb_Message* msg, upb_Arena* arena) {
167 return PrivateAccess::CProxy<T>(msg, arena);
168}
169
170class ExtensionMiniTableProvider {
171 public:
172 constexpr explicit ExtensionMiniTableProvider(
173 const upb_MiniTableExtension* mini_table_ext)
174 : mini_table_ext_(mini_table_ext) {}
175 const upb_MiniTableExtension* mini_table_ext() const {
176 return mini_table_ext_;
177 }
178
179 private:
180 const upb_MiniTableExtension* mini_table_ext_;
181};
182
183// -------------------------------------------------------------------
184// ExtensionIdentifier
185// This is the type of actual extension objects. E.g. if you have:
186// extend Foo {
187// optional MyExtension bar = 1234;
188// }
189// then "bar" will be defined in C++ as:
190// ExtensionIdentifier<Foo, MyExtension> bar(&namespace_bar_ext);
191template <typename ExtendeeType, typename ExtensionType>
192class ExtensionIdentifier : public ExtensionMiniTableProvider {
193 public:
194 using Extension = ExtensionType;
195 using Extendee = ExtendeeType;
196
197 constexpr explicit ExtensionIdentifier(
198 const upb_MiniTableExtension* mini_table_ext)
199 : ExtensionMiniTableProvider(mini_table_ext) {}
200};
201
202template <typename T>
203upb_Arena* GetArena(Ptr<T> message) {
204 return static_cast<upb_Arena*>(message->GetInternalArena());
205}
206
207template <typename T>
208upb_Arena* GetArena(T* message) {
209 return static_cast<upb_Arena*>(message->GetInternalArena());
210}
211
212template <typename T>
213const upb_MiniTable* GetMiniTable(const T*) {
214 return T::minitable();
215}
216
217template <typename T>
218const upb_MiniTable* GetMiniTable(Ptr<T>) {
219 return T::minitable();
220}
221
222upb_ExtensionRegistry* GetUpbExtensions(
223 const ExtensionRegistry& extension_registry);
224
225absl::StatusOr<absl::string_view> Serialize(const upb_Message* message,
226 const upb_MiniTable* mini_table,
227 upb_Arena* arena, int options);
228
229bool HasExtensionOrUnknown(const upb_Message* msg,
230 const upb_MiniTableExtension* eid);
231
232bool GetOrPromoteExtension(upb_Message* msg, const upb_MiniTableExtension* eid,
233 upb_Arena* arena, upb_MessageValue* value);
234
235void DeepCopy(upb_Message* target, const upb_Message* source,
236 const upb_MiniTable* mini_table, upb_Arena* arena);
237
238upb_Message* DeepClone(const upb_Message* source,
239 const upb_MiniTable* mini_table, upb_Arena* arena);
240
241absl::Status MoveExtension(upb_Message* message, upb_Arena* message_arena,
242 const upb_MiniTableExtension* ext,
243 upb_Message* extension, upb_Arena* extension_arena);
244
245absl::Status SetExtension(upb_Message* message, upb_Arena* message_arena,
246 const upb_MiniTableExtension* ext,
247 const upb_Message* extension);
248
249template <typename T>
250struct RemovePtr;
251
252template <typename T>
253struct RemovePtr<Ptr<T>> {
254 using type = T;
255};
256
257template <typename T>
258struct RemovePtr<T*> {
259 using type = T;
260};
261
262template <typename T>
263using RemovePtrT = typename RemovePtr<T>::type;
264
265template <typename T, typename U = RemovePtrT<T>,
266 typename = std::enable_if_t<!std::is_const_v<U>>>
267using PtrOrRaw = T;
268
269} // namespace internal
270
271template <typename T>
272void DeepCopy(Ptr<const T> source_message, Ptr<T> target_message) {
273 static_assert(!std::is_const_v<T>);
274 ::protos::internal::DeepCopy(
275 internal::GetInternalMsg(target_message),
276 internal::GetInternalMsg(source_message), T::minitable(),
277 static_cast<upb_Arena*>(target_message->GetInternalArena()));
278}
279
280template <typename T>
281typename T::Proxy CloneMessage(Ptr<T> message, upb_Arena* arena) {
282 return internal::PrivateAccess::Proxy<T>(
283 ::protos::internal::DeepClone(internal::GetInternalMsg(message),
284 T::minitable(), arena),
285 arena);
286}
287
288template <typename T>
289void DeepCopy(Ptr<const T> source_message, T* target_message) {
290 static_assert(!std::is_const_v<T>);
291 DeepCopy(source_message, protos::Ptr(target_message));
292}
293
294template <typename T>
295void DeepCopy(const T* source_message, Ptr<T> target_message) {
296 static_assert(!std::is_const_v<T>);
297 DeepCopy(protos::Ptr(source_message), target_message);
298}
299
300template <typename T>
301void DeepCopy(const T* source_message, T* target_message) {
302 static_assert(!std::is_const_v<T>);
303 DeepCopy(protos::Ptr(source_message), protos::Ptr(target_message));
304}
305
306template <typename T>
307void ClearMessage(internal::PtrOrRaw<T> message) {
308 auto ptr = Ptr(message);
309 auto minitable = internal::GetMiniTable(ptr);
310 upb_Message_Clear(internal::GetInternalMsg(ptr), minitable);
311}
312
313class ExtensionRegistry {
314 public:
315 ExtensionRegistry(
316 const std::vector<const ::protos::internal::ExtensionMiniTableProvider*>&
317 extensions,
318 const upb::Arena& arena)
319 : registry_(upb_ExtensionRegistry_New(arena.ptr())) {
320 if (registry_) {
321 for (const auto& ext_provider : extensions) {
322 const auto* ext = ext_provider->mini_table_ext();
323 bool success = upb_ExtensionRegistry_AddArray(registry_, &ext, 1);
324 if (!success) {
325 registry_ = nullptr;
326 break;
327 }
328 }
329 }
330 }
331
332 private:
333 friend upb_ExtensionRegistry* ::protos::internal::GetUpbExtensions(
334 const ExtensionRegistry& extension_registry);
335 upb_ExtensionRegistry* registry_;
336};
337
338template <typename T>
339using EnableIfProtosClass = std::enable_if_t<
340 std::is_base_of<typename T::Access, T>::value &&
341 std::is_base_of<typename T::Access, typename T::ExtendableType>::value>;
342
343template <typename T>
344using EnableIfMutableProto = std::enable_if_t<!std::is_const<T>::value>;
345
346template <typename T, typename Extendee, typename Extension,
347 typename = EnableIfProtosClass<T>>
348ABSL_MUST_USE_RESULT bool HasExtension(
349 Ptr<T> message,
350 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
351 return ::protos::internal::HasExtensionOrUnknown(
352 ::protos::internal::GetInternalMsg(message), id.mini_table_ext());
353}
354
355template <typename T, typename Extendee, typename Extension,
356 typename = EnableIfProtosClass<T>>
357ABSL_MUST_USE_RESULT bool HasExtension(
358 const T* message,
359 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
360 return HasExtension(protos::Ptr(message), id);
361}
362
363template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
364 typename = EnableIfMutableProto<T>>
365void ClearExtension(
366 Ptr<T> message,
367 const ::protos::internal::ExtensionIdentifier<T, Extension>& id) {
368 static_assert(!std::is_const_v<T>, "");
369 upb_Message_ClearExtension(internal::GetInternalMsg(message),
370 id.mini_table_ext());
371}
372
373template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
374void ClearExtension(
375 T* message,
376 const ::protos::internal::ExtensionIdentifier<T, Extension>& id) {
377 ClearExtension(::protos::Ptr(message), id);
378}
379
380template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
381 typename = EnableIfMutableProto<T>>
382absl::Status SetExtension(
383 Ptr<T> message,
384 const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
385 const Extension& value) {
386 static_assert(!std::is_const_v<T>);
387 auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
388 return ::protos::internal::SetExtension(internal::GetInternalMsg(message),
389 message_arena, id.mini_table_ext(),
390 internal::GetInternalMsg(&value));
391}
392
393template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
394 typename = EnableIfMutableProto<T>>
395absl::Status SetExtension(
396 Ptr<T> message,
397 const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
398 Ptr<Extension> value) {
399 static_assert(!std::is_const_v<T>);
400 auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
401 return ::protos::internal::SetExtension(internal::GetInternalMsg(message),
402 message_arena, id.mini_table_ext(),
403 internal::GetInternalMsg(value));
404}
405
406template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
407 typename = EnableIfMutableProto<T>>
408absl::Status SetExtension(
409 Ptr<T> message,
410 const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
411 Extension&& value) {
412 Extension ext = std::move(value);
413 static_assert(!std::is_const_v<T>);
414 auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
415 auto* extension_arena = static_cast<upb_Arena*>(ext.GetInternalArena());
416 return ::protos::internal::MoveExtension(
417 internal::GetInternalMsg(message), message_arena, id.mini_table_ext(),
418 internal::GetInternalMsg(&ext), extension_arena);
419}
420
421template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
422absl::Status SetExtension(
423 T* message, const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
424 const Extension& value) {
425 return ::protos::SetExtension(::protos::Ptr(message), id, value);
426}
427
428template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
429absl::Status SetExtension(
430 T* message, const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
431 Extension&& value) {
432 return ::protos::SetExtension(::protos::Ptr(message), id,
433 std::forward<Extension>(value));
434}
435
436template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
437absl::Status SetExtension(
438 T* message, const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
439 Ptr<Extension> value) {
440 return ::protos::SetExtension(::protos::Ptr(message), id, value);
441}
442
443template <typename T, typename Extendee, typename Extension,
444 typename = EnableIfProtosClass<T>>
445absl::StatusOr<Ptr<const Extension>> GetExtension(
446 Ptr<T> message,
447 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
448 // TODO: Fix const correctness issues.
449 upb_MessageValue value;
450 const bool ok = ::protos::internal::GetOrPromoteExtension(
451 const_cast<upb_Message*>(internal::GetInternalMsg(message)),
452 id.mini_table_ext(), ::protos::internal::GetArena(message), &value);
453 if (!ok) {
454 return ExtensionNotFoundError(
455 upb_MiniTableExtension_Number(id.mini_table_ext()));
456 }
457 return Ptr<const Extension>(::protos::internal::CreateMessage<Extension>(
458 (upb_Message*)value.msg_val, ::protos::internal::GetArena(message)));
459}
460
461template <typename T, typename Extendee, typename Extension,
462 typename = EnableIfProtosClass<T>>
463absl::StatusOr<Ptr<const Extension>> GetExtension(
464 const T* message,
465 const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
466 return GetExtension(protos::Ptr(message), id);
467}
468
469template <typename T>
470ABSL_MUST_USE_RESULT bool Parse(Ptr<T> message, absl::string_view bytes) {
471 static_assert(!std::is_const_v<T>);
472 upb_Message_Clear(internal::GetInternalMsg(message),
473 ::protos::internal::GetMiniTable(message));
474 auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
475 return upb_Decode(bytes.data(), bytes.size(),
476 internal::GetInternalMsg(message),
477 ::protos::internal::GetMiniTable(message),
478 /* extreg= */ nullptr, /* options= */ 0,
479 arena) == kUpb_DecodeStatus_Ok;
480}
481
482template <typename T>
483ABSL_MUST_USE_RESULT bool Parse(
484 Ptr<T> message, absl::string_view bytes,
485 const ::protos::ExtensionRegistry& extension_registry) {
486 static_assert(!std::is_const_v<T>);
487 upb_Message_Clear(internal::GetInternalMsg(message),
488 ::protos::internal::GetMiniTable(message));
489 auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
490 return upb_Decode(bytes.data(), bytes.size(),
491 internal::GetInternalMsg(message),
492 ::protos::internal::GetMiniTable(message),
493 /* extreg= */
494 ::protos::internal::GetUpbExtensions(extension_registry),
495 /* options= */ 0, arena) == kUpb_DecodeStatus_Ok;
496}
497
498template <typename T>
499ABSL_MUST_USE_RESULT bool Parse(
500 T* message, absl::string_view bytes,
501 const ::protos::ExtensionRegistry& extension_registry) {
502 static_assert(!std::is_const_v<T>);
503 return Parse(protos::Ptr(message, bytes, extension_registry));
504}
505
506template <typename T>
507ABSL_MUST_USE_RESULT bool Parse(T* message, absl::string_view bytes) {
508 static_assert(!std::is_const_v<T>);
509 upb_Message_Clear(internal::GetInternalMsg(message),
510 ::protos::internal::GetMiniTable(message));
511 auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
512 return upb_Decode(bytes.data(), bytes.size(),
513 internal::GetInternalMsg(message),
514 ::protos::internal::GetMiniTable(message),
515 /* extreg= */ nullptr, /* options= */ 0,
516 arena) == kUpb_DecodeStatus_Ok;
517}
518
519template <typename T>
520absl::StatusOr<T> Parse(absl::string_view bytes, int options = 0) {
521 T message;
522 auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
523 upb_DecodeStatus status =
524 upb_Decode(bytes.data(), bytes.size(), message.msg(),
525 ::protos::internal::GetMiniTable(&message),
526 /* extreg= */ nullptr, /* options= */ 0, arena);
527 if (status == kUpb_DecodeStatus_Ok) {
528 return message;
529 }
530 return MessageDecodeError(status);
531}
532
533template <typename T>
534absl::StatusOr<T> Parse(absl::string_view bytes,
535 const ::protos::ExtensionRegistry& extension_registry,
536 int options = 0) {
537 T message;
538 auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
539 upb_DecodeStatus status =
540 upb_Decode(bytes.data(), bytes.size(), message.msg(),
541 ::protos::internal::GetMiniTable(&message),
542 ::protos::internal::GetUpbExtensions(extension_registry),
543 /* options= */ 0, arena);
544 if (status == kUpb_DecodeStatus_Ok) {
545 return message;
546 }
547 return MessageDecodeError(status);
548}
549
550template <typename T>
551absl::StatusOr<absl::string_view> Serialize(const T* message, upb::Arena& arena,
552 int options = 0) {
553 return ::protos::internal::Serialize(
554 internal::GetInternalMsg(message),
555 ::protos::internal::GetMiniTable(message), arena.ptr(), options);
556}
557
558template <typename T>
559absl::StatusOr<absl::string_view> Serialize(Ptr<T> message, upb::Arena& arena,
560 int options = 0) {
561 return ::protos::internal::Serialize(
562 internal::GetInternalMsg(message),
563 ::protos::internal::GetMiniTable(message), arena.ptr(), options);
564}
565
566} // namespace protos
567
568#endif // PROTOBUF_HPB_HPB_H_