Merge "Extract base::UnixSocketRaw"
diff --git a/include/perfetto/base/unix_socket.h b/include/perfetto/base/unix_socket.h
index 4a613a0..88b3c8d 100644
--- a/include/perfetto/base/unix_socket.h
+++ b/include/perfetto/base/unix_socket.h
@@ -28,46 +28,87 @@
#include "perfetto/base/utils.h"
#include "perfetto/base/weak_ptr.h"
-#include <sys/socket.h>
-#include <sys/un.h>
+struct msghdr;
namespace perfetto {
namespace base {
class TaskRunner;
-ssize_t SockSend(int fd,
- const void* msg,
- size_t len,
- const int* send_fds,
- size_t num_fds);
+// Use arbitrarily high values to avoid that some code accidentally ends up
+// assuming that these enum values match the sysroot's SOCK_xxx defines rather
+// than using GetUnixSockType().
+enum class SockType { kStream = 100, kDgram, kSeqPacket };
-ssize_t SockReceive(int fd,
- void* msg,
- size_t len,
- base::ScopedFile* fd_vec,
- size_t max_files);
+// UnixSocketRaw is a basic wrapper around UNIX sockets. It exposes wrapper
+// methods that take care of most common pitfalls (e.g., marking fd as
+// O_CLOEXEC, avoiding SIGPIPE, properly handling partial writes). It is used as
+// a building block for the more sophisticated UnixSocket class.
+class UnixSocketRaw {
+ public:
+ // Creates a new unconnected unix socket.
+ static UnixSocketRaw CreateMayFail(SockType t) { return UnixSocketRaw(t); }
-bool MakeSockAddr(const std::string& socket_name,
- sockaddr_un* addr,
- socklen_t* addr_size);
+ // Crates a pair of connected sockets.
+ static std::pair<UnixSocketRaw, UnixSocketRaw> CreatePair(SockType);
-base::ScopedFile CreateSocket();
+ // Creates an uninitialized unix socket.
+ UnixSocketRaw();
-// Update msghdr so subsequent sendmsg will send data that remains after n bytes
-// have already been sent.
-// This should not be used, it's exported for test use only.
-void ShiftMsgHdr(size_t n, struct msghdr* msg);
+ // Creates a unix socket adopting an existing file descriptor. This is
+ // typically used to inherit fds from init via environment variables.
+ UnixSocketRaw(ScopedFile, SockType);
-// Re-enter sendmsg until all the data has been sent or an error occurs.
-//
-// TODO(fmayer): Figure out how to do timeouts here for heapprofd.
-ssize_t SendMsgAll(int sockfd, struct msghdr* msg, int flags);
+ ~UnixSocketRaw() = default;
+ UnixSocketRaw(UnixSocketRaw&&) noexcept = default;
+ UnixSocketRaw& operator=(UnixSocketRaw&&) = default;
-// A non-blocking UNIX domain socket in SOCK_STREAM mode. Allows also to
-// transfer file descriptors. None of the methods in this class are blocking.
-// The main design goal is API simplicity and strong guarantees on the
-// EventListener callbacks, in order to avoid ending in some undefined state.
+ bool Bind(const std::string& socket_name);
+ bool Listen();
+ bool Connect(const std::string& socket_name);
+ bool SetTxTimeout(uint32_t timeout_ms);
+ void Shutdown();
+ void SetBlocking(bool);
+ bool IsBlocking() const;
+ SockType type() const { return type_; }
+ int fd() const { return *fd_; }
+ explicit operator bool() const { return !!fd_; }
+
+ ScopedFile ReleaseFd() { return std::move(fd_); }
+
+ ssize_t Send(const void* msg,
+ size_t len,
+ const int* send_fds = nullptr,
+ size_t num_fds = 0);
+
+ // Re-enter sendmsg until all the data has been sent or an error occurs.
+ // TODO(fmayer): Figure out how to do timeouts here for heapprofd.
+ ssize_t SendMsgAll(struct msghdr* msg);
+
+ ssize_t Receive(void* msg,
+ size_t len,
+ ScopedFile* fd_vec = nullptr,
+ size_t max_files = 0);
+
+ // Exposed for testing only.
+ // Update msghdr so subsequent sendmsg will send data that remains after n
+ // bytes have already been sent.
+ static void ShiftMsgHdr(size_t n, struct msghdr* msg);
+
+ private:
+ explicit UnixSocketRaw(SockType);
+
+ UnixSocketRaw(const UnixSocketRaw&) = delete;
+ UnixSocketRaw& operator=(const UnixSocketRaw&) = delete;
+
+ ScopedFile fd_;
+ SockType type_{SockType::kStream};
+};
+
+// A non-blocking UNIX domain socket. Allows also to transfer file descriptors.
+// None of the methods in this class are blocking.
+// The main design goal is making strong guarantees on the EventListener
+// callbacks, in order to avoid ending in some undefined state.
// In case of any error it will aggressively just shut down the socket and
// notify the failure with OnConnect(false) or OnDisconnect() depending on the
// state of the socket (see below).
@@ -143,26 +184,23 @@
// is_listening() == false and last_error() will contain the failure reason.
static std::unique_ptr<UnixSocket> Listen(const std::string& socket_name,
EventListener*,
- base::TaskRunner*);
+ TaskRunner*,
+ SockType = SockType::kStream);
// Attaches to a pre-existing socket. The socket must have been created in
// SOCK_STREAM mode and the caller must have called bind() on it.
- static std::unique_ptr<UnixSocket> Listen(base::ScopedFile socket_fd,
+ static std::unique_ptr<UnixSocket> Listen(ScopedFile,
EventListener*,
- base::TaskRunner*);
+ TaskRunner*,
+ SockType = SockType::kStream);
// Creates a Unix domain socket and connects to the listening endpoint.
// Returns always an instance. EventListener::OnConnect(bool success) will
// be called always, whether the connection succeeded or not.
static std::unique_ptr<UnixSocket> Connect(const std::string& socket_name,
EventListener*,
- base::TaskRunner*);
-
- // Creates a Unix domain socket and binds it to |socket_name| (see comment
- // of Listen() above for the format). This file descriptor is suitable to be
- // passed to Listen(ScopedFile, ...). Returns the file descriptor, or -1 in
- // case of failure.
- static base::ScopedFile CreateAndBind(const std::string& socket_name);
+ TaskRunner*,
+ SockType = SockType::kStream);
// This class gives the hard guarantee that no callback is called on the
// passed EventListener immediately after the object has been destroyed.
@@ -184,15 +222,23 @@
// DO NOT PASS kNonBlocking, it is broken.
bool Send(const void* msg,
size_t len,
- int send_fd = -1,
- BlockingMode blocking = BlockingMode::kNonBlocking);
- bool Send(const void* msg,
- size_t len,
const int* send_fds,
size_t num_fds,
BlockingMode blocking = BlockingMode::kNonBlocking);
- bool Send(const std::string& msg,
- BlockingMode blockimg = BlockingMode::kNonBlocking);
+
+ inline bool Send(const void* msg,
+ size_t len,
+ int send_fd = -1,
+ BlockingMode blocking = BlockingMode::kNonBlocking) {
+ if (send_fd != -1)
+ return Send(msg, len, &send_fd, 1, blocking);
+ return Send(msg, len, nullptr, 0, blocking);
+ }
+
+ inline bool Send(const std::string& msg,
+ BlockingMode blocking = BlockingMode::kNonBlocking) {
+ return Send(msg.c_str(), msg.size() + 1, -1, blocking);
+ }
// Returns the number of bytes (<= |len|) written in |msg| or 0 if there
// is no data in the buffer to read or an error occurs (in which case a
@@ -200,11 +246,11 @@
// If the ScopedFile pointer is not null and a FD is received, it moves the
// received FD into that. If a FD is received but the ScopedFile pointer is
// null, the FD will be automatically closed.
- size_t Receive(void* msg, size_t len);
- size_t Receive(void* msg,
- size_t len,
- base::ScopedFile*,
- size_t max_files = 1);
+ size_t Receive(void* msg, size_t len, ScopedFile*, size_t max_files = 1);
+
+ inline size_t Receive(void* msg, size_t len) {
+ return Receive(msg, len, nullptr, 0);
+ }
// Only for tests. This is slower than Receive() as it requires a heap
// allocation and a copy for the std::string. Guarantees that the returned
@@ -214,7 +260,7 @@
bool is_connected() const { return state_ == State::kConnected; }
bool is_listening() const { return state_ == State::kListening; }
- int fd() const { return fd_.get(); }
+ int fd() const { return sock_raw_.fd(); }
int last_error() const { return last_error_; }
// User ID of the peer, as returned by the kernel. If the client disconnects
@@ -239,20 +285,19 @@
#endif
private:
- UnixSocket(EventListener*, base::TaskRunner*);
- UnixSocket(EventListener*, base::TaskRunner*, base::ScopedFile, State);
+ UnixSocket(EventListener*, TaskRunner*, SockType);
+ UnixSocket(EventListener*, TaskRunner*, ScopedFile, State, SockType);
UnixSocket(const UnixSocket&) = delete;
UnixSocket& operator=(const UnixSocket&) = delete;
// Called once by the corresponding public static factory methods.
void DoConnect(const std::string& socket_name);
void ReadPeerCredentials();
- void SetBlockingIO(bool is_blocking);
void OnEvent();
void NotifyConnectionState(bool success);
- base::ScopedFile fd_;
+ UnixSocketRaw sock_raw_;
State state_ = State::kDisconnected;
int last_error_ = 0;
uid_t peer_uid_ = kInvalidUid;
@@ -260,9 +305,9 @@
PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
pid_t peer_pid_ = kInvalidPid;
#endif
- EventListener* event_listener_;
- base::TaskRunner* task_runner_;
- base::WeakPtrFactory<UnixSocket> weak_ptr_factory_;
+ EventListener* const event_listener_;
+ TaskRunner* const task_runner_;
+ WeakPtrFactory<UnixSocket> weak_ptr_factory_;
};
} // namespace base
diff --git a/src/base/unix_socket.cc b/src/base/unix_socket.cc
index ee83b75..2604703 100644
--- a/src/base/unix_socket.cc
+++ b/src/base/unix_socket.cc
@@ -41,6 +41,12 @@
namespace perfetto {
namespace base {
+// The CMSG_* macros use NULL instead of nullptr.
+#pragma GCC diagnostic push
+#if !PERFETTO_BUILDFLAG(PERFETTO_OS_MACOSX)
+#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
+#endif
+
namespace {
// MSG_NOSIGNAL is not supported on Mac OS X, but in that case the socket is
// created with SO_NOSIGPIPE (See InitializeSocket()).
@@ -56,15 +62,45 @@
#else
using CBufLenType = socklen_t;
#endif
+
+inline int GetUnixSockType(SockType type) {
+ switch (type) {
+ case SockType::kStream:
+ return SOCK_STREAM;
+ case SockType::kDgram:
+ return SOCK_DGRAM;
+ case SockType::kSeqPacket:
+ return SOCK_SEQPACKET;
+ }
+ PERFETTO_CHECK(false);
}
-// The CMSG_* macros use NULL instead of nullptr.
-#pragma GCC diagnostic push
-#if !PERFETTO_BUILDFLAG(PERFETTO_OS_MACOSX)
-#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
-#endif
+bool MakeSockAddr(const std::string& socket_name,
+ sockaddr_un* addr,
+ socklen_t* addr_size) {
+ memset(addr, 0, sizeof(*addr));
+ const size_t name_len = socket_name.size();
+ if (name_len >= sizeof(addr->sun_path)) {
+ errno = ENAMETOOLONG;
+ return false;
+ }
+ memcpy(addr->sun_path, socket_name.data(), name_len);
+ if (addr->sun_path[0] == '@')
+ addr->sun_path[0] = '\0';
+ addr->sun_family = AF_UNIX;
+ *addr_size = static_cast<socklen_t>(
+ __builtin_offsetof(sockaddr_un, sun_path) + name_len + 1);
+ return true;
+}
-void ShiftMsgHdr(size_t n, struct msghdr* msg) {
+} // namespace
+
+// +-----------------------+
+// | UnixSocketRaw methods |
+// +-----------------------+
+
+// static
+void UnixSocketRaw::ShiftMsgHdr(size_t n, struct msghdr* msg) {
using LenType = decltype(msg->msg_iovlen); // Mac and Linux don't agree.
for (LenType i = 0; i < msg->msg_iovlen; ++i) {
struct iovec* vec = &msg->msg_iov[i];
@@ -85,6 +121,92 @@
msg->msg_iov = nullptr;
}
+// static
+std::pair<UnixSocketRaw, UnixSocketRaw> UnixSocketRaw::CreatePair(SockType t) {
+ int fds[2];
+ if (socketpair(AF_UNIX, GetUnixSockType(t), 0, fds) != 0)
+ return std::make_pair(UnixSocketRaw(), UnixSocketRaw());
+
+ return std::make_pair(UnixSocketRaw(ScopedFile(fds[0]), t),
+ UnixSocketRaw(ScopedFile(fds[1]), t));
+}
+
+UnixSocketRaw::UnixSocketRaw() = default;
+
+UnixSocketRaw::UnixSocketRaw(SockType type)
+ : UnixSocketRaw(ScopedFile(socket(AF_UNIX, GetUnixSockType(type), 0)),
+ type) {}
+
+UnixSocketRaw::UnixSocketRaw(ScopedFile fd, SockType type)
+ : fd_(std::move(fd)), type_(type) {
+ PERFETTO_CHECK(fd_);
+#if PERFETTO_BUILDFLAG(PERFETTO_OS_MACOSX)
+ const int no_sigpipe = 1;
+ setsockopt(*fd_, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(no_sigpipe));
+#endif
+
+ // There is no reason why a socket should outlive the process in case of
+ // exec() by default, this is just working around a broken unix design.
+ int fcntl_res = fcntl(*fd_, F_SETFD, FD_CLOEXEC);
+ PERFETTO_CHECK(fcntl_res == 0);
+}
+
+void UnixSocketRaw::SetBlocking(bool is_blocking) {
+ PERFETTO_DCHECK(fd_);
+ int flags = fcntl(*fd_, F_GETFL, 0);
+ if (!is_blocking) {
+ flags |= O_NONBLOCK;
+ } else {
+ flags &= ~static_cast<int>(O_NONBLOCK);
+ }
+ bool fcntl_res = fcntl(*fd_, F_SETFL, flags);
+ PERFETTO_CHECK(fcntl_res == 0);
+}
+
+bool UnixSocketRaw::IsBlocking() const {
+ PERFETTO_DCHECK(fd_);
+ return (fcntl(*fd_, F_GETFL, 0) & O_NONBLOCK) == 0;
+}
+
+bool UnixSocketRaw::Bind(const std::string& socket_name) {
+ PERFETTO_DCHECK(fd_);
+ sockaddr_un addr;
+ socklen_t addr_size;
+ if (!MakeSockAddr(socket_name, &addr, &addr_size))
+ return false;
+
+ if (bind(*fd_, reinterpret_cast<sockaddr*>(&addr), addr_size)) {
+ PERFETTO_DPLOG("bind()");
+ return false;
+ }
+
+ return true;
+}
+
+bool UnixSocketRaw::Listen() {
+ return listen(*fd_, SOMAXCONN) == 0;
+}
+
+bool UnixSocketRaw::Connect(const std::string& socket_name) {
+ PERFETTO_DCHECK(fd_);
+ sockaddr_un addr;
+ socklen_t addr_size;
+ if (!MakeSockAddr(socket_name, &addr, &addr_size))
+ return false;
+
+ int res = PERFETTO_EINTR(
+ connect(*fd_, reinterpret_cast<sockaddr*>(&addr), addr_size));
+ if (res && errno != EINPROGRESS)
+ return false;
+
+ return true;
+}
+
+void UnixSocketRaw::Shutdown() {
+ shutdown(*fd_, SHUT_RDWR);
+ fd_.reset();
+}
+
// For the interested reader, Linux kernel dive to verify this is not only a
// theoretical possibility: sock_stream_sendmsg, if sock_alloc_send_pskb returns
// NULL [1] (which it does when it gets interrupted [2]), returns early with the
@@ -93,13 +215,14 @@
// [1]:
// https://elixir.bootlin.com/linux/v4.18.10/source/net/unix/af_unix.c#L1872
// [2]: https://elixir.bootlin.com/linux/v4.18.10/source/net/core/sock.c#L2101
-ssize_t SendMsgAll(int sockfd, struct msghdr* msg, int flags) {
+ssize_t UnixSocketRaw::SendMsgAll(struct msghdr* msg) {
// This does not make sense on non-blocking sockets.
- PERFETTO_DCHECK((fcntl(sockfd, F_GETFL, 0) & O_NONBLOCK) == 0);
+ PERFETTO_DCHECK(fd_);
+ PERFETTO_DCHECK(IsBlocking());
ssize_t total_sent = 0;
while (msg->msg_iov) {
- ssize_t sent = PERFETTO_EINTR(sendmsg(sockfd, msg, flags));
+ ssize_t sent = PERFETTO_EINTR(sendmsg(*fd_, msg, kNoSigPipe));
if (sent <= 0) {
if (sent == -1 && (errno == EAGAIN || errno == EWOULDBLOCK))
return total_sent;
@@ -114,11 +237,11 @@
return total_sent;
}
-ssize_t SockSend(int fd,
- const void* msg,
- size_t len,
- const int* send_fds,
- size_t num_fds) {
+ssize_t UnixSocketRaw::Send(const void* msg,
+ size_t len,
+ const int* send_fds,
+ size_t num_fds) {
+ PERFETTO_DCHECK(fd_);
msghdr msg_hdr = {};
iovec iov = {const_cast<void*>(msg), len};
msg_hdr.msg_iov = &iov;
@@ -142,14 +265,14 @@
// msg_hdr.msg_controllen would need to be adjusted, see "man 3 cmsg".
}
- return SendMsgAll(fd, &msg_hdr, kNoSigPipe);
+ return SendMsgAll(&msg_hdr);
}
-ssize_t SockReceive(int fd,
- void* msg,
- size_t len,
- ScopedFile* fd_vec,
- size_t max_files) {
+ssize_t UnixSocketRaw::Receive(void* msg,
+ size_t len,
+ ScopedFile* fd_vec,
+ size_t max_files) {
+ PERFETTO_DCHECK(fd_);
msghdr msg_hdr = {};
iovec iov = {msg, len};
msg_hdr.msg_iov = &iov;
@@ -162,7 +285,7 @@
static_cast<CBufLenType>(CMSG_SPACE(max_files * sizeof(int)));
PERFETTO_CHECK(msg_hdr.msg_controllen <= sizeof(control_buf));
}
- const ssize_t sz = PERFETTO_EINTR(recvmsg(fd, &msg_hdr, kNoSigPipe));
+ const ssize_t sz = PERFETTO_EINTR(recvmsg(*fd_, &msg_hdr, kNoSigPipe));
if (sz <= 0) {
return sz;
}
@@ -201,88 +324,75 @@
return sz;
}
+bool UnixSocketRaw::SetTxTimeout(uint32_t timeout_ms) {
+ PERFETTO_DCHECK(fd_);
+ struct timeval timeout {};
+ uint32_t timeout_sec = timeout_ms / 1000;
+ timeout.tv_sec = static_cast<decltype(timeout.tv_sec)>(timeout_sec);
+ timeout.tv_usec = static_cast<decltype(timeout.tv_usec)>(
+ (timeout_ms - (timeout_sec * 1000)) * 1000);
+
+ return setsockopt(*fd_, SOL_SOCKET, SO_SNDTIMEO,
+ reinterpret_cast<const char*>(&timeout),
+ sizeof(timeout)) == 0;
+}
+
#pragma GCC diagnostic pop
-bool MakeSockAddr(const std::string& socket_name,
- sockaddr_un* addr,
- socklen_t* addr_size) {
- memset(addr, 0, sizeof(*addr));
- const size_t name_len = socket_name.size();
- if (name_len >= sizeof(addr->sun_path)) {
- errno = ENAMETOOLONG;
- return false;
- }
- memcpy(addr->sun_path, socket_name.data(), name_len);
- if (addr->sun_path[0] == '@')
- addr->sun_path[0] = '\0';
- addr->sun_family = AF_UNIX;
- *addr_size = static_cast<socklen_t>(
- __builtin_offsetof(sockaddr_un, sun_path) + name_len + 1);
- return true;
-}
-
-ScopedFile CreateSocket() {
- return ScopedFile(socket(AF_UNIX, SOCK_STREAM, 0));
-}
+// +--------------------+
+// | UnixSocket methods |
+// +--------------------+
// TODO(primiano): Add ThreadChecker to methods of this class.
// static
-ScopedFile UnixSocket::CreateAndBind(const std::string& socket_name) {
- ScopedFile fd = CreateSocket();
- if (!fd)
- return fd;
-
- sockaddr_un addr;
- socklen_t addr_size;
- if (!MakeSockAddr(socket_name, &addr, &addr_size)) {
- return ScopedFile();
- }
-
- if (bind(*fd, reinterpret_cast<sockaddr*>(&addr), addr_size)) {
- PERFETTO_DPLOG("bind()");
- return ScopedFile();
- }
-
- return fd;
-}
-
-// static
std::unique_ptr<UnixSocket> UnixSocket::Listen(const std::string& socket_name,
EventListener* event_listener,
- TaskRunner* task_runner) {
+ TaskRunner* task_runner,
+ SockType sock_type) {
+ auto sock_raw = UnixSocketRaw::CreateMayFail(sock_type);
+ if (!sock_raw || !sock_raw.Bind(socket_name))
+ return nullptr;
+
// Forward the call to the Listen() overload below.
- return Listen(CreateAndBind(socket_name), event_listener, task_runner);
+ return Listen(sock_raw.ReleaseFd(), event_listener, task_runner);
}
// static
-std::unique_ptr<UnixSocket> UnixSocket::Listen(ScopedFile socket_fd,
+std::unique_ptr<UnixSocket> UnixSocket::Listen(ScopedFile fd,
EventListener* event_listener,
- TaskRunner* task_runner) {
- std::unique_ptr<UnixSocket> sock(new UnixSocket(
- event_listener, task_runner, std::move(socket_fd), State::kListening));
- return sock;
+ TaskRunner* task_runner,
+ SockType sock_type) {
+ return std::unique_ptr<UnixSocket>(
+ new UnixSocket(event_listener, task_runner, std::move(fd),
+ State::kListening, sock_type));
}
// static
std::unique_ptr<UnixSocket> UnixSocket::Connect(const std::string& socket_name,
EventListener* event_listener,
- TaskRunner* task_runner) {
- std::unique_ptr<UnixSocket> sock(new UnixSocket(event_listener, task_runner));
+ TaskRunner* task_runner,
+ SockType sock_type) {
+ std::unique_ptr<UnixSocket> sock(
+ new UnixSocket(event_listener, task_runner, sock_type));
sock->DoConnect(socket_name);
return sock;
}
-UnixSocket::UnixSocket(EventListener* event_listener, TaskRunner* task_runner)
+UnixSocket::UnixSocket(EventListener* event_listener,
+ TaskRunner* task_runner,
+ SockType sock_type)
: UnixSocket(event_listener,
task_runner,
ScopedFile(),
- State::kDisconnected) {}
+ State::kDisconnected,
+ sock_type) {}
UnixSocket::UnixSocket(EventListener* event_listener,
TaskRunner* task_runner,
ScopedFile adopt_fd,
- State adopt_state)
+ State adopt_state,
+ SockType sock_type)
: event_listener_(event_listener),
task_runner_(task_runner),
weak_ptr_factory_(this) {
@@ -290,15 +400,15 @@
if (adopt_state == State::kDisconnected) {
// We get here from the default ctor().
PERFETTO_DCHECK(!adopt_fd);
- fd_ = CreateSocket();
- if (!fd_) {
+ sock_raw_ = UnixSocketRaw::CreateMayFail(sock_type);
+ if (!sock_raw_) {
last_error_ = errno;
return;
}
} else if (adopt_state == State::kConnected) {
// We get here from OnNewIncomingConnection().
PERFETTO_DCHECK(adopt_fd);
- fd_ = std::move(adopt_fd);
+ sock_raw_ = UnixSocketRaw(std::move(adopt_fd), sock_type);
state_ = State::kConnected;
ReadPeerCredentials();
} else if (adopt_state == State::kListening) {
@@ -310,8 +420,8 @@
return;
}
- fd_ = std::move(adopt_fd);
- if (listen(*fd_, SOMAXCONN)) {
+ sock_raw_ = UnixSocketRaw(std::move(adopt_fd), sock_type);
+ if (!sock_raw_.Listen()) {
last_error_ = errno;
PERFETTO_DPLOG("listen()");
return;
@@ -321,22 +431,13 @@
PERFETTO_FATAL("Unexpected adopt_state"); // Unfeasible.
}
- PERFETTO_DCHECK(fd_);
+ PERFETTO_CHECK(sock_raw_);
last_error_ = 0;
-#if PERFETTO_BUILDFLAG(PERFETTO_OS_MACOSX)
- const int no_sigpipe = 1;
- setsockopt(*fd_, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(no_sigpipe));
-#endif
- // There is no reason why a socket should outlive the process in case of
- // exec() by default, this is just working around a broken unix design.
- int fcntl_res = fcntl(*fd_, F_SETFD, FD_CLOEXEC);
- PERFETTO_CHECK(fcntl_res == 0);
-
- SetBlockingIO(false);
+ sock_raw_.SetBlocking(false);
WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
- task_runner_->AddFileDescriptorWatch(*fd_, [weak_ptr]() {
+ task_runner_->AddFileDescriptorWatch(sock_raw_.fd(), [weak_ptr] {
if (weak_ptr)
weak_ptr->OnEvent();
});
@@ -352,39 +453,32 @@
PERFETTO_DCHECK(state_ == State::kDisconnected);
// This is the only thing that can gracefully fail in the ctor.
- if (!fd_)
+ if (!sock_raw_)
return NotifyConnectionState(false);
- sockaddr_un addr;
- socklen_t addr_size;
- if (!MakeSockAddr(socket_name, &addr, &addr_size)) {
+ if (!sock_raw_.Connect(socket_name)) {
last_error_ = errno;
return NotifyConnectionState(false);
}
- int res = PERFETTO_EINTR(
- connect(*fd_, reinterpret_cast<sockaddr*>(&addr), addr_size));
- if (res && errno != EINPROGRESS) {
- last_error_ = errno;
- return NotifyConnectionState(false);
- }
-
- // At this point either |res| == 0 (the connect() succeeded) or started
- // asynchronously (EINPROGRESS).
+ // At this point either connect() succeeded or started asynchronously
+ // (errno = EINPROGRESS).
last_error_ = 0;
state_ = State::kConnecting;
// Even if the socket is non-blocking, connecting to a UNIX socket can be
- // acknowledged straight away rather than returning EINPROGRESS. In this case
- // just trigger an OnEvent without waiting for the FD watch. That will poll
- // the SO_ERROR and evolve the state into either kConnected or kDisconnected.
- if (res == 0) {
- WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
- task_runner_->PostTask([weak_ptr]() {
- if (weak_ptr)
- weak_ptr->OnEvent();
- });
- }
+ // acknowledged straight away rather than returning EINPROGRESS.
+ // The decision here is to deal with the two cases uniformly, at the cost of
+ // delaying the straight-away-connect() case by one task, to avoid depending
+ // on implementation details of UNIX socket on the various OSes.
+ // Posting the OnEvent() below emulates a wakeup of the FD watch. OnEvent(),
+ // which knows how to deal with spurious wakeups, will poll the SO_ERROR and
+ // evolve, if necessary, the state into either kConnected or kDisconnected.
+ WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
+ task_runner_->PostTask([weak_ptr] {
+ if (weak_ptr)
+ weak_ptr->OnEvent();
+ });
}
void UnixSocket::ReadPeerCredentials() {
@@ -392,14 +486,15 @@
PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
struct ucred user_cred;
socklen_t len = sizeof(user_cred);
- int res = getsockopt(*fd_, SOL_SOCKET, SO_PEERCRED, &user_cred, &len);
+ int fd = sock_raw_.fd();
+ int res = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &user_cred, &len);
PERFETTO_CHECK(res == 0);
peer_uid_ = user_cred.uid;
peer_pid_ = user_cred.pid;
#else
struct xucred user_cred;
socklen_t len = sizeof(user_cred);
- int res = getsockopt(*fd_, 0, LOCAL_PEERCRED, &user_cred, &len);
+ int res = getsockopt(sock_raw_.fd(), 0, LOCAL_PEERCRED, &user_cred, &len);
PERFETTO_CHECK(res == 0 && user_cred.cr_version == XUCRED_VERSION);
peer_uid_ = static_cast<uid_t>(user_cred.cr_uid);
// There is no pid in the LOCAL_PEERCREDS for MacOS / FreeBSD.
@@ -414,10 +509,11 @@
return event_listener_->OnDataAvailable(this);
if (state_ == State::kConnecting) {
- PERFETTO_DCHECK(fd_);
+ PERFETTO_DCHECK(sock_raw_);
int sock_err = EINVAL;
socklen_t err_len = sizeof(sock_err);
- int res = getsockopt(*fd_, SOL_SOCKET, SO_ERROR, &sock_err, &err_len);
+ int res =
+ getsockopt(sock_raw_.fd(), SOL_SOCKET, SO_ERROR, &sock_err, &err_len);
if (res == 0 && sock_err == EINPROGRESS)
return; // Not connected yet, just a spurious FD watch wakeup.
if (res == 0 && sock_err == 0) {
@@ -437,30 +533,18 @@
for (;;) {
sockaddr_un cli_addr = {};
socklen_t size = sizeof(cli_addr);
- ScopedFile new_fd(PERFETTO_EINTR(
- accept(*fd_, reinterpret_cast<sockaddr*>(&cli_addr), &size)));
+ ScopedFile new_fd(PERFETTO_EINTR(accept(
+ sock_raw_.fd(), reinterpret_cast<sockaddr*>(&cli_addr), &size)));
if (!new_fd)
return;
- std::unique_ptr<UnixSocket> new_sock(new UnixSocket(
- event_listener_, task_runner_, std::move(new_fd), State::kConnected));
+ std::unique_ptr<UnixSocket> new_sock(
+ new UnixSocket(event_listener_, task_runner_, std::move(new_fd),
+ State::kConnected, sock_raw_.type()));
event_listener_->OnNewIncomingConnection(this, std::move(new_sock));
}
}
}
-bool UnixSocket::Send(const std::string& msg, BlockingMode blocking) {
- return Send(msg.c_str(), msg.size() + 1, -1, blocking);
-}
-
-bool UnixSocket::Send(const void* msg,
- size_t len,
- int send_fd,
- BlockingMode blocking_mode) {
- if (send_fd != -1)
- return Send(msg, len, &send_fd, 1, blocking_mode);
- return Send(msg, len, nullptr, 0, blocking_mode);
-}
-
bool UnixSocket::Send(const void* msg,
size_t len,
const int* send_fds,
@@ -476,10 +560,11 @@
}
if (blocking_mode == BlockingMode::kBlocking)
- SetBlockingIO(true);
- const ssize_t sz = SockSend(*fd_, msg, len, send_fds, num_fds);
+ sock_raw_.SetBlocking(true);
+ const ssize_t sz = sock_raw_.Send(msg, len, send_fds, num_fds);
+ int saved_errno = errno;
if (blocking_mode == BlockingMode::kBlocking)
- SetBlockingIO(false);
+ sock_raw_.SetBlocking(false);
if (sz == static_cast<ssize_t>(len)) {
last_error_ = 0;
@@ -490,7 +575,7 @@
// endpoint disconnected in the middle of the read, and we managed to send
// only a portion of the buffer. In this case we should just give up.
- if (sz < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
+ if (sz < 0 && (saved_errno == EAGAIN || saved_errno == EWOULDBLOCK)) {
// A genuine out-of-buffer. The client should retry or give up.
// Man pages specify that EAGAIN and EWOULDBLOCK have the same semantic here
// and clients should check for both.
@@ -500,7 +585,7 @@
// Either the the other endpoint disconnect (ECONNRESET) or some other error
// happened.
- last_error_ = errno;
+ last_error_ = saved_errno;
PERFETTO_DPLOG("sendmsg() failed");
Shutdown(true);
return false;
@@ -510,30 +595,25 @@
WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
if (notify) {
if (state_ == State::kConnected) {
- task_runner_->PostTask([weak_ptr]() {
+ task_runner_->PostTask([weak_ptr] {
if (weak_ptr)
weak_ptr->event_listener_->OnDisconnect(weak_ptr.get());
});
} else if (state_ == State::kConnecting) {
- task_runner_->PostTask([weak_ptr]() {
+ task_runner_->PostTask([weak_ptr] {
if (weak_ptr)
weak_ptr->event_listener_->OnConnect(weak_ptr.get(), false);
});
}
}
- if (fd_) {
- shutdown(*fd_, SHUT_RDWR);
- task_runner_->RemoveFileDescriptorWatch(*fd_);
- fd_.reset();
+ if (sock_raw_) {
+ task_runner_->RemoveFileDescriptorWatch(sock_raw_.fd());
+ sock_raw_.Shutdown();
}
state_ = State::kDisconnected;
}
-size_t UnixSocket::Receive(void* msg, size_t len) {
- return Receive(msg, len, nullptr, 0);
-}
-
size_t UnixSocket::Receive(void* msg,
size_t len,
ScopedFile* fd_vec,
@@ -543,7 +623,7 @@
return 0;
}
- const ssize_t sz = SockReceive(*fd_, msg, len, fd_vec, max_files);
+ const ssize_t sz = sock_raw_.Receive(msg, len, fd_vec, max_files);
if (sz < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
last_error_ = EAGAIN;
return 0;
@@ -570,23 +650,12 @@
Shutdown(false);
WeakPtr<UnixSocket> weak_ptr = weak_ptr_factory_.GetWeakPtr();
- task_runner_->PostTask([weak_ptr, success]() {
+ task_runner_->PostTask([weak_ptr, success] {
if (weak_ptr)
weak_ptr->event_listener_->OnConnect(weak_ptr.get(), success);
});
}
-void UnixSocket::SetBlockingIO(bool is_blocking) {
- int flags = fcntl(*fd_, F_GETFL, 0);
- if (!is_blocking) {
- flags |= O_NONBLOCK;
- } else {
- flags &= ~static_cast<int>(O_NONBLOCK);
- }
- bool fcntl_res = fcntl(fd(), F_SETFL, flags);
- PERFETTO_CHECK(fcntl_res == 0);
-}
-
UnixSocket::EventListener::~EventListener() {}
void UnixSocket::EventListener::OnNewIncomingConnection(
UnixSocket*,
diff --git a/src/base/unix_socket_unittest.cc b/src/base/unix_socket_unittest.cc
index feff166..e1fea62 100644
--- a/src/base/unix_socket_unittest.cc
+++ b/src/base/unix_socket_unittest.cc
@@ -18,7 +18,9 @@
#include <signal.h>
#include <sys/mman.h>
-
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
#include <list>
#include <thread>
@@ -265,7 +267,9 @@
}
TEST_F(UnixSocketTest, ListenWithPassedFileDescriptor) {
- auto fd = UnixSocket::CreateAndBind(kSocketName);
+ auto sock_raw = UnixSocketRaw::CreateMayFail(SockType::kStream);
+ ASSERT_TRUE(sock_raw.Bind(kSocketName));
+ auto fd = sock_raw.ReleaseFd();
auto srv = UnixSocket::Listen(std::move(fd), &event_listener_, &task_runner_);
ASSERT_TRUE(srv->is_listening());
@@ -566,7 +570,7 @@
hdr.msg_iov = iov;
hdr.msg_iovlen = base::ArraySize(iov);
- ShiftMsgHdr(1, &hdr);
+ UnixSocketRaw::ShiftMsgHdr(1, &hdr);
EXPECT_NE(hdr.msg_iov, nullptr);
EXPECT_EQ(hdr.msg_iov[0].iov_base, &hello[1]);
EXPECT_EQ(hdr.msg_iov[1].iov_base, &world[0]);
@@ -574,13 +578,13 @@
EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), "ello");
EXPECT_EQ(iov[0].iov_len, base::ArraySize(hello) - 1);
- ShiftMsgHdr(base::ArraySize(hello) - 1, &hdr);
+ UnixSocketRaw::ShiftMsgHdr(base::ArraySize(hello) - 1, &hdr);
EXPECT_EQ(hdr.msg_iov, &iov[1]);
EXPECT_EQ(hdr.msg_iovlen, 1);
EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), world);
EXPECT_EQ(hdr.msg_iov[0].iov_len, base::ArraySize(world));
- ShiftMsgHdr(base::ArraySize(world), &hdr);
+ UnixSocketRaw::ShiftMsgHdr(base::ArraySize(world), &hdr);
EXPECT_EQ(hdr.msg_iov, nullptr);
EXPECT_EQ(hdr.msg_iovlen, 0);
}
@@ -600,13 +604,13 @@
hdr.msg_iov = iov;
hdr.msg_iovlen = base::ArraySize(iov);
- ShiftMsgHdr(base::ArraySize(hello) + 1, &hdr);
+ UnixSocketRaw::ShiftMsgHdr(base::ArraySize(hello) + 1, &hdr);
EXPECT_NE(hdr.msg_iov, nullptr);
EXPECT_EQ(hdr.msg_iovlen, 1);
EXPECT_STREQ(reinterpret_cast<char*>(hdr.msg_iov[0].iov_base), "orld");
EXPECT_EQ(hdr.msg_iov[0].iov_len, base::ArraySize(world) - 1);
- ShiftMsgHdr(base::ArraySize(world) - 1, &hdr);
+ UnixSocketRaw::ShiftMsgHdr(base::ArraySize(world) - 1, &hdr);
EXPECT_EQ(hdr.msg_iov, nullptr);
EXPECT_EQ(hdr.msg_iovlen, 0);
}
@@ -626,7 +630,8 @@
hdr.msg_iov = iov;
hdr.msg_iovlen = base::ArraySize(iov);
- ShiftMsgHdr(base::ArraySize(world) + base::ArraySize(hello), &hdr);
+ UnixSocketRaw::ShiftMsgHdr(base::ArraySize(world) + base::ArraySize(hello),
+ &hdr);
EXPECT_EQ(hdr.msg_iov, nullptr);
EXPECT_EQ(hdr.msg_iovlen, 0);
}
@@ -638,17 +643,18 @@
}
TEST_F(UnixSocketTest, PartialSendMsgAll) {
- int sv[2];
- ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), 0);
- base::ScopedFile send_socket(sv[0]);
- base::ScopedFile recv_socket(sv[1]);
+ UnixSocketRaw send_sock;
+ UnixSocketRaw recv_sock;
+ std::tie(send_sock, recv_sock) = UnixSocketRaw::CreatePair(SockType::kStream);
+ ASSERT_TRUE(send_sock);
+ ASSERT_TRUE(recv_sock);
// Set bufsize to minimum.
int bufsize = 1024;
- ASSERT_EQ(setsockopt(*send_socket, SOL_SOCKET, SO_SNDBUF, &bufsize,
+ ASSERT_EQ(setsockopt(send_sock.fd(), SOL_SOCKET, SO_SNDBUF, &bufsize,
sizeof(bufsize)),
0);
- ASSERT_EQ(setsockopt(*recv_socket, SOL_SOCKET, SO_RCVBUF, &bufsize,
+ ASSERT_EQ(setsockopt(recv_sock.fd(), SOL_SOCKET, SO_RCVBUF, &bufsize,
sizeof(bufsize)),
0);
@@ -674,8 +680,8 @@
rollback(&oldact);
auto blocked_thread = pthread_self();
- std::thread th([blocked_thread, &recv_socket, &recv_buf] {
- ssize_t rd = PERFETTO_EINTR(read(*recv_socket, recv_buf, 1));
+ std::thread th([blocked_thread, &recv_sock, &recv_buf] {
+ ssize_t rd = PERFETTO_EINTR(read(recv_sock.fd(), recv_buf, 1));
ASSERT_EQ(rd, 1);
// We are now sure the other thread is in sendmsg, interrupt send.
ASSERT_EQ(pthread_kill(blocked_thread, SIGWINCH), 0);
@@ -683,7 +689,7 @@
size_t offset = 1;
while (offset < sizeof(recv_buf)) {
rd = PERFETTO_EINTR(
- read(*recv_socket, recv_buf + offset, sizeof(recv_buf) - offset));
+ read(recv_sock.fd(), recv_buf + offset, sizeof(recv_buf) - offset));
ASSERT_GE(rd, 0);
offset += static_cast<size_t>(rd);
}
@@ -703,8 +709,8 @@
hdr.msg_iov = iov;
hdr.msg_iovlen = base::ArraySize(iov);
- ASSERT_EQ(SendMsgAll(*send_socket, &hdr, 0), sizeof(send_buf));
- send_socket.reset();
+ ASSERT_EQ(send_sock.SendMsgAll(&hdr), sizeof(send_buf));
+ send_sock.Shutdown();
th.join();
// Make sure the re-entry logic was actually triggered.
ASSERT_EQ(hdr.msg_iov, nullptr);
diff --git a/src/profiling/memory/client.cc b/src/profiling/memory/client.cc
index 358aed6..6ad47eb 100644
--- a/src/profiling/memory/client.cc
+++ b/src/profiling/memory/client.cc
@@ -17,9 +17,7 @@
#include "src/profiling/memory/client.h"
#include <inttypes.h>
-#include <sys/socket.h>
#include <sys/syscall.h>
-#include <sys/un.h>
#include <unistd.h>
#include <atomic>
@@ -46,27 +44,20 @@
namespace profiling {
namespace {
-constexpr struct timeval kSendTimeout = {1 /* s */, 0 /* us */};
+constexpr uint32_t kSendTimeoutMs = 1000;
constexpr std::chrono::seconds kLockTimeout{1};
-std::vector<base::ScopedFile> ConnectPool(const std::string& sock_name,
- size_t n) {
- sockaddr_un addr;
- socklen_t addr_size;
- if (!base::MakeSockAddr(sock_name, &addr, &addr_size))
- return {};
-
- std::vector<base::ScopedFile> res;
+std::vector<base::UnixSocketRaw> ConnectPool(const std::string& sock_name,
+ size_t n) {
+ std::vector<base::UnixSocketRaw> res;
res.reserve(n);
for (size_t i = 0; i < n; ++i) {
- auto sock = base::CreateSocket();
- if (connect(*sock, reinterpret_cast<sockaddr*>(&addr), addr_size) == -1) {
+ auto sock = base::UnixSocketRaw::CreateMayFail(base::SockType::kStream);
+ if (!sock || !sock.Connect(sock_name)) {
PERFETTO_PLOG("Failed to connect to %s", sock_name.c_str());
continue;
}
- if (setsockopt(*sock, SOL_SOCKET, SO_SNDTIMEO,
- reinterpret_cast<const char*>(&kSendTimeout),
- sizeof(kSendTimeout)) != 0) {
+ if (!sock.SetTxTimeout(kSendTimeoutMs)) {
PERFETTO_PLOG("Failed to set timeout for %s", sock_name.c_str());
continue;
}
@@ -125,36 +116,36 @@
msg.record_type = RecordType::Free;
free_page_.num_entries = offset_;
msg.free_header = &free_page_;
- BorrowedSocket fd(pool->Borrow());
- if (!fd || !SendWireMessage(*fd, msg)) {
+ BorrowedSocket sock(pool->Borrow());
+ if (!sock || !SendWireMessage(sock.get(), msg)) {
PERFETTO_ELOG("Failed to send wire message");
- fd.Close();
+ sock.Shutdown();
return false;
}
return true;
}
-SocketPool::SocketPool(std::vector<base::ScopedFile> sockets)
+SocketPool::SocketPool(std::vector<base::UnixSocketRaw> sockets)
: sockets_(std::move(sockets)), available_sockets_(sockets_.size()) {}
BorrowedSocket SocketPool::Borrow() {
std::unique_lock<std::timed_mutex> l(mutex_, kLockTimeout);
if (!l.owns_lock())
- return {base::ScopedFile(), nullptr};
+ return {base::UnixSocketRaw(), nullptr};
cv_.wait(l, [this] {
return available_sockets_ > 0 || dead_sockets_ == sockets_.size() ||
shutdown_;
});
if (dead_sockets_ == sockets_.size() || shutdown_) {
- return {base::ScopedFile(), nullptr};
+ return {base::UnixSocketRaw(), nullptr};
}
PERFETTO_CHECK(available_sockets_ > 0);
return {std::move(sockets_[--available_sockets_]), this};
}
-void SocketPool::Return(base::ScopedFile sock) {
+void SocketPool::Return(base::UnixSocketRaw sock) {
std::unique_lock<std::timed_mutex> l(mutex_, kLockTimeout);
if (!l.owns_lock())
return;
@@ -179,7 +170,7 @@
if (!l.owns_lock())
return;
for (size_t i = 0; i < available_sockets_; ++i)
- sockets_[i].reset();
+ sockets_[i].Shutdown();
dead_sockets_ += available_sockets_;
available_sockets_ = 0;
shutdown_ = true;
@@ -202,7 +193,7 @@
return stackaddr + stacksize;
}
-Client::Client(std::vector<base::ScopedFile> socks)
+Client::Client(std::vector<base::UnixSocketRaw> socks)
: pthread_key_(ThreadLocalSamplingData::KeyDestructor),
socket_pool_(std::move(socks)),
main_thread_stack_base_(FindMainThreadStack()) {
@@ -218,16 +209,16 @@
int fds[2];
fds[0] = *maps;
fds[1] = *mem;
- auto fd = socket_pool_.Borrow();
- if (!fd)
+ auto sock = socket_pool_.Borrow();
+ if (!sock)
return;
// Send an empty record to transfer fds for /proc/self/maps and
// /proc/self/mem.
- if (base::SockSend(*fd, &size, sizeof(size), fds, 2) != sizeof(size)) {
+ if (sock->Send(&size, sizeof(size), fds, 2) != sizeof(size)) {
PERFETTO_DFATAL("Failed to send file descriptors.");
return;
}
- if (recv(*fd, &client_config_, sizeof(client_config_), 0) !=
+ if (sock->Receive(&client_config_, sizeof(client_config_)) !=
sizeof(client_config_)) {
PERFETTO_DFATAL("Failed to receive client config.");
return;
@@ -294,10 +285,10 @@
msg.payload = const_cast<char*>(stacktop);
msg.payload_size = static_cast<size_t>(stack_size);
- BorrowedSocket fd = socket_pool_.Borrow();
- if (!fd || !SendWireMessage(*fd, msg)) {
+ BorrowedSocket sock = socket_pool_.Borrow();
+ if (!sock || !SendWireMessage(sock.get(), msg)) {
PERFETTO_DFATAL("Failed to send wire message.");
- fd.Close();
+ sock.Shutdown();
Shutdown();
}
}
diff --git a/src/profiling/memory/client.h b/src/profiling/memory/client.h
index 0396902..36d220e 100644
--- a/src/profiling/memory/client.h
+++ b/src/profiling/memory/client.h
@@ -24,7 +24,7 @@
#include <mutex>
#include <vector>
-#include "perfetto/base/scoped_file.h"
+#include "perfetto/base/unix_socket.h"
#include "src/profiling/memory/wire_protocol.h"
namespace perfetto {
@@ -35,7 +35,7 @@
class SocketPool {
public:
friend class BorrowedSocket;
- SocketPool(std::vector<base::ScopedFile> sockets);
+ SocketPool(std::vector<base::UnixSocketRaw> sockets);
BorrowedSocket Borrow();
void Shutdown();
@@ -43,10 +43,10 @@
private:
bool shutdown_ = false;
- void Return(base::ScopedFile fd);
+ void Return(base::UnixSocketRaw);
std::timed_mutex mutex_;
std::condition_variable_any cv_;
- std::vector<base::ScopedFile> sockets_;
+ std::vector<base::UnixSocketRaw> sockets_;
size_t available_sockets_;
size_t dead_sockets_ = 0;
};
@@ -56,30 +56,26 @@
public:
BorrowedSocket(const BorrowedSocket&) = delete;
BorrowedSocket& operator=(const BorrowedSocket&) = delete;
- BorrowedSocket(BorrowedSocket&& other) noexcept {
- fd_ = std::move(other.fd_);
- socket_pool_ = other.socket_pool_;
+ BorrowedSocket(BorrowedSocket&& other) noexcept
+ : sock_(std::move(other.sock_)), socket_pool_(other.socket_pool_) {
other.socket_pool_ = nullptr;
}
- BorrowedSocket(base::ScopedFile fd, SocketPool* socket_pool)
- : fd_(std::move(fd)), socket_pool_(socket_pool) {}
+ BorrowedSocket(base::UnixSocketRaw sock, SocketPool* socket_pool)
+ : sock_(std::move(sock)), socket_pool_(socket_pool) {}
~BorrowedSocket() {
if (socket_pool_ != nullptr)
- socket_pool_->Return(std::move(fd_));
+ socket_pool_->Return(std::move(sock_));
}
- int operator*() { return get(); }
-
- int get() { return *fd_; }
-
- void Close() { fd_.reset(); }
-
- operator bool() const { return !!fd_; }
+ base::UnixSocketRaw* operator->() { return &sock_; }
+ base::UnixSocketRaw* get() { return &sock_; }
+ void Shutdown() { sock_.Shutdown(); }
+ explicit operator bool() const { return !!sock_; }
private:
- base::ScopedFile fd_;
+ base::UnixSocketRaw sock_;
SocketPool* socket_pool_ = nullptr;
};
@@ -130,7 +126,7 @@
// This is created and owned by the malloc hooks.
class Client {
public:
- Client(std::vector<base::ScopedFile> sockets);
+ Client(std::vector<base::UnixSocketRaw> sockets);
Client(const std::string& sock_name, size_t conns);
void RecordMalloc(uint64_t alloc_size,
uint64_t total_size,
diff --git a/src/profiling/memory/client_unittest.cc b/src/profiling/memory/client_unittest.cc
index 7e8e726..62185aa 100644
--- a/src/profiling/memory/client_unittest.cc
+++ b/src/profiling/memory/client_unittest.cc
@@ -17,6 +17,7 @@
#include "src/profiling/memory/client.h"
#include "gtest/gtest.h"
+#include "perfetto/base/unix_socket.h"
#include <thread>
@@ -24,34 +25,41 @@
namespace profiling {
namespace {
+base::UnixSocketRaw CreateSocket() {
+ auto sock = base::UnixSocketRaw::CreateMayFail(base::SockType::kStream);
+ PERFETTO_CHECK(sock);
+ return sock;
+}
+
TEST(SocketPoolTest, Basic) {
- std::vector<base::ScopedFile> files;
- files.emplace_back(base::OpenFile("/dev/null", O_RDONLY));
- SocketPool pool(std::move(files));
+ std::vector<base::UnixSocketRaw> socks;
+ socks.emplace_back(CreateSocket());
+ SocketPool pool(std::move(socks));
BorrowedSocket sock = pool.Borrow();
}
+
TEST(SocketPoolTest, Close) {
- std::vector<base::ScopedFile> files;
- files.emplace_back(base::OpenFile("/dev/null", O_RDONLY));
- SocketPool pool(std::move(files));
+ std::vector<base::UnixSocketRaw> socks;
+ socks.emplace_back(CreateSocket());
+ SocketPool pool(std::move(socks));
BorrowedSocket sock = pool.Borrow();
- sock.Close();
+ sock.Shutdown();
}
TEST(SocketPoolTest, Multiple) {
- std::vector<base::ScopedFile> files;
- files.emplace_back(base::OpenFile("/dev/null", O_RDONLY));
- files.emplace_back(base::OpenFile("/dev/null", O_RDONLY));
- SocketPool pool(std::move(files));
+ std::vector<base::UnixSocketRaw> socks;
+ socks.emplace_back(CreateSocket());
+ socks.emplace_back(CreateSocket());
+ SocketPool pool(std::move(socks));
BorrowedSocket sock = pool.Borrow();
BorrowedSocket sock_2 = pool.Borrow();
}
TEST(SocketPoolTest, Blocked) {
- std::vector<base::ScopedFile> files;
- files.emplace_back(base::OpenFile("/dev/null", O_RDONLY));
- SocketPool pool(std::move(files));
- BorrowedSocket sock = pool.Borrow();
+ std::vector<base::UnixSocketRaw> socks;
+ socks.emplace_back(CreateSocket());
+ SocketPool pool(std::move(socks));
+ BorrowedSocket sock = pool.Borrow(); // Takes the socket above.
std::thread t([&pool] { pool.Borrow(); });
{
// Return fd to unblock thread.
@@ -61,23 +69,23 @@
}
TEST(SocketPoolTest, BlockedClose) {
- std::vector<base::ScopedFile> files;
- files.emplace_back(base::OpenFile("/dev/null", O_RDONLY));
- SocketPool pool(std::move(files));
+ std::vector<base::UnixSocketRaw> socks;
+ socks.emplace_back(CreateSocket());
+ SocketPool pool(std::move(socks));
BorrowedSocket sock = pool.Borrow();
std::thread t([&pool] { pool.Borrow(); });
{
// Return fd to unblock thread.
BorrowedSocket temp = std::move(sock);
- temp.Close();
+ temp.Shutdown();
}
t.join();
}
TEST(SocketPoolTest, MultipleBlocked) {
- std::vector<base::ScopedFile> files;
- files.emplace_back(base::OpenFile("/dev/null", O_RDONLY));
- SocketPool pool(std::move(files));
+ std::vector<base::UnixSocketRaw> socks;
+ socks.emplace_back(CreateSocket());
+ SocketPool pool(std::move(socks));
BorrowedSocket sock = pool.Borrow();
std::thread t([&pool] { pool.Borrow(); });
std::thread t2([&pool] { pool.Borrow(); });
@@ -90,25 +98,25 @@
}
TEST(SocketPoolTest, MultipleBlockedClose) {
- std::vector<base::ScopedFile> files;
- files.emplace_back(base::OpenFile("/dev/null", O_RDONLY));
- SocketPool pool(std::move(files));
+ std::vector<base::UnixSocketRaw> socks;
+ socks.emplace_back(CreateSocket());
+ SocketPool pool(std::move(socks));
BorrowedSocket sock = pool.Borrow();
std::thread t([&pool] { pool.Borrow(); });
std::thread t2([&pool] { pool.Borrow(); });
{
// Return fd to unblock thread.
BorrowedSocket temp = std::move(sock);
- temp.Close();
+ temp.Shutdown();
}
t.join();
t2.join();
}
TEST(FreePageTest, ShutdownSocketPool) {
- std::vector<base::ScopedFile> files;
- files.emplace_back(base::OpenFile("/dev/null", O_RDONLY));
- SocketPool pool(std::move(files));
+ std::vector<base::UnixSocketRaw> socks;
+ socks.emplace_back(CreateSocket());
+ SocketPool pool(std::move(socks));
pool.Shutdown();
FreePage p;
p.Add(0, 1, &pool);
diff --git a/src/profiling/memory/wire_protocol.cc b/src/profiling/memory/wire_protocol.cc
index 2ef5b41..43ae7de 100644
--- a/src/profiling/memory/wire_protocol.cc
+++ b/src/profiling/memory/wire_protocol.cc
@@ -37,7 +37,7 @@
}
} // namespace
-bool SendWireMessage(int sock, const WireMessage& msg) {
+bool SendWireMessage(base::UnixSocketRaw* sock, const WireMessage& msg) {
uint64_t total_size;
struct iovec iovecs[4] = {};
// TODO(fmayer): Maye pack these two.
@@ -72,7 +72,7 @@
total_size = iovecs[1].iov_len + iovecs[2].iov_len;
}
- ssize_t sent = base::SendMsgAll(sock, &hdr, MSG_NOSIGNAL);
+ ssize_t sent = sock->SendMsgAll(&hdr);
return sent == static_cast<ssize_t>(total_size + sizeof(total_size));
}
diff --git a/src/profiling/memory/wire_protocol.h b/src/profiling/memory/wire_protocol.h
index 045b983..00db8e1 100644
--- a/src/profiling/memory/wire_protocol.h
+++ b/src/profiling/memory/wire_protocol.h
@@ -30,6 +30,11 @@
#include <unwindstack/UserX86_64.h>
namespace perfetto {
+
+namespace base {
+class UnixSocketRaw;
+}
+
namespace profiling {
// Types needed for the wire format used for communication between the client
@@ -115,7 +120,7 @@
size_t payload_size;
};
-bool SendWireMessage(int sock, const WireMessage& msg);
+bool SendWireMessage(base::UnixSocketRaw*, const WireMessage& msg);
// Parse message received over the wire.
// |buf| has to outlive |out|.
diff --git a/src/profiling/memory/wire_protocol_unittest.cc b/src/profiling/memory/wire_protocol_unittest.cc
index 58d73af..4182fa0 100644
--- a/src/profiling/memory/wire_protocol_unittest.cc
+++ b/src/profiling/memory/wire_protocol_unittest.cc
@@ -17,6 +17,7 @@
#include "src/profiling/memory/wire_protocol.h"
#include "perfetto/base/logging.h"
#include "perfetto/base/scoped_file.h"
+#include "perfetto/base/unix_socket.h"
#include "src/profiling/memory/record_reader.h"
#include <sys/socket.h>
@@ -53,13 +54,13 @@
namespace {
-RecordReader::Record ReceiveAll(int sock) {
+RecordReader::Record ReceiveAll(base::UnixSocketRaw* sock) {
RecordReader record_reader;
RecordReader::Record record;
bool received = false;
while (!received) {
RecordReader::ReceiveBuffer buf = record_reader.BeginReceive();
- ssize_t rd = PERFETTO_EINTR(read(sock, buf.data, buf.size));
+ ssize_t rd = sock->Receive(buf.data, buf.size);
PERFETTO_CHECK(rd > 0);
auto status = record_reader.EndReceive(static_cast<size_t>(rd), &record);
switch (status) {
@@ -93,13 +94,15 @@
msg.payload = payload;
msg.payload_size = sizeof(payload);
- int sv[2];
- ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), 0);
- base::ScopedFile send_sock(sv[0]);
- base::ScopedFile recv_sock(sv[1]);
- ASSERT_TRUE(SendWireMessage(*send_sock, msg));
+ base::UnixSocketRaw send_sock;
+ base::UnixSocketRaw recv_sock;
+ std::tie(send_sock, recv_sock) =
+ base::UnixSocketRaw::CreatePair(base::SockType::kStream);
+ ASSERT_TRUE(send_sock);
+ ASSERT_TRUE(recv_sock);
+ ASSERT_TRUE(SendWireMessage(&send_sock, msg));
- RecordReader::Record record = ReceiveAll(*recv_sock);
+ RecordReader::Record record = ReceiveAll(&recv_sock);
WireMessage recv_msg;
ASSERT_TRUE(ReceiveWireMessage(reinterpret_cast<char*>(record.data.get()),
@@ -123,11 +126,13 @@
int sv[2];
ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), 0);
- base::ScopedFile send_sock(sv[0]);
- base::ScopedFile recv_sock(sv[1]);
- ASSERT_TRUE(SendWireMessage(*send_sock, msg));
+ base::UnixSocketRaw send_sock(base::ScopedFile(sv[0]),
+ base::SockType::kStream);
+ base::UnixSocketRaw recv_sock(base::ScopedFile(sv[1]),
+ base::SockType::kStream);
+ ASSERT_TRUE(SendWireMessage(&send_sock, msg));
- RecordReader::Record record = ReceiveAll(*recv_sock);
+ RecordReader::Record record = ReceiveAll(&recv_sock);
WireMessage recv_msg;
ASSERT_TRUE(ReceiveWireMessage(reinterpret_cast<char*>(record.data.get()),