base: Add websocket support to http_server.cc
Bug: 205274609
Change-Id: Ia8977e03a4ea3e758470c47ce946af60a42766ab
diff --git a/Android.bp b/Android.bp
index 1002d49..afc122e 100644
--- a/Android.bp
+++ b/Android.bp
@@ -6898,6 +6898,7 @@
name: "perfetto_src_base_http_http",
srcs: [
"src/base/http/http_server.cc",
+ "src/base/http/sha1.cc",
],
}
@@ -6906,6 +6907,7 @@
name: "perfetto_src_base_http_unittests",
srcs: [
"src/base/http/http_server_unittest.cc",
+ "src/base/http/sha1_unittest.cc",
],
}
diff --git a/BUILD b/BUILD
index 0a0b060..a0d0884 100644
--- a/BUILD
+++ b/BUILD
@@ -335,6 +335,7 @@
name = "include_perfetto_ext_base_http_http",
srcs = [
"include/perfetto/ext/base/http/http_server.h",
+ "include/perfetto/ext/base/http/sha1.h",
],
)
@@ -654,6 +655,7 @@
name = "src_base_http_http",
srcs = [
"src/base/http/http_server.cc",
+ "src/base/http/sha1.cc",
],
hdrs = [
":include_perfetto_base_base",
diff --git a/include/perfetto/ext/base/http/BUILD.gn b/include/perfetto/ext/base/http/BUILD.gn
index 78268b2..3e42607 100644
--- a/include/perfetto/ext/base/http/BUILD.gn
+++ b/include/perfetto/ext/base/http/BUILD.gn
@@ -13,6 +13,9 @@
# limitations under the License.
source_set("http") {
- sources = [ "http_server.h" ]
+ sources = [
+ "http_server.h",
+ "sha1.h",
+ ]
public_deps = [ "..:base" ]
}
diff --git a/include/perfetto/ext/base/http/http_server.h b/include/perfetto/ext/base/http/http_server.h
index da5e626..c251061 100644
--- a/include/perfetto/ext/base/http/http_server.h
+++ b/include/perfetto/ext/base/http/http_server.h
@@ -47,6 +47,7 @@
StringView uri;
StringView origin;
StringView body;
+ bool is_websocket_handshake = false;
private:
friend class HttpServer;
@@ -60,6 +61,24 @@
size_t num_headers = 0;
};
+struct WebsocketMessage {
+ explicit WebsocketMessage(HttpServerConnection* c) : conn(c) {}
+
+ HttpServerConnection* conn;
+
+ // Note: message boundaries are not respected in case of fragmentation.
+ // This websocket implementation preserves only the byte stream, but not the
+ // atomicity of inbound messages (like SOCK_STREAM, unlike SOCK_DGRAM).
+ // Holds onto the connection's |rxbuf|. This is valid only within the scope
+ // of the OnWebsocketMessage() callback.
+ StringView data;
+
+ // If false the payload contains binary data. If true it's supposed to contain
+ // text. Note that there is no guarantee this will be the case. This merely
+ // reflect the opcode that the client sets on each message.
+ bool is_text = false;
+};
+
class HttpServerConnection {
public:
static constexpr size_t kOmitContentLength = static_cast<size_t>(-1);
@@ -86,6 +105,24 @@
SendResponse(http_code, headers, content, true);
}
+ // The metods below are only valid for websocket connections.
+
+ // Upgrade an existing connection to a websocket. This can be called only in
+ // the context of OnHttpRequest(req) if req.is_websocket_handshake == true.
+ // If the origin is not in the |allowed_origins_|, the request will fail with
+ // a 403 error (this is because there is no browser-side CORS support for
+ // websockets).
+ void UpgradeToWebsocket(const HttpRequest&);
+ void SendWebsocketMessage(const void* data, size_t len);
+ void SendWebsocketMessage(StringView sv) {
+ SendWebsocketMessage(sv.data(), sv.size());
+ }
+ void SendWebsocketFrame(uint8_t opcode,
+ const void* payload,
+ size_t payload_len);
+
+ bool is_websocket() const { return is_websocket_; }
+
private:
friend class HttpServer;
@@ -94,6 +131,7 @@
std::unique_ptr<UnixSocket> sock;
PagedMemory rxbuf;
size_t rxbuf_used = 0;
+ bool is_websocket_ = false;
bool headers_sent_ = false;
size_t content_len_headers_ = 0;
size_t content_len_actual_ = 0;
@@ -112,6 +150,7 @@
public:
virtual ~HttpRequestHandler();
virtual void OnHttpRequest(const HttpRequest&) = 0;
+ virtual void OnWebsocketMessage(const WebsocketMessage&);
virtual void OnHttpConnectionClosed(HttpServerConnection*);
};
@@ -124,6 +163,7 @@
private:
size_t ParseOneHttpRequest(HttpServerConnection*);
+ size_t ParseOneWebsocketFrame(HttpServerConnection*);
void HandleCorsPreflightRequest(const HttpRequest&);
bool IsOriginAllowed(StringView);
diff --git a/include/perfetto/ext/base/http/sha1.h b/include/perfetto/ext/base/http/sha1.h
new file mode 100644
index 0000000..c583d69
--- /dev/null
+++ b/include/perfetto/ext/base/http/sha1.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INCLUDE_PERFETTO_EXT_BASE_HTTP_SHA1_H_
+#define INCLUDE_PERFETTO_EXT_BASE_HTTP_SHA1_H_
+
+#include <stddef.h>
+
+#include <array>
+#include <string>
+
+namespace perfetto {
+namespace base {
+
+constexpr size_t kSHA1Length = 20;
+using SHA1Digest = std::array<uint8_t, kSHA1Length>;
+
+SHA1Digest SHA1Hash(const std::string& str);
+SHA1Digest SHA1Hash(const void* data, size_t size);
+
+} // namespace base
+} // namespace perfetto
+
+#endif // INCLUDE_PERFETTO_EXT_BASE_HTTP_SHA1_H_
diff --git a/src/base/http/BUILD.gn b/src/base/http/BUILD.gn
index b32f179..e086134 100644
--- a/src/base/http/BUILD.gn
+++ b/src/base/http/BUILD.gn
@@ -26,7 +26,10 @@
"../../../include/perfetto/base",
"../../../include/perfetto/ext/base/http",
]
- sources = [ "http_server.cc" ]
+ sources = [
+ "http_server.cc",
+ "sha1.cc",
+ ]
}
perfetto_unittest_source_set("unittests") {
@@ -37,5 +40,8 @@
"../../../gn:default_deps",
"../../../gn:gtest_and_gmock",
]
- sources = [ "http_server_unittest.cc" ]
+ sources = [
+ "http_server_unittest.cc",
+ "sha1_unittest.cc",
+ ]
}
diff --git a/src/base/http/http_server.cc b/src/base/http/http_server.cc
index 42263cf..17a1652 100644
--- a/src/base/http/http_server.cc
+++ b/src/base/http/http_server.cc
@@ -15,9 +15,13 @@
*/
#include "perfetto/ext/base/http/http_server.h"
+#include <cinttypes>
+
#include <vector>
+#include "perfetto/ext/base/base64.h"
#include "perfetto/ext/base/endian.h"
+#include "perfetto/ext/base/http/sha1.h"
#include "perfetto/ext/base/string_utils.h"
#include "perfetto/ext/base/string_view.h"
@@ -25,8 +29,22 @@
namespace base {
namespace {
-// 32 MiB payload + 128K for HTTP headers.
-constexpr size_t kMaxRequestSize = (32 * 1024 + 128) * 1024;
+constexpr size_t kMaxPayloadSize = 32 * 1024 * 1024;
+constexpr size_t kMaxRequestSize = kMaxPayloadSize + 4096;
+
+enum WebsocketOpcode : uint8_t {
+ kOpcodeContinuation = 0x0,
+ kOpcodeText = 0x1,
+ kOpcodeBinary = 0x2,
+ kOpcodeDataUnused = 0x3,
+ kOpcodeClose = 0x8,
+ kOpcodePing = 0x9,
+ kOpcodePong = 0xA,
+ kOpcodeControlUnused = 0xB,
+};
+
+// From https://datatracker.ietf.org/doc/html/rfc6455#section-1.3.
+constexpr char kWebsocketGuid[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
} // namespace
@@ -103,7 +121,13 @@
// At this point |rxbuf| can contain a partial HTTP request, a full one or
// more (in case of HTTP Keepalive pipelining).
for (;;) {
- size_t bytes_consumed = ParseOneHttpRequest(conn);
+ size_t bytes_consumed;
+
+ if (conn->is_websocket()) {
+ bytes_consumed = ParseOneWebsocketFrame(conn);
+ } else {
+ bytes_consumed = ParseOneHttpRequest(conn);
+ }
if (bytes_consumed == 0)
break;
@@ -176,6 +200,8 @@
conn->origin_allowed_ = hdr_value.ToStdString();
} else if (hdr_name.CaseInsensitiveEq("connection")) {
conn->keepalive_ = hdr_value.CaseInsensitiveEq("keep-alive");
+ http_req.is_websocket_handshake =
+ hdr_value.CaseInsensitiveEq("upgrade");
}
}
}
@@ -185,7 +211,8 @@
PERFETTO_CHECK(buf_view.size() <= conn->rxbuf_used);
const size_t headers_size = conn->rxbuf_used - buf_view.size();
- if (body_size + headers_size >= kMaxRequestSize) {
+ if (body_size + headers_size >= kMaxRequestSize ||
+ body_size > kMaxPayloadSize) {
conn->SendResponseAndClose("413 Payload Too Large");
return 0;
}
@@ -246,11 +273,185 @@
return false;
}
+void HttpServerConnection::UpgradeToWebsocket(const HttpRequest& req) {
+ PERFETTO_CHECK(req.is_websocket_handshake);
+
+ // |origin_allowed_| is set to the req.origin only if it's in the allowlist.
+ if (origin_allowed_.empty())
+ return SendResponseAndClose("403 Forbidden", {}, "Origin not allowed");
+
+ auto ws_ver = req.GetHeader("sec-webSocket-version").value_or(StringView());
+ auto ws_key = req.GetHeader("sec-webSocket-key").value_or(StringView());
+
+ if (!ws_ver.CaseInsensitiveEq("13"))
+ return SendResponseAndClose("505 HTTP Version Not Supported", {});
+
+ if (ws_key.size() != 24) {
+ // The nonce must be a base64 encoded 16 bytes value (24 after base64).
+ return SendResponseAndClose("400 Bad Request", {});
+ }
+
+ // From https://datatracker.ietf.org/doc/html/rfc6455#section-1.3 :
+ // For this header field, the server has to take the value (as present
+ // in the header field, e.g., the base64-encoded [RFC4648] version minus
+ // any leading and trailing whitespace) and concatenate this with the
+ // Globally Unique Identifier (GUID, [RFC4122]) "258EAFA5-E914-47DA-
+ // 95CA-C5AB0DC85B11" in string form, which is unlikely to be used by
+ // network endpoints that do not understand the WebSocket Protocol. A
+ // SHA-1 hash (160 bits) [FIPS.180-3], base64-encoded (see Section 4 of
+ // [RFC4648]), of this concatenation is then returned in the server's
+ // handshake.
+ StackString<128> signed_nonce("%.*s%s", static_cast<int>(ws_key.size()),
+ ws_key.data(), kWebsocketGuid);
+ auto digest = SHA1Hash(signed_nonce.c_str(), signed_nonce.len());
+ std::string digest_b64 = Base64Encode(digest.data(), digest.size());
+
+ StackString<128> accept_hdr("Sec-WebSocket-Accept: %s", digest_b64.c_str());
+
+ std::initializer_list<const char*> headers = {
+ "Upgrade: websocket", //
+ "Connection: Upgrade", //
+ accept_hdr.c_str(), //
+ };
+ PERFETTO_DLOG("[HTTP] Handshaking WebSocket for %.*s",
+ static_cast<int>(req.uri.size()), req.uri.data());
+ for (const char* hdr : headers)
+ PERFETTO_DLOG("> %s", hdr);
+
+ SendResponseHeaders("101 Switching Protocols", headers,
+ HttpServerConnection::kOmitContentLength);
+
+ is_websocket_ = true;
+}
+
+size_t HttpServer::ParseOneWebsocketFrame(HttpServerConnection* conn) {
+ auto* rxbuf = reinterpret_cast<uint8_t*>(conn->rxbuf.Get());
+ const size_t frame_size = conn->rxbuf_used;
+ uint8_t* rd = rxbuf;
+ uint8_t* const end = rxbuf + frame_size;
+
+ auto avail = [&] {
+ PERFETTO_CHECK(rd <= end);
+ return static_cast<size_t>(end - rd);
+ };
+
+ // From https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 :
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-------+-+-------------+-------------------------------+
+ // |F|R|R|R| opcode|M| Payload len | Extended payload length |
+ // |I|S|S|S| (4) |A| (7) | (16/64) |
+ // |N|V|V|V| |S| | (if payload len==126/127) |
+ // | |1|2|3| |K| | |
+ // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+ // | Extended payload length continued, if payload len == 127 |
+ // + - - - - - - - - - - - - - - - +-------------------------------+
+ // | |Masking-key, if MASK set to 1 |
+ // +-------------------------------+-------------------------------+
+ // | Masking-key (continued) | Payload Data |
+ // +-------------------------------- - - - - - - - - - - - - - - - +
+ // : Payload Data continued ... :
+ // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+ // | Payload Data continued ... |
+ // +---------------------------------------------------------------+
+
+ if (avail() < 2)
+ return 0; // Can't even decode the frame header. Wait for more data.
+
+ uint8_t h0 = *(rd++);
+ uint8_t h1 = *(rd++);
+ const bool fin = !!(h0 & 0x80); // This bit is set if this frame is the last
+ // data to complete this message.
+ const uint8_t opcode = h0 & 0x0F;
+
+ const bool has_mask = !!(h1 & 0x80);
+ uint64_t payload_len_u64 = (h1 & 0x7F);
+ uint8_t extended_payload_size = 0;
+ if (payload_len_u64 == 126) {
+ extended_payload_size = 2;
+ } else if (payload_len_u64 == 127) {
+ extended_payload_size = 8;
+ }
+
+ if (extended_payload_size > 0) {
+ if (avail() < extended_payload_size)
+ return 0; // Not enough data to read the extended header.
+ payload_len_u64 = 0;
+ for (uint8_t i = 0; i < extended_payload_size; ++i) {
+ payload_len_u64 <<= 8;
+ payload_len_u64 |= *(rd++);
+ }
+ }
+
+ if (payload_len_u64 >= kMaxPayloadSize) {
+ PERFETTO_ELOG("[HTTP] Websocket payload too big (%" PRIu64 " > %zu)",
+ payload_len_u64, kMaxPayloadSize);
+ conn->Close();
+ return 0;
+ }
+ const size_t payload_len = static_cast<size_t>(payload_len_u64);
+
+ if (!has_mask) {
+ // https://datatracker.ietf.org/doc/html/rfc6455#section-5.1
+ // The server MUST close the connection upon receiving a frame that is
+ // not masked.
+ PERFETTO_ELOG("[HTTP] Websocket inbound frames must be masked");
+ conn->Close();
+ return 0;
+ }
+
+ uint8_t mask[4];
+ if (avail() < sizeof(mask))
+ return 0; // Not enough data to read the masking key.
+ memcpy(mask, rd, sizeof(mask));
+ rd += sizeof(mask);
+
+ PERFETTO_DLOG(
+ "[HTTP] Websocket fin=%d opcode=%u, payload_len=%zu (avail=%zu), "
+ "mask=%02x%02x%02x%02x",
+ fin, opcode, payload_len, avail(), mask[0], mask[1], mask[2], mask[3]);
+
+ if (avail() < payload_len)
+ return 0; // Not enouh data to read the payload.
+ uint8_t* const payload_start = rd;
+
+ // Unmask the payload.
+ for (uint32_t i = 0; i < payload_len; ++i)
+ payload_start[i] ^= mask[i % sizeof(mask)];
+
+ if (opcode == kOpcodePing) {
+ PERFETTO_DLOG("[HTTP] Websocket PING");
+ conn->SendWebsocketFrame(kOpcodePong, payload_start, payload_len);
+ } else if (opcode == kOpcodeBinary || opcode == kOpcodeText ||
+ opcode == kOpcodeContinuation) {
+ // We do NOT handle fragmentation. We propagate all fragments as individual
+ // messages, breaking the message-oriented nature of websockets. We do this
+ // because in all our use cases we need only a byte stream without caring
+ // about message boundaries.
+ // If we wanted to support fragmentation, we'd have to stash
+ // kOpcodeContinuation messages in a buffer, until we FIN bit is set.
+ // When loading traces with trace processor, the messages can be up to
+ // 32MB big (SLICE_SIZE in trace_stream.ts). The double-buffering would
+ // slow down significantly trace loading with no benefits.
+ WebsocketMessage msg(conn);
+ msg.data =
+ StringView(reinterpret_cast<const char*>(payload_start), payload_len);
+ msg.is_text = opcode == kOpcodeText;
+ req_handler_->OnWebsocketMessage(msg);
+ } else if (opcode == kOpcodeClose) {
+ conn->Close();
+ } else {
+ PERFETTO_LOG("Unsupported WebSocket opcode: %d", opcode);
+ }
+ return static_cast<size_t>(rd - rxbuf) + payload_len;
+}
+
void HttpServerConnection::SendResponseHeaders(
const char* http_code,
std::initializer_list<const char*> headers,
size_t content_length) {
PERFETTO_CHECK(!headers_sent_);
+ PERFETTO_CHECK(!is_websocket_);
headers_sent_ = true;
std::vector<char> resp_hdr;
resp_hdr.reserve(512);
@@ -296,6 +497,7 @@
}
void HttpServerConnection::SendResponseBody(const void* data, size_t len) {
+ PERFETTO_CHECK(!is_websocket_);
if (data == nullptr) {
PERFETTO_DCHECK(len == 0);
return;
@@ -323,6 +525,39 @@
Close();
}
+void HttpServerConnection::SendWebsocketMessage(const void* data, size_t len) {
+ SendWebsocketFrame(kOpcodeBinary, data, len);
+}
+
+void HttpServerConnection::SendWebsocketFrame(uint8_t opcode,
+ const void* payload,
+ size_t payload_len) {
+ PERFETTO_CHECK(is_websocket_);
+
+ uint8_t hdr[10]{};
+ uint32_t hdr_len = 0;
+
+ hdr[0] = opcode | 0x80 /* FIN=1, no fragmentation */;
+ if (payload_len < 126) {
+ hdr_len = 2;
+ hdr[1] = static_cast<uint8_t>(payload_len);
+ } else if (payload_len < 0xffff) {
+ hdr_len = 4;
+ hdr[1] = 126; // Special value: Header extends for 2 bytes.
+ uint16_t len_be = HostToBE16(static_cast<uint16_t>(payload_len));
+ memcpy(&hdr[2], &len_be, sizeof(len_be));
+ } else {
+ hdr_len = 10;
+ hdr[1] = 127; // Special value: Header extends for 4 bytes.
+ uint64_t len_be = HostToBE64(payload_len);
+ memcpy(&hdr[2], &len_be, sizeof(len_be));
+ }
+
+ sock->Send(hdr, hdr_len);
+ if (payload && payload_len > 0)
+ sock->Send(payload, payload_len);
+}
+
HttpServerConnection::HttpServerConnection(std::unique_ptr<UnixSocket> s)
: sock(std::move(s)), rxbuf(PagedMemory::Allocate(kMaxRequestSize)) {}
@@ -337,6 +572,7 @@
}
HttpRequestHandler::~HttpRequestHandler() = default;
+void HttpRequestHandler::OnWebsocketMessage(const WebsocketMessage&) {}
void HttpRequestHandler::OnHttpConnectionClosed(HttpServerConnection*) {}
} // namespace base
diff --git a/src/base/http/http_server_unittest.cc b/src/base/http/http_server_unittest.cc
index 38eaa08..9726ab6 100644
--- a/src/base/http/http_server_unittest.cc
+++ b/src/base/http/http_server_unittest.cc
@@ -39,6 +39,7 @@
public:
MOCK_METHOD1(OnHttpRequest, void(const HttpRequest&));
MOCK_METHOD1(OnHttpConnectionClosed, void(HttpServerConnection*));
+ MOCK_METHOD1(OnWebsocketMessage, void(const WebsocketMessage&));
};
class HttpCli {
@@ -59,7 +60,7 @@
sock.SendStr(body);
}
- std::string RecvAndWaitConnClose() {
+ std::string Recv(size_t min_bytes) {
static int n = 0;
auto checkpoint_name = "rx_" + std::to_string(n++);
auto checkpoint = task_runner_->CreateCheckpoint(checkpoint_name);
@@ -69,15 +70,17 @@
char buf[1024]{};
auto rsize = PERFETTO_EINTR(sock.Receive(buf, sizeof(buf)));
ASSERT_GE(rsize, 0);
- if (rsize == 0)
- checkpoint();
rxbuf.append(buf, static_cast<size_t>(rsize));
+ if (rsize == 0 || (min_bytes && rxbuf.length() >= min_bytes))
+ checkpoint();
});
task_runner_->RunUntilCheckpoint(checkpoint_name);
task_runner_->RemoveFileDescriptorWatch(sock.fd());
return rxbuf;
}
+ std::string RecvAndWaitConnClose() { return Recv(0); }
+
TestTaskRunner* task_runner_;
UnixSocketRaw sock;
};
@@ -103,6 +106,7 @@
req.GetHeader("X-header").value_or("N/A").ToStdString());
EXPECT_EQ("foo",
req.GetHeader("X-header2").value_or("N/A").ToStdString());
+ EXPECT_FALSE(req.is_websocket_handshake);
req.conn->SendResponseAndClose("200 OK", {}, "<html>");
}));
EXPECT_CALL(handler_, OnHttpConnectionClosed(_)).Times(kIterations);
@@ -213,6 +217,105 @@
EXPECT_EQ(cli.RecvAndWaitConnClose(), expected_response);
}
+TEST_F(HttpServerTest, Websocket) {
+ srv_.AddAllowedOrigin("http://foo.com");
+ srv_.AddAllowedOrigin("http://websocket.com");
+ for (int rep = 0; rep < 3; rep++) {
+ HttpCli cli(&task_runner_);
+ EXPECT_CALL(handler_, OnHttpRequest(_))
+ .WillOnce(Invoke([&](const HttpRequest& req) {
+ EXPECT_EQ(req.uri.ToStdString(), "/websocket");
+ EXPECT_EQ(req.method.ToStdString(), "GET");
+ EXPECT_EQ(req.origin.ToStdString(), "http://websocket.com");
+ EXPECT_TRUE(req.is_websocket_handshake);
+ req.conn->UpgradeToWebsocket(req);
+ }));
+
+ cli.SendHttpReq({
+ "GET /websocket HTTP/1.1", //
+ "Origin: http://websocket.com", //
+ "Connection: upgrade", //
+ "Upgrade: websocket", //
+ "Sec-WebSocket-Version: 13", //
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", //
+ });
+ std::string expected_resp =
+ "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"
+ "Access-Control-Allow-Origin: http://websocket.com\r\n"
+ "Vary: Origin\r\n"
+ "\r\n";
+ EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
+
+ for (int i = 0; i < 3; i++) {
+ EXPECT_CALL(handler_, OnWebsocketMessage(_))
+ .WillOnce(Invoke([i](const WebsocketMessage& msg) {
+ EXPECT_EQ(msg.data.ToStdString(), "test message");
+ StackString<6> resp("PONG%d", i);
+ msg.conn->SendWebsocketMessage(resp.c_str(), resp.len());
+ }));
+
+ // A frame from a real tcpdump capture:
+ // 1... .... = Fin: True
+ // .000 .... = Reserved: 0x0
+ // .... 0001 = Opcode: Text (1)
+ // 1... .... = Mask: True
+ // .000 1100 = Payload length: 12
+ // Masking-Key: e17e8eb9
+ // Masked payload: "test message"
+ cli.sock.SendStr(
+ "\x81\x8c\xe1\x7e\x8e\xb9\x95\x1b\xfd\xcd\xc1\x13\xeb\xca\x92\x1f\xe9"
+ "\xdc");
+ EXPECT_EQ(cli.Recv(2 + 5), "\x82\x05PONG" + std::to_string(i));
+ }
+
+ cli.sock.Shutdown();
+ auto checkpoint_name = "ws_close_" + std::to_string(rep);
+ auto ws_close = task_runner_.CreateCheckpoint(checkpoint_name);
+ EXPECT_CALL(handler_, OnHttpConnectionClosed(_))
+ .WillOnce(InvokeWithoutArgs(ws_close));
+ task_runner_.RunUntilCheckpoint(checkpoint_name);
+ }
+}
+
+TEST_F(HttpServerTest, Websocket_OriginNotAllowed) {
+ srv_.AddAllowedOrigin("http://websocket.com");
+ srv_.AddAllowedOrigin("http://notallowed.commando");
+ srv_.AddAllowedOrigin("http://iamnotallowed.com");
+ srv_.AddAllowedOrigin("iamnotallowed.com");
+ // The origin must match in full, including scheme. This won't match.
+ srv_.AddAllowedOrigin("notallowed.com");
+
+ HttpCli cli(&task_runner_);
+ EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
+ EXPECT_CALL(handler_, OnHttpRequest(_))
+ .WillOnce(Invoke([&](const HttpRequest& req) {
+ EXPECT_EQ(req.origin.ToStdString(), "http://notallowed.com");
+ EXPECT_TRUE(req.is_websocket_handshake);
+ req.conn->UpgradeToWebsocket(req);
+ }));
+
+ cli.SendHttpReq({
+ "GET /websocket HTTP/1.1", //
+ "Origin: http://notallowed.com", //
+ "Connection: upgrade", //
+ "Upgrade: websocket", //
+ "Sec-WebSocket-Version: 13", //
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", //
+ });
+ std::string expected_resp =
+ "HTTP/1.1 403 Forbidden\r\n"
+ "Content-Length: 18\r\n"
+ "Connection: close\r\n"
+ "\r\n"
+ "Origin not allowed";
+
+ EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
+ cli.sock.Shutdown();
+}
+
} // namespace
} // namespace base
} // namespace perfetto
diff --git a/src/base/http/sha1.cc b/src/base/http/sha1.cc
new file mode 100644
index 0000000..da3f753
--- /dev/null
+++ b/src/base/http/sha1.cc
@@ -0,0 +1,242 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "perfetto/ext/base/http/sha1.h"
+
+#include <stddef.h>
+#include <stdint.h>
+#include <string.h>
+
+// From chrome_elf/sha1/sha1.cc.
+
+namespace perfetto {
+namespace base {
+
+namespace {
+
+inline uint32_t BSwap32(uint32_t x) {
+#if defined(__GNUC__)
+ return __builtin_bswap32(x);
+#elif defined(_MSC_VER)
+ return _byteswap_ulong(val);
+#else
+ return (((x & 0xff000000u) >> 24) | ((x & 0x00ff0000u) >> 8) |
+ ((x & 0x0000ff00u) << 8) | ((x & 0x000000ffu) << 24));
+#endif
+}
+
+// Usage example:
+//
+// SecureHashAlgorithm sha;
+// while(there is data to hash)
+// sha.Update(moredata, size of data);
+// sha.Final();
+// memcpy(somewhere, sha.Digest(), 20);
+//
+// to reuse the instance of sha, call sha.Init();
+class SecureHashAlgorithm {
+ public:
+ SecureHashAlgorithm() { Init(); }
+
+ void Init();
+ void Update(const void* data, size_t nbytes);
+ void Final();
+
+ // 20 bytes of message digest.
+ const unsigned char* Digest() const {
+ return reinterpret_cast<const unsigned char*>(H);
+ }
+
+ private:
+ void Pad();
+ void Process();
+
+ uint32_t A, B, C, D, E;
+
+ uint32_t H[5];
+
+ union {
+ uint32_t W[80];
+ uint8_t M[64];
+ };
+
+ uint32_t cursor;
+ uint64_t l;
+};
+
+//------------------------------------------------------------------------------
+// Private functions
+//------------------------------------------------------------------------------
+
+// Identifier names follow notation in FIPS PUB 180-3, where you'll
+// also find a description of the algorithm:
+// http://csrc.nist.gov/publications/fips/fips180-3/fips180-3_final.pdf
+
+inline uint32_t f(uint32_t t, uint32_t B, uint32_t C, uint32_t D) {
+ if (t < 20) {
+ return (B & C) | ((~B) & D);
+ } else if (t < 40) {
+ return B ^ C ^ D;
+ } else if (t < 60) {
+ return (B & C) | (B & D) | (C & D);
+ } else {
+ return B ^ C ^ D;
+ }
+}
+
+inline uint32_t S(uint32_t n, uint32_t X) {
+ return (X << n) | (X >> (32 - n));
+}
+
+inline uint32_t K(uint32_t t) {
+ if (t < 20) {
+ return 0x5a827999;
+ } else if (t < 40) {
+ return 0x6ed9eba1;
+ } else if (t < 60) {
+ return 0x8f1bbcdc;
+ } else {
+ return 0xca62c1d6;
+ }
+}
+
+void SecureHashAlgorithm::Init() {
+ A = 0;
+ B = 0;
+ C = 0;
+ D = 0;
+ E = 0;
+ cursor = 0;
+ l = 0;
+ H[0] = 0x67452301;
+ H[1] = 0xefcdab89;
+ H[2] = 0x98badcfe;
+ H[3] = 0x10325476;
+ H[4] = 0xc3d2e1f0;
+}
+
+void SecureHashAlgorithm::Update(const void* data, size_t nbytes) {
+ const uint8_t* d = reinterpret_cast<const uint8_t*>(data);
+ while (nbytes--) {
+ M[cursor++] = *d++;
+ if (cursor >= 64)
+ Process();
+ l += 8;
+ }
+}
+
+void SecureHashAlgorithm::Final() {
+ Pad();
+ Process();
+
+ for (size_t t = 0; t < 5; ++t)
+ H[t] = BSwap32(H[t]);
+}
+
+void SecureHashAlgorithm::Process() {
+ uint32_t t;
+
+ // Each a...e corresponds to a section in the FIPS 180-3 algorithm.
+
+ // a.
+ //
+ // W and M are in a union, so no need to memcpy.
+ // memcpy(W, M, sizeof(M));
+ for (t = 0; t < 16; ++t)
+ W[t] = BSwap32(W[t]);
+
+ // b.
+ for (t = 16; t < 80; ++t)
+ W[t] = S(1, W[t - 3] ^ W[t - 8] ^ W[t - 14] ^ W[t - 16]);
+
+ // c.
+ A = H[0];
+ B = H[1];
+ C = H[2];
+ D = H[3];
+ E = H[4];
+
+ // d.
+ for (t = 0; t < 80; ++t) {
+ uint32_t TEMP = S(5, A) + f(t, B, C, D) + E + W[t] + K(t);
+ E = D;
+ D = C;
+ C = S(30, B);
+ B = A;
+ A = TEMP;
+ }
+
+ // e.
+ H[0] += A;
+ H[1] += B;
+ H[2] += C;
+ H[3] += D;
+ H[4] += E;
+
+ cursor = 0;
+}
+
+void SecureHashAlgorithm::Pad() {
+ M[cursor++] = 0x80;
+
+ if (cursor > 64 - 8) {
+ // pad out to next block
+ while (cursor < 64)
+ M[cursor++] = 0;
+
+ Process();
+ }
+
+ while (cursor < 64 - 8)
+ M[cursor++] = 0;
+
+ M[cursor++] = (l >> 56) & 0xff;
+ M[cursor++] = (l >> 48) & 0xff;
+ M[cursor++] = (l >> 40) & 0xff;
+ M[cursor++] = (l >> 32) & 0xff;
+ M[cursor++] = (l >> 24) & 0xff;
+ M[cursor++] = (l >> 16) & 0xff;
+ M[cursor++] = (l >> 8) & 0xff;
+ M[cursor++] = l & 0xff;
+}
+
+// Computes the SHA-1 hash of the |len| bytes in |data| and puts the hash
+// in |hash|. |hash| must be kSHA1Length bytes long.
+void SHA1HashBytes(const unsigned char* data, size_t len, unsigned char* hash) {
+ SecureHashAlgorithm sha;
+ sha.Update(data, len);
+ sha.Final();
+
+ ::memcpy(hash, sha.Digest(), kSHA1Length);
+}
+
+} // namespace
+
+//------------------------------------------------------------------------------
+// Public functions
+//------------------------------------------------------------------------------
+SHA1Digest SHA1Hash(const void* data, size_t size) {
+ SHA1Digest digest;
+ SHA1HashBytes(static_cast<const unsigned char*>(data), size,
+ reinterpret_cast<unsigned char*>(&digest[0]));
+ return digest;
+}
+
+SHA1Digest SHA1Hash(const std::string& str) {
+ return SHA1Hash(str.data(), str.size());
+}
+
+} // namespace base
+} // namespace perfetto
diff --git a/src/base/http/sha1_unittest.cc b/src/base/http/sha1_unittest.cc
new file mode 100644
index 0000000..269afc1
--- /dev/null
+++ b/src/base/http/sha1_unittest.cc
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "perfetto/ext/base/http/sha1.h"
+
+#include <string>
+
+#include "perfetto/ext/base/string_view.h"
+#include "test/gtest_and_gmock.h"
+
+namespace perfetto {
+namespace base {
+namespace {
+
+using testing::ElementsAreArray;
+
+TEST(SHA1Test, Hash) {
+ EXPECT_THAT(SHA1Hash(""), ElementsAreArray<uint8_t>(
+ {0xda, 0x39, 0xa3, 0xee, 0x5e, 0x6b, 0x4b,
+ 0x0d, 0x32, 0x55, 0xbf, 0xef, 0x95, 0x60,
+ 0x18, 0x90, 0xaf, 0xd8, 0x07, 0x09}));
+
+ EXPECT_THAT(SHA1Hash("abc"), ElementsAreArray<uint8_t>(
+ {0xa9, 0x99, 0x3e, 0x36, 0x47, 0x06, 0x81,
+ 0x6a, 0xba, 0x3e, 0x25, 0x71, 0x78, 0x50,
+ 0xc2, 0x6c, 0x9c, 0xd0, 0xd8, 0x9d}));
+
+ EXPECT_THAT(
+ SHA1Hash("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"),
+ ElementsAreArray<uint8_t>({0x84, 0x98, 0x3e, 0x44, 0x1c, 0x3b, 0xd2,
+ 0x6e, 0xba, 0xae, 0x4a, 0xa1, 0xf9, 0x51,
+ 0x29, 0xe5, 0xe5, 0x46, 0x70, 0xf1}));
+
+ EXPECT_THAT(
+ SHA1Hash(std::string(1000000, 'a')),
+ ElementsAreArray<uint8_t>({0x34, 0xaa, 0x97, 0x3c, 0xd4, 0xc4, 0xda,
+ 0xa4, 0xf6, 0x1e, 0xeb, 0x2b, 0xdb, 0xad,
+ 0x27, 0x31, 0x65, 0x34, 0x01, 0x6f}));
+}
+
+} // namespace
+} // namespace base
+} // namespace perfetto