traced_relay: add tests for the relay service

Bug: 284258446
Change-Id: Ibd9dd3c86d4dc05c1722633865c81cb0ee904eb3
diff --git a/Android.bp b/Android.bp
index 9b53a37..9f5a62b 100644
--- a/Android.bp
+++ b/Android.bp
@@ -2340,6 +2340,8 @@
         ":perfetto_src_traced_probes_statsd_client_statsd_client",
         ":perfetto_src_traced_probes_sys_stats_sys_stats",
         ":perfetto_src_traced_probes_system_info_system_info",
+        ":perfetto_src_traced_relay_integrationtests",
+        ":perfetto_src_traced_relay_lib",
         ":perfetto_src_tracing_client_api_without_backends",
         ":perfetto_src_tracing_common",
         ":perfetto_src_tracing_core_core",
@@ -11937,6 +11939,32 @@
     name: "perfetto_src_traced_probes_unittests",
 }
 
+// GN: //src/traced_relay:integrationtests
+filegroup {
+    name: "perfetto_src_traced_relay_integrationtests",
+    srcs: [
+        "src/traced_relay/relay_service_integrationtest.cc",
+    ],
+}
+
+// GN: //src/traced_relay:lib
+filegroup {
+    name: "perfetto_src_traced_relay_lib",
+    srcs: [
+        "src/traced_relay/relay_service.cc",
+        "src/traced_relay/socket_relay_handler.cc",
+    ],
+}
+
+// GN: //src/traced_relay:unittests
+filegroup {
+    name: "perfetto_src_traced_relay_unittests",
+    srcs: [
+        "src/traced_relay/relay_service_unittest.cc",
+        "src/traced_relay/socket_relay_handler_unittest.cc",
+    ],
+}
+
 // GN: //src/traced/service:service
 filegroup {
     name: "perfetto_src_traced_service_service",
@@ -12693,6 +12721,7 @@
         ":perfetto_src_ipc_client",
         ":perfetto_src_ipc_common",
         ":perfetto_src_ipc_host",
+        ":perfetto_src_ipc_perfetto_ipc",
         ":perfetto_src_ipc_test_messages_cpp_gen",
         ":perfetto_src_ipc_test_messages_ipc_gen",
         ":perfetto_src_ipc_unittests",
@@ -12880,6 +12909,8 @@
         ":perfetto_src_traced_probes_system_info_system_info",
         ":perfetto_src_traced_probes_system_info_unittests",
         ":perfetto_src_traced_probes_unittests",
+        ":perfetto_src_traced_relay_lib",
+        ":perfetto_src_traced_relay_unittests",
         ":perfetto_src_traced_service_service",
         ":perfetto_src_traced_service_unittests",
         ":perfetto_src_tracing_client_api_without_backends",
diff --git a/gn/perfetto_integrationtests.gni b/gn/perfetto_integrationtests.gni
index ff7bb7f..3e65e20 100644
--- a/gn/perfetto_integrationtests.gni
+++ b/gn/perfetto_integrationtests.gni
@@ -49,3 +49,7 @@
   perfetto_integrationtests_targets +=
       [ "src/trace_processor:integrationtests" ]
 }
+
+if (enable_perfetto_traced_relay) {
+  perfetto_integrationtests_targets += [ "src/traced_relay:integrationtests" ]
+}
diff --git a/gn/perfetto_unittests.gni b/gn/perfetto_unittests.gni
index cc51910..91e8d5d 100644
--- a/gn/perfetto_unittests.gni
+++ b/gn/perfetto_unittests.gni
@@ -80,3 +80,7 @@
     perfetto_unittests_targets += [ "src/bigtrace:unittests" ]
   }
 }
+
+if (enable_perfetto_traced_relay) {
+  perfetto_unittests_targets += [ "src/traced_relay:unittests" ]
+}
diff --git a/src/traced_relay/BUILD.gn b/src/traced_relay/BUILD.gn
index d4e27cc..6b49f94 100644
--- a/src/traced_relay/BUILD.gn
+++ b/src/traced_relay/BUILD.gn
@@ -14,6 +14,7 @@
 
 import("../../gn/perfetto.gni")
 import("../../gn/perfetto_component.gni")
+import("../../gn/test.gni")
 
 executable("traced_relay") {
   deps = [
@@ -45,3 +46,33 @@
     "//src/ipc:perfetto_ipc",
   ]
 }
