tracing_service_impl_unittest: Avoid internals in ScrapeOnDisconnect

This test is trying to simulate what happens with a remote producer when
it disconnects. The test however uses an in-process producer.

An in-process producer behaves quite differently (compared to a remote
producer). The shared memory arbiter (and the shared memory itself) is
owned by the producer endpoint. There's no way of disconnecting an
in-process producer without destroying its shared memory arbiter. But we
cannot destroy the shared memory arbiter, because the trace writer holds
references to it. We cannot destroy the trace writer, because it would
flush, and scraping will not be exercised.

This commit makes the test more realisting by using a "remote" producer:
* Producer `in_process` can be set to false.
* TestRefSharedMemory is a new type of shared memory that can be used to
  simulate two components having access to the same shared memory in
  tests.
* In order to simulate a disconnection, a ProxyProducerEndpoint can be
  configured to forward all the requests to a real proxy or to drop
  them, after disconnecting.

The test can now own the shared memory arbiter. TracingServiceImpl will
behave as it does with a remote producer and scrape its memory.

+ The production code doesn't contain test only code.
- There's a lot more test-only code required.
+ The test is more realistic.

Change-Id: I61121c6623f912c167876602e93ead39cd087c52
diff --git a/Android.bp b/Android.bp
index 0efdda3..e2f143f 100644
--- a/Android.bp
+++ b/Android.bp
@@ -14932,6 +14932,8 @@
         "src/tracing/test/fake_packet.cc",
         "src/tracing/test/mock_consumer.cc",
         "src/tracing/test/mock_producer.cc",
+        "src/tracing/test/proxy_producer_endpoint.cc",
+        "src/tracing/test/test_shared_memory.cc",
         "src/tracing/test/traced_value_test_support.cc",
     ],
 }
diff --git a/src/tracing/core/trace_writer_impl.h b/src/tracing/core/trace_writer_impl.h
index d72f15b..3ee4d01 100644
--- a/src/tracing/core/trace_writer_impl.h
+++ b/src/tracing/core/trace_writer_impl.h
@@ -67,10 +67,6 @@
     return protobuf_stream_writer_.written();
   }
 
-  void ResetChunkForTesting() {
-    cur_chunk_ = SharedMemoryABI::Chunk();
-    cur_chunk_packet_count_inflated_ = false;
-  }
   bool drop_packets_for_testing() const { return drop_packets_; }
 
  private:
diff --git a/src/tracing/service/tracing_service_impl_unittest.cc b/src/tracing/service/tracing_service_impl_unittest.cc
index 138f488..9032399 100644
--- a/src/tracing/service/tracing_service_impl_unittest.cc
+++ b/src/tracing/service/tracing_service_impl_unittest.cc
@@ -64,6 +64,7 @@
 #include "src/tracing/core/trace_writer_impl.h"
 #include "src/tracing/test/mock_consumer.h"
 #include "src/tracing/test/mock_producer.h"
+#include "src/tracing/test/proxy_producer_endpoint.h"
 #include "src/tracing/test/test_shared_memory.h"
 #include "test/gtest_and_gmock.h"
 
@@ -274,11 +275,6 @@
     return svc->GetProducer(producer_id)->inproc_shmem_arbiter_.get();
   }
 
-  std::unique_ptr<SharedMemoryArbiterImpl> StealShmemArbiterForProducer(
-      ProducerID producer_id) {
-    return std::move(svc->GetProducer(producer_id)->inproc_shmem_arbiter_);
-  }
-
   void SetTriggerWindowNs(int64_t window_ns) {
     svc->trigger_window_ns_ = window_ns;
   }
@@ -3362,8 +3358,19 @@
   consumer->Connect(svc.get());
 
   std::unique_ptr<MockProducer> producer = CreateMockProducer();
-  producer->Connect(svc.get(), "mock_producer");
-  ProducerID producer_id = *last_producer_id();
+
+  static constexpr size_t kShmSizeBytes = 1024 * 1024;
+  static constexpr size_t kShmPageSizeBytes = 4 * 1024;
+
+  TestSharedMemory::Factory factory;
+  auto shm = factory.CreateSharedMemory(kShmSizeBytes);
+
+  // Service should adopt the SMB provided by the producer.
+  producer->Connect(svc.get(), "mock_producer", /*uid=*/42, /*pid=*/1025,
+                    /*shared_memory_size_hint_bytes=*/0, kShmPageSizeBytes,
+                    TestRefSharedMemory::Create(shm.get()),
+                    /*in_process=*/false);
+
   producer->RegisterDataSource("data_source");
 
   TraceConfig trace_config;
@@ -3377,9 +3384,19 @@
   producer->WaitForDataSourceSetup("data_source");
   producer->WaitForDataSourceStart("data_source");
 
-  std::unique_ptr<TraceWriter> writer = producer->endpoint()->CreateTraceWriter(
-      tracing_session()->buffers_index[0]);
-  // Wait for TraceWriter to be registered.
+  auto client_producer_endpoint = std::make_unique<ProxyProducerEndpoint>();
+  client_producer_endpoint->set_backend(producer->endpoint());
+
+  auto shmem_arbiter = std::make_unique<SharedMemoryArbiterImpl>(
+      shm->start(), shm->size(), SharedMemoryABI::ShmemMode::kDefault,
+      kShmPageSizeBytes, client_producer_endpoint.get(), &task_runner);
+  shmem_arbiter->SetDirectSMBPatchingSupportedByService();
+
+  const auto* ds_inst = producer->GetDataSourceInstance("data_source");
+  ASSERT_NE(nullptr, ds_inst);
+  std::unique_ptr<TraceWriter> writer =
+      shmem_arbiter->CreateTraceWriter(ds_inst->target_buffer);
+  // Wait for the TraceWriter to be registered.
   task_runner.RunUntilIdle();
 
   // Write a few trace packets.
@@ -3388,9 +3405,8 @@
   writer->NewTracePacket()->set_for_testing()->set_str("payload3");
 
   // Disconnect the producer without committing the chunk. This should cause a
-  // scrape of the SMB. Avoid destroying the ShmemArbiter until writer is
-  // destroyed.
-  auto shmem_arbiter = StealShmemArbiterForProducer(producer_id);
+  // scrape of the SMB.
+  client_producer_endpoint->set_backend(nullptr);
   producer.reset();
 
   // Chunk with the packets should have been scraped.
@@ -3405,9 +3421,6 @@
                                          Property(&protos::gen::TestEvent::str,
                                                   Eq("payload3")))));
 
-  // Cleanup writer without causing a crash because the producer already went
-  // away.
-  static_cast<TraceWriterImpl*>(writer.get())->ResetChunkForTesting();
   writer.reset();
   shmem_arbiter.reset();
 
diff --git a/src/tracing/test/BUILD.gn b/src/tracing/test/BUILD.gn
index 47038c2..3acedef 100644
--- a/src/tracing/test/BUILD.gn
+++ b/src/tracing/test/BUILD.gn
@@ -41,6 +41,7 @@
     "aligned_buffer_test.h",
     "fake_packet.cc",
     "fake_packet.h",
+    "test_shared_memory.cc",
     "test_shared_memory.h",
     "traced_value_test_support.cc",
   ]
@@ -54,6 +55,8 @@
       "mock_producer.cc",
       "mock_producer.h",
       "mock_producer_endpoint.h",
+      "proxy_producer_endpoint.cc",
+      "proxy_producer_endpoint.h",
     ]
   }
 }
diff --git a/src/tracing/test/mock_producer.cc b/src/tracing/test/mock_producer.cc
index 501afc3..ac48a29 100644
--- a/src/tracing/test/mock_producer.cc
+++ b/src/tracing/test/mock_producer.cc
@@ -73,13 +73,15 @@
                            pid_t pid,
                            size_t shared_memory_size_hint_bytes,
                            size_t shared_memory_page_size_hint_bytes,