+
+perfetto_unittest_source_set("unittests") {
+  testonly = true
+  deps = [
+    ":lib",
+    "../../gn:default_deps",
+    "../../gn:gtest_and_gmock",
+    "../base",
+    "../base:test_support",
+    "../base/threading",
+    "//src/ipc:perfetto_ipc",
+  ]
+  sources = [
+    "relay_service_unittest.cc",
+    "socket_relay_handler_unittest.cc",
+  ]
+}
+
+source_set("integrationtests") {
+  testonly = true
+  deps = [
+    ":lib",
+    "../../gn:default_deps",
+    "../../gn:gtest_and_gmock",
+    "../../test:test_helper",
+    "../base",
+    "../base:test_support",
+  ]
+  sources = [ "relay_service_integrationtest.cc" ]
+}
diff --git a/src/traced_relay/relay_service_integrationtest.cc b/src/traced_relay/relay_service_integrationtest.cc
new file mode 100644
index 0000000..7e29041
--- /dev/null
+++ b/src/traced_relay/relay_service_integrationtest.cc
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <memory>
+#include "src/traced_relay/relay_service.h"
+
+#include "src/base/test/test_task_runner.h"
+#include "test/gtest_and_gmock.h"
+#include "test/test_helper.h"
+
+#include "protos/perfetto/config/test_config.gen.h"
+#include "protos/perfetto/config/trace_config.gen.h"
+#include "protos/perfetto/trace/test_event.gen.h"
+
+namespace perfetto {
+namespace {
+
+TEST(TracedRelayIntegrationTest, BasicCase) {
+  base::TestTaskRunner task_runner;
+
+  std::string sock_name;
+  {
+    // Set up a server UnixSocket to find an unused TCP port.
+    base::UnixSocket::EventListener event_listener;
+    auto srv = base::UnixSocket::Listen("127.0.0.1:0", &event_listener,
+                                        &task_runner, base::SockFamily::kInet,
+                                        base::SockType::kStream);
+    ASSERT_TRUE(srv->is_listening());
+    sock_name = srv->GetSockAddr();
+    // Shut down |srv| here to free the port. It's unlikely that the port will
+    // be taken by another process so quickly before we reach the code below.
+  }
+
+  TestHelper helper(&task_runner, TestHelper::Mode::kStartDaemons,
+                    sock_name.c_str());
+  ASSERT_EQ(helper.num_producers(), 1u);
+  helper.StartServiceIfRequired();
+
+  auto relay_service = std::make_unique<RelayService>(&task_runner);
+
+  relay_service->Start("@traced_relay", sock_name.c_str());
+
+  auto producer_connected =
+      task_runner.CreateCheckpoint("perfetto.FakeProducer.connected");
+  auto noop = []() {};
+  auto connected = [&]() { task_runner.PostTask(producer_connected); };
+
+  // We won't use the built-in fake producer and will start our own.
+  auto producer_thread = std::make_unique<FakeProducerThread>(
+      "@traced_relay", connected, noop, noop, "perfetto.FakeProducer");
+  producer_thread->Connect();
+  task_runner.RunUntilCheckpoint("perfetto.FakeProducer.connected");
+
+  helper.ConnectConsumer();
+  helper.WaitForConsumerConnect();
+
+  TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(1024);
+  trace_config.set_duration_ms(200);
+
+  static constexpr uint32_t kMsgSize = 1024;
+  static constexpr uint32_t kRandomSeed = 42;
+  // Enable the producer.
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("perfetto.FakeProducer");
+  ds_config->set_target_buffer(0);
+  ds_config->mutable_for_testing()->set_seed(kRandomSeed);
+  ds_config->mutable_for_testing()->set_message_count(12);
+  ds_config->mutable_for_testing()->set_message_size(kMsgSize);
+  ds_config->mutable_for_testing()->set_send_batch_on_register(true);
+
+  helper.StartTracing(trace_config);
+  helper.WaitForTracingDisabled();
+
+  helper.ReadData();
+  helper.WaitForReadData();
+
+  const auto& packets = helper.trace();
+  ASSERT_EQ(packets.size(), 12u);
+
+  // The producer is connected from this process. The relay service will inject
+  // the SetPeerIdentity message using the pid and euid of the current process.
+  auto pid = static_cast<int32_t>(getpid());
+  auto uid = static_cast<int32_t>(geteuid());
+
+  std::minstd_rand0 rnd_engine(kRandomSeed);
+  for (const auto& packet : packets) {
+    ASSERT_TRUE(packet.has_for_testing());
+    ASSERT_EQ(packet.trusted_pid(), pid);
+    ASSERT_EQ(packet.trusted_uid(), uid);
+    ASSERT_EQ(packet.for_testing().seq_value(), rnd_engine());
+  }
+}
+
+}  // namespace
+}  // namespace perfetto
diff --git a/src/traced_relay/relay_service_unittest.cc b/src/traced_relay/relay_service_unittest.cc
new file mode 100644
index 0000000..649f552
--- /dev/null
+++ b/src/traced_relay/relay_service_unittest.cc
@@ -0,0 +1,130 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/traced_relay/relay_service.h"
+
+#include <memory>
+
+#include "perfetto/ext/base/unix_socket.h"
+#include "protos/perfetto/ipc/wire_protocol.gen.h"
+#include "src/base/test/test_task_runner.h"
+#include "src/ipc/buffered_frame_deserializer.h"
+#include "test/gtest_and_gmock.h"
+
+namespace perfetto {
+namespace {
+
+using ::testing::_;
+using ::testing::Invoke;
+
+class TestEventListener : public base::UnixSocket::EventListener {
+ public:
+  MOCK_METHOD(void, OnDataAvailable, (base::UnixSocket*), (override));
+  MOCK_METHOD(void, OnConnect, (base::UnixSocket*, bool), (override));
+  MOCK_METHOD(void, OnNewIncomingConnection, (base::UnixSocket*));
+
+  void OnNewIncomingConnection(
+      base::UnixSocket*,
+      std::unique_ptr<base::UnixSocket> new_connection) override {
+    // Need to keep |new_connection| alive.
+    client_connection_ = std::move(new_connection);
+    OnNewIncomingConnection(client_connection_.get());
+  }
+
+ private:
+  std::unique_ptr<base::UnixSocket> client_connection_;
+};
+
+// Exercises the relay service and also validates that the relay service injects
+// a SetPeerIdentity message:
+//
+// producer (client UnixSocket) <- @producer.sock -> relay service
+// <- 127.0.0.1.* -> tcp_server (listening UnixSocet).
+TEST(RelayServiceTest, SetPeerIdentity) {
+  base::TestTaskRunner task_runner;
+  auto relay_service = std::make_unique<RelayService>(&task_runner);
+
+  // Set up a server UnixSocket to find an unused TCP port.
+  // The TCP connection emulates the socket to the host traced.
+  TestEventListener tcp_listener;
+  auto tcp_server = base::UnixSocket::Listen(
+      "127.0.0.1:0", &tcp_listener, &task_runner, base::SockFamily::kInet,
+      base::SockType::kStream);
+  ASSERT_TRUE(tcp_server->is_listening());
+  auto tcp_sock_name = tcp_server->GetSockAddr();
+  auto* unix_sock_name =
+      "@producer.sock";  // Use abstract unix socket for server socket.
+
+  // Start the relay service.
+  relay_service->Start(unix_sock_name, tcp_sock_name.c_str());
+
+  // Emulates the producer connection.
+  TestEventListener producer_listener;
+  auto producer = base::UnixSocket::Connect(
+      unix_sock_name, &producer_listener, &task_runner, base::SockFamily::kUnix,
+      base::SockType::kStream);
+  auto producer_connected = task_runner.CreateCheckpoint("producer_connected");
+  EXPECT_CALL(producer_listener, OnConnect(_, _))
+      .WillOnce(Invoke([&](base::UnixSocket* s, bool conn) {
+        EXPECT_TRUE(conn);
+        EXPECT_EQ(s, producer.get());
+        producer_connected();
+      }));
+  task_runner.RunUntilCheckpoint("producer_connected");
+
+  // Add some producer data.
+  ipc::Frame test_frame;
+  test_frame.add_data_for_testing("test_data");
+  auto test_data = ipc::BufferedFrameDeserializer::Serialize(test_frame);
+  producer->SendStr(test_data);
+
+  base::UnixSocket* tcp_client_connection = nullptr;
+  auto tcp_client_connected =
+      task_runner.CreateCheckpoint("tcp_client_connected");
+  EXPECT_CALL(tcp_listener, OnNewIncomingConnection(_))
+      .WillOnce(Invoke([&](base::UnixSocket* client) {
+        tcp_client_connection = client;
+        tcp_client_connected();
+      }));
+  task_runner.RunUntilCheckpoint("tcp_client_connected");
+
+  // Asserts that we can receive the SetPeerIdentity message.
+  auto peer_identity_recv = task_runner.CreateCheckpoint("peer_identity_recv");
+  ipc::BufferedFrameDeserializer deserializer;
+  EXPECT_CALL(tcp_listener, OnDataAvailable(_))
+      .WillRepeatedly(Invoke([&](base::UnixSocket* tcp_conn) {
+        auto buf = deserializer.BeginReceive();
+        auto rsize = tcp_conn->Receive(buf.data, buf.size);
+        EXPECT_TRUE(deserializer.EndReceive(rsize));
+
+        auto frame = deserializer.PopNextFrame();
+        EXPECT_TRUE(frame->has_set_peer_identity());
+
+        const auto& set_peer_identity = frame->set_peer_identity();
+        EXPECT_EQ(set_peer_identity.pid(), getpid());
+        EXPECT_EQ(set_peer_identity.uid(), static_cast<int32_t>(geteuid()));
+
+        frame = deserializer.PopNextFrame();
+        EXPECT_EQ(1u, frame->data_for_testing().size());
+        EXPECT_EQ(std::string("test_data"), frame->data_for_testing()[0]);
+
+        peer_identity_recv();
+      }));
+  task_runner.RunUntilCheckpoint("peer_identity_recv");
+}
+
+}  // namespace
+}  // namespace perfetto
diff --git a/src/traced_relay/socket_relay_handler_unittest.cc b/src/traced_relay/socket_relay_handler_unittest.cc
new file mode 100644
index 0000000..a96308a
--- /dev/null
+++ b/src/traced_relay/socket_relay_handler_unittest.cc
@@ -0,0 +1,204 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/traced_relay/socket_relay_handler.h"
+
+#include <chrono>
+#include <cstring>
+#include <memory>
+#include <random>
+#include <string>
+#include <thread>
+#include <utility>
+
+#include "perfetto/ext/base/threading/thread_pool.h"
+#include "perfetto/ext/base/unix_socket.h"
+
+#include "test/gtest_and_gmock.h"
+
+using testing::Values;
+
+namespace perfetto {
+namespace {
+
+using RawSocketPair = std::pair<base::UnixSocketRaw, base::UnixSocketRaw>;
+using RngValueType = std::minstd_rand0::result_type;
+
+struct TestClient {
+  RawSocketPair endpoint_sockets;
+  std::minstd_rand0 data_prng;
+  std::thread client_thread;
+};
+
+class SocketRelayHandlerTest : public ::testing::TestWithParam<uint32_t> {
+ protected:
+  void SetUp() override {
+    socket_relay_handler_ = std::make_unique<SocketRelayHandler>();
+
+    for (uint32_t i = 0; i < GetParam(); i++) {
+      TestClient client{SetUpEndToEndSockets(), std::minstd_rand0(i), {}};
+      test_clients_.push_back(std::move(client));
+    }
+  }
+  void TearDown() override { socket_relay_handler_ = nullptr; }
+
+  RawSocketPair SetUpEndToEndSockets() {
+    // Creates 2 SocketPairs:
+    // sock1 <-> sock2 <-> SocketRelayHandler <-> sock3 <-> sock4.
+    // sock2 and sock3 are transferred to the SocketRelayHandler.
+    // We test by reading and writing bidirectionally using sock1 and sock4.
+    auto [sock1, sock2] = base::UnixSocketRaw::CreatePairPosix(
+        base::SockFamily::kUnix, base::SockType::kStream);
+    sock2.SetBlocking(false);
+
+    auto [sock3, sock4] = base::UnixSocketRaw::CreatePairPosix(
+        base::SockFamily::kUnix, base::SockType::kStream);
+    sock3.SetBlocking(false);
+
+    auto socket_pair = std::make_unique<SocketPair>();
+    socket_pair->first.sock = std::move(sock2);
+    socket_pair->second.sock = std::move(sock3);
+
+    socket_relay_handler_->AddSocketPair(std::move(socket_pair));
+
+    RawSocketPair endpoint_sockets;
+    endpoint_sockets.first = std::move(sock1);
+    endpoint_sockets.second = std::move(sock4);
+
+    return endpoint_sockets;
+  }
+
+  std::unique_ptr<SocketRelayHandler> socket_relay_handler_;
+  std::vector<TestClient> test_clients_;
+  // Use fewer receiver threads than sender threads.
+  base::ThreadPool receiver_thread_pool_{1 + GetParam() / 10};
+};
+
+TEST(SocketWithBufferTest, EnqueueDequeue) {
+  SocketWithBuffer socket_with_buffer;
+  // No data initially.
+  EXPECT_EQ(0u, socket_with_buffer.data_size());
+
+  // Has room for writing some bytes into.
+  std::string data = "12345678901234567890";
+  EXPECT_GT(socket_with_buffer.available_bytes(), data.size());
+
+  memcpy(socket_with_buffer.buffer(), data.data(), data.size());
+  socket_with_buffer.EnqueueData(data.size());
+  EXPECT_EQ(data.size(), socket_with_buffer.data_size());
+
+  // Dequeue some bytes.
+  socket_with_buffer.DequeueData(5);
+  EXPECT_EQ(socket_with_buffer.data_size(), data.size() - 5);
+  std::string buffered_data(reinterpret_cast<char*>(socket_with_buffer.data()),
+                            socket_with_buffer.data_size());
+  EXPECT_EQ(buffered_data, "678901234567890");
+}
+
+// Test the SocketRelayHander with randomized request and response data.
+TEST_P(SocketRelayHandlerTest, RandomizedRequestResponse) {
+  // The max message size in the number of RNG calls.
+  constexpr size_t kMaxMsgSizeRng = 1 << 20;
+
+  // Create the threads for sending and receiving data through the
+  // SocketRelayHandler.
+  for (auto& client : test_clients_) {
+    auto* thread_pool = &receiver_thread_pool_;
+
+    auto thread_func = [&client, thread_pool]() {
+      auto& rng = client.data_prng;
+
+      // The max number of requests.
+      const size_t num_requests = rng() % 50;
+
+      for (size_t j = 0; j < num_requests; j++) {
+        auto& send_endpoint = client.endpoint_sockets.first;
+        auto& receive_endpoint = client.endpoint_sockets.second;
+
+        auto req_size = rng() % kMaxMsgSizeRng;
+
+        // Generate the random request.
+        std::vector<RngValueType> request;
+        request.reserve(req_size);
+        for (size_t r = 0; r < req_size; r++) {
+          request.emplace_back(rng());
+        }
+
+        // Create a buffer for receiving the request.
+        std::vector<RngValueType> received_request(request.size());
+
+        std::mutex mutex;
+        std::condition_variable cv;
+        std::unique_lock<std::mutex> lock(mutex);
+        bool done = false;
+
+        // Blocking receive on the thread pool.
+        thread_pool->PostTask([&]() {
+          const size_t bytes_to_receive =
+              received_request.size() * sizeof(RngValueType);
+          uint8_t* receive_buffer =
+              reinterpret_cast<uint8_t*>(received_request.data());
+          size_t bytes_received = 0;
+
+          // Perform a blocking read until we received the expected bytes.
+          while (bytes_received < bytes_to_receive) {
+            ssize_t rsize = PERFETTO_EINTR(
+                receive_endpoint.Receive(receive_buffer + bytes_received,
+                                         bytes_to_receive - bytes_received));
+            if (rsize <= 0)
+              break;
+            bytes_received += static_cast<size_t>(rsize);
+
+            std::this_thread::yield();  // Adds some scheduling randomness.
+          }
+
+          std::lock_guard<std::mutex> inner_lock(mutex);
+          done = true;
+          cv.notify_one();
+        });
+
+        // Perform a blocking send of the request data.
+        PERFETTO_EINTR(send_endpoint.Send(
+            request.data(), request.size() * sizeof(RngValueType)));
+
+        // Wait until the request is fully received.
+        cv.wait(lock, [&done] { return done; });
+
+        // Check data integrity.
+        EXPECT_EQ(request, received_request);
+
+        // Add some randomness to timing.
+        std::this_thread::sleep_for(std::chrono::microseconds(rng() % 1000));
+
+        // Emulate the response by reversing the data flow direction.
+        std::swap(send_endpoint, receive_endpoint);
+      }
+    };
+
+    client.client_thread = std::thread(std::move(thread_func));
+  }
+
+  for (auto& client : test_clients_) {
+    client.client_thread.join();
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(ByConnections,
+                         SocketRelayHandlerTest,
+                         Values(1, 5, 50));
+
+}  // namespace
+}  // namespace perfetto