-                           std::unique_ptr<SharedMemory> shm) {
+                           std::unique_ptr<SharedMemory> shm,
+                           bool in_process) {
   producer_name_ = producer_name;
-  service_endpoint_ = svc->ConnectProducer(
-      this, ClientIdentity(uid, pid), producer_name,
-      shared_memory_size_hint_bytes,
-      /*in_process=*/true, TracingService::ProducerSMBScrapingMode::kDefault,
-      shared_memory_page_size_hint_bytes, std::move(shm));
+  service_endpoint_ =
+      svc->ConnectProducer(this, ClientIdentity(uid, pid), producer_name,
+                           shared_memory_size_hint_bytes,
+                           /*in_process=*/in_process,
+                           TracingService::ProducerSMBScrapingMode::kDefault,
+                           shared_memory_page_size_hint_bytes, std::move(shm));
   auto checkpoint_name = "on_producer_connect_" + producer_name;
   auto on_connect = task_runner_->CreateCheckpoint(checkpoint_name);
   EXPECT_CALL(*this, OnConnect()).WillOnce(Invoke(on_connect));
diff --git a/src/tracing/test/mock_producer.h b/src/tracing/test/mock_producer.h
index f3fab40..de33bee 100644
--- a/src/tracing/test/mock_producer.h
+++ b/src/tracing/test/mock_producer.h
@@ -51,7 +51,8 @@
                pid_t pid = 1025,
                size_t shared_memory_size_hint_bytes = 0,
                size_t shared_memory_page_size_hint_bytes = 0,
-               std::unique_ptr<SharedMemory> shm = nullptr);
+               std::unique_ptr<SharedMemory> shm = nullptr,
+               bool in_process = true);
   void RegisterDataSource(const std::string& name,
                           bool ack_stop = false,
                           bool ack_start = false,
diff --git a/src/tracing/test/proxy_producer_endpoint.cc b/src/tracing/test/proxy_producer_endpoint.cc
new file mode 100644
index 0000000..31dca1e
--- /dev/null
+++ b/src/tracing/test/proxy_producer_endpoint.cc
@@ -0,0 +1,134 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/tracing/test/proxy_producer_endpoint.h"
+
+#include "perfetto/ext/tracing/core/trace_writer.h"
+
+namespace perfetto {
+
+ProxyProducerEndpoint::~ProxyProducerEndpoint() = default;
+
+void ProxyProducerEndpoint::Disconnect() {
+  if (!backend_) {
+    return;
+  }
+  backend_->Disconnect();
+}
+void ProxyProducerEndpoint::RegisterDataSource(
+    const DataSourceDescriptor& dsd) {
+  if (!backend_) {
+    return;
+  }
+  backend_->RegisterDataSource(dsd);
+}
+void ProxyProducerEndpoint::UpdateDataSource(const DataSourceDescriptor& dsd) {
+  if (!backend_) {
+    return;
+  }
+  backend_->UpdateDataSource(dsd);
+}
+void ProxyProducerEndpoint::UnregisterDataSource(const std::string& name) {
+  if (!backend_) {
+    return;
+  }
+  backend_->UnregisterDataSource(name);
+}
+void ProxyProducerEndpoint::RegisterTraceWriter(uint32_t writer_id,
+                                                uint32_t target_buffer) {
+  if (!backend_) {
+    return;
+  }
+  backend_->RegisterTraceWriter(writer_id, target_buffer);
+}
+void ProxyProducerEndpoint::UnregisterTraceWriter(uint32_t writer_id) {
+  if (!backend_) {
+    return;
+  }
+  backend_->UnregisterTraceWriter(writer_id);
+}
+void ProxyProducerEndpoint::CommitData(const CommitDataRequest& req,
+                                       CommitDataCallback callback) {
+  if (!backend_) {
+    return;
+  }
+  backend_->CommitData(req, callback);
+}
+SharedMemory* ProxyProducerEndpoint::shared_memory() const {
+  if (!backend_) {
+    return nullptr;
+  }
+  return backend_->shared_memory();
+}
+size_t ProxyProducerEndpoint::shared_buffer_page_size_kb() const {
+  if (!backend_) {
+    return 0;
+  }
+  return backend_->shared_buffer_page_size_kb();
+}
+std::unique_ptr<TraceWriter> ProxyProducerEndpoint::CreateTraceWriter(
+    BufferID target_buffer,
+    BufferExhaustedPolicy buffer_exhausted_policy) {
+  if (!backend_) {
+    return nullptr;
+  }
+  return backend_->CreateTraceWriter(target_buffer, buffer_exhausted_policy);
+}
+SharedMemoryArbiter* ProxyProducerEndpoint::MaybeSharedMemoryArbiter() {
+  if (!backend_) {
+    return nullptr;
+  }
+  return backend_->MaybeSharedMemoryArbiter();
+}
+bool ProxyProducerEndpoint::IsShmemProvidedByProducer() const {
+  if (!backend_) {
+    return false;
+  }
+  return backend_->IsShmemProvidedByProducer();
+}
+void ProxyProducerEndpoint::NotifyFlushComplete(FlushRequestID id) {
+  if (!backend_) {
+    return;
+  }
+  backend_->NotifyFlushComplete(id);
+}
+void ProxyProducerEndpoint::NotifyDataSourceStarted(DataSourceInstanceID id) {
+  if (!backend_) {
+    return;
+  }
+  backend_->NotifyDataSourceStarted(id);
+}
+void ProxyProducerEndpoint::NotifyDataSourceStopped(DataSourceInstanceID id) {
+  if (!backend_) {
+    return;
+  }
+  backend_->NotifyDataSourceStopped(id);
+}
+void ProxyProducerEndpoint::ActivateTriggers(
+    const std::vector<std::string>& triggers) {
+  if (!backend_) {
+    return;
+  }
+  backend_->ActivateTriggers(triggers);
+}
+void ProxyProducerEndpoint::Sync(std::function<void()> callback) {
+  if (!backend_) {
+    return;
+  }
+  backend_->Sync(callback);
+}
+
+}  // namespace perfetto
diff --git a/src/tracing/test/proxy_producer_endpoint.h b/src/tracing/test/proxy_producer_endpoint.h
new file mode 100644
index 0000000..708151f
--- /dev/null
+++ b/src/tracing/test/proxy_producer_endpoint.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_TRACING_TEST_PROXY_PRODUCER_ENDPOINT_H_
+#define SRC_TRACING_TEST_PROXY_PRODUCER_ENDPOINT_H_
+
+#include "perfetto/ext/tracing/core/tracing_service.h"
+
+namespace perfetto {
+
+// A "proxy" ProducerEndpoint that forwards all the requests to a real
+// (`backend_`) ProducerEndpoint endpoint or drops them if (`backend_`) is
+// nullptr.
+class ProxyProducerEndpoint : public ProducerEndpoint {
+ public:
+  ~ProxyProducerEndpoint() override;
+
+  // `backend` is not owned.
+  void set_backend(ProducerEndpoint* backend) { backend_ = backend; }
+
+  ProducerEndpoint* backend() const { return backend_; }
+
+  // Begin ProducerEndpoint implementation
+  void Disconnect() override;
+  void RegisterDataSource(const DataSourceDescriptor&) override;
+  void UpdateDataSource(const DataSourceDescriptor&) override;
+  void UnregisterDataSource(const std::string& name) override;
+  void RegisterTraceWriter(uint32_t writer_id, uint32_t target_buffer) override;
+  void UnregisterTraceWriter(uint32_t writer_id) override;
+  void CommitData(const CommitDataRequest&,
+                  CommitDataCallback callback = {}) override;
+  SharedMemory* shared_memory() const override;
+  size_t shared_buffer_page_size_kb() const override;
+  std::unique_ptr<TraceWriter> CreateTraceWriter(
+      BufferID target_buffer,
+      BufferExhaustedPolicy buffer_exhausted_policy =
+          BufferExhaustedPolicy::kDefault) override;
+  SharedMemoryArbiter* MaybeSharedMemoryArbiter() override;
+  bool IsShmemProvidedByProducer() const override;
+  void NotifyFlushComplete(FlushRequestID) override;
+  void NotifyDataSourceStarted(DataSourceInstanceID) override;
+  void NotifyDataSourceStopped(DataSourceInstanceID) override;
+  void ActivateTriggers(const std::vector<std::string>&) override;
+  void Sync(std::function<void()> callback) override;
+  // End ProducerEndpoint implementation
+
+ private:
+  ProducerEndpoint* backend_ = nullptr;
+};
+
+}  // namespace perfetto
+
+#endif  // SRC_TRACING_TEST_PROXY_PRODUCER_ENDPOINT_H_
diff --git a/src/tracing/test/test_shared_memory.cc b/src/tracing/test/test_shared_memory.cc
new file mode 100644
index 0000000..509c8e0
--- /dev/null
+++ b/src/tracing/test/test_shared_memory.cc
@@ -0,0 +1,30 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/tracing/test/test_shared_memory.h"
+
+namespace perfetto {
+
+TestRefSharedMemory::~TestRefSharedMemory() = default;
+
+const void* TestRefSharedMemory::start() const {
+  return start_;
+}
+size_t TestRefSharedMemory::size() const {
+  return size_;
+}
+
+}  // namespace perfetto
diff --git a/src/tracing/test/test_shared_memory.h b/src/tracing/test/test_shared_memory.h
index e932ad6..2c62b9b 100644
--- a/src/tracing/test/test_shared_memory.h
+++ b/src/tracing/test/test_shared_memory.h
@@ -21,7 +21,6 @@
 
 #include <memory>
 
-#include "perfetto/ext/base/paged_memory.h"
 #include "perfetto/ext/tracing/core/shared_memory.h"
 #include "src/tracing/core/in_process_shared_memory.h"
 
@@ -31,6 +30,33 @@
 // (just a wrapper around malloc() that fits the SharedMemory API).
 using TestSharedMemory = InProcessSharedMemory;
 
+// An implementation of the SharedMemory that doesn't own any memory, but just
+// points to memory owned by another SharedMemory.
+//
+// This is useful to test two components that own separate SharedMemory that
+// really point to the same memory underneath without setting up real posix
+// shared memory.
+class TestRefSharedMemory : public SharedMemory {
+ public:
+  // N.B. `*mem` must outlive `*this`.
+  explicit TestRefSharedMemory(SharedMemory* mem)
+      : start_(mem->start()), size_(mem->size()) {}
+  ~TestRefSharedMemory() override;
+
+  static std::unique_ptr<TestRefSharedMemory> Create(SharedMemory* mem) {
+    return std::make_unique<TestRefSharedMemory>(mem);
+  }
+
+  // SharedMemory implementation.
+  using SharedMemory::start;  // Equal priority to const and non-const versions
+  const void* start() const override;
+  size_t size() const override;
+
+ private:
+  void* start_;
+  size_t size_;
+};
+
 }  // namespace perfetto
 
 #endif  // SRC_TRACING_TEST_TEST_SHARED_MEMORY_H_