base: Add websocket support to http_server.cc

Bug: 205274609
Change-Id: Ia8977e03a4ea3e758470c47ce946af60a42766ab
diff --git a/Android.bp b/Android.bp
index 1002d49..afc122e 100644
--- a/Android.bp
+++ b/Android.bp
@@ -6898,6 +6898,7 @@
     name: "perfetto_src_base_http_http",
     srcs: [
         "src/base/http/http_server.cc",
+        "src/base/http/sha1.cc",
     ],
 }
 
@@ -6906,6 +6907,7 @@
     name: "perfetto_src_base_http_unittests",
     srcs: [
         "src/base/http/http_server_unittest.cc",
+        "src/base/http/sha1_unittest.cc",
     ],
 }
 
diff --git a/BUILD b/BUILD
index 0a0b060..a0d0884 100644
--- a/BUILD
+++ b/BUILD
@@ -335,6 +335,7 @@
     name = "include_perfetto_ext_base_http_http",
     srcs = [
         "include/perfetto/ext/base/http/http_server.h",
+        "include/perfetto/ext/base/http/sha1.h",
     ],
 )
 
@@ -654,6 +655,7 @@
     name = "src_base_http_http",
     srcs = [
         "src/base/http/http_server.cc",
+        "src/base/http/sha1.cc",
     ],
     hdrs = [
         ":include_perfetto_base_base",
diff --git a/include/perfetto/ext/base/http/BUILD.gn b/include/perfetto/ext/base/http/BUILD.gn
index 78268b2..3e42607 100644
--- a/include/perfetto/ext/base/http/BUILD.gn
+++ b/include/perfetto/ext/base/http/BUILD.gn
@@ -13,6 +13,9 @@
 # limitations under the License.
 
 source_set("http") {
-  sources = [ "http_server.h" ]
+  sources = [
+    "http_server.h",
+    "sha1.h",
+  ]
   public_deps = [ "..:base" ]
 }
diff --git a/include/perfetto/ext/base/http/http_server.h b/include/perfetto/ext/base/http/http_server.h
index da5e626..c251061 100644
--- a/include/perfetto/ext/base/http/http_server.h
+++ b/include/perfetto/ext/base/http/http_server.h
@@ -47,6 +47,7 @@
   StringView uri;
   StringView origin;
   StringView body;
+  bool is_websocket_handshake = false;
 
  private:
   friend class HttpServer;
@@ -60,6 +61,24 @@
   size_t num_headers = 0;
 };
 
+struct WebsocketMessage {
+  explicit WebsocketMessage(HttpServerConnection* c) : conn(c) {}
+
+  HttpServerConnection* conn;
+
+  // Note: message boundaries are not respected in case of fragmentation.
+  // This websocket implementation preserves only the byte stream, but not the
+  // atomicity of inbound messages (like SOCK_STREAM, unlike SOCK_DGRAM).
+  // Holds onto the connection's |rxbuf|. This is valid only within the scope
+  // of the OnWebsocketMessage() callback.
+  StringView data;
+
+  // If false the payload contains binary data. If true it's supposed to contain
+  // text. Note that there is no guarantee this will be the case. This merely
+  // reflect the opcode that the client sets on each message.
+  bool is_text = false;
+};
+
 class HttpServerConnection {
  public:
   static constexpr size_t kOmitContentLength = static_cast<size_t>(-1);
@@ -86,6 +105,24 @@
     SendResponse(http_code, headers, content, true);
   }
 
+  // The metods below are only valid for websocket connections.
+
+  // Upgrade an existing connection to a websocket. This can be called only in
+  // the context of OnHttpRequest(req) if req.is_websocket_handshake == true.
+  // If the origin is not in the |allowed_origins_|, the request will fail with
+  // a 403 error (this is because there is no browser-side CORS support for
+  // websockets).
+  void UpgradeToWebsocket(const HttpRequest&);
+  void SendWebsocketMessage(const void* data, size_t len);
+  void SendWebsocketMessage(StringView sv) {
+    SendWebsocketMessage(sv.data(), sv.size());
+  }
+  void SendWebsocketFrame(uint8_t opcode,
+                          const void* payload,
+                          size_t payload_len);
+
+  bool is_websocket() const { return is_websocket_; }
+
  private:
   friend class HttpServer;
 
@@ -94,6 +131,7 @@
   std::unique_ptr<UnixSocket> sock;
   PagedMemory rxbuf;
   size_t rxbuf_used = 0;
+  bool is_websocket_ = false;
   bool headers_sent_ = false;
   size_t content_len_headers_ = 0;
   size_t content_len_actual_ = 0;
@@ -112,6 +150,7 @@
  public:
   virtual ~HttpRequestHandler();
   virtual void OnHttpRequest(const HttpRequest&) = 0;
+  virtual void OnWebsocketMessage(const WebsocketMessage&);
   virtual void OnHttpConnectionClosed(HttpServerConnection*);
 };
 
@@ -124,6 +163,7 @@
 
  private:
   size_t ParseOneHttpRequest(HttpServerConnection*);
+  size_t ParseOneWebsocketFrame(HttpServerConnection*);
   void HandleCorsPreflightRequest(const HttpRequest&);
   bool IsOriginAllowed(StringView);
 
diff --git a/include/perfetto/ext/base/http/sha1.h b/include/perfetto/ext/base/http/sha1.h
new file mode 100644
index 0000000..c583d69
--- /dev/null
+++ b/include/perfetto/ext/base/http/sha1.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INCLUDE_PERFETTO_EXT_BASE_HTTP_SHA1_H_
+#define INCLUDE_PERFETTO_EXT_BASE_HTTP_SHA1_H_
+
+#include <stddef.h>
+
+#include <array>
+#include <string>
+
+namespace perfetto {
+namespace base {
+
+constexpr size_t kSHA1Length = 20;
+using SHA1Digest = std::array<uint8_t, kSHA1Length>;
+
+SHA1Digest SHA1Hash(const std::string& str);
+SHA1Digest SHA1Hash(const void* data, size_t size);
+
+}  // namespace base
+}  // namespace perfetto
+
+#endif  // INCLUDE_PERFETTO_EXT_BASE_HTTP_SHA1_H_
diff --git a/src/base/http/BUILD.gn b/src/base/http/BUILD.gn
index b32f179..e086134 100644
--- a/src/base/http/BUILD.gn
+++ b/src/base/http/BUILD.gn
@@ -26,7 +26,10 @@
     "../../../include/perfetto/base",
     "../../../include/perfetto/ext/base/http",
   ]
-  sources = [ "http_server.cc" ]
+  sources = [
+    "http_server.cc",
+    "sha1.cc",
+  ]
 }
 
 perfetto_unittest_source_set("unittests") {
@@ -37,5 +40,8 @@
     "../../../gn:default_deps",
     "../../../gn:gtest_and_gmock",
   ]
-  sources = [ "http_server_unittest.cc" ]
+  sources = [
+    "http_server_unittest.cc",
+    "sha1_unittest.cc",
+  ]
 }
diff --git a/src/base/http/http_server.cc b/src/base/http/http_server.cc
index 42263cf..17a1652 100644
--- a/src/base/http/http_server.cc
+++ b/src/base/http/http_server.cc
@@ -15,9 +15,13 @@
  */
 #include "perfetto/ext/base/http/http_server.h"
 
+#include <cinttypes>
+
 #include <vector>
 
+#include "perfetto/ext/base/base64.h"
 #include "perfetto/ext/base/endian.h"
+#include "perfetto/ext/base/http/sha1.h"
 #include "perfetto/ext/base/string_utils.h"
 #include "perfetto/ext/base/string_view.h"
 
@@ -25,8 +29,22 @@
 namespace base {
 
 namespace {
-// 32 MiB payload + 128K for HTTP headers.
-constexpr size_t kMaxRequestSize = (32 * 1024 + 128) * 1024;
+constexpr size_t kMaxPayloadSize = 32 * 1024 * 1024;
+constexpr size_t kMaxRequestSize = kMaxPayloadSize + 4096;
+
+enum WebsocketOpcode : uint8_t {
+  kOpcodeContinuation = 0x0,
+  kOpcodeText = 0x1,
+  kOpcodeBinary = 0x2,
+  kOpcodeDataUnused = 0x3,
+  kOpcodeClose = 0x8,
+  kOpcodePing = 0x9,
+  kOpcodePong = 0xA,
+  kOpcodeControlUnused = 0xB,
+};
+
+// From https://datatracker.ietf.org/doc/html/rfc6455#section-1.3.
+constexpr char kWebsocketGuid[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
 
 }  // namespace
 
@@ -103,7 +121,13 @@
   // At this point |rxbuf| can contain a partial HTTP request, a full one or
   // more (in case of HTTP Keepalive pipelining).
   for (;;) {
-    size_t bytes_consumed = ParseOneHttpRequest(conn);
+    size_t bytes_consumed;
+
+    if (conn->is_websocket()) {
+      bytes_consumed = ParseOneWebsocketFrame(conn);
+    } else {
+      bytes_consumed = ParseOneHttpRequest(conn);
+    }
 
     if (bytes_consumed == 0)
       break;
@@ -176,6 +200,8 @@
           conn->origin_allowed_ = hdr_value.ToStdString();
       } else if (hdr_name.CaseInsensitiveEq("connection")) {
         conn->keepalive_ = hdr_value.CaseInsensitiveEq("keep-alive");
+        http_req.is_websocket_handshake =
+            hdr_value.CaseInsensitiveEq("upgrade");
       }
     }
   }
@@ -185,7 +211,8 @@
   PERFETTO_CHECK(buf_view.size() <= conn->rxbuf_used);
   const size_t headers_size = conn->rxbuf_used - buf_view.size();
 
-  if (body_size + headers_size >= kMaxRequestSize) {
+  if (body_size + headers_size >= kMaxRequestSize ||
+      body_size > kMaxPayloadSize) {
     conn->SendResponseAndClose("413 Payload Too Large");
     return 0;
   }
@@ -246,11 +273,185 @@
   return false;
 }
 
+void HttpServerConnection::UpgradeToWebsocket(const HttpRequest& req) {
+  PERFETTO_CHECK(req.is_websocket_handshake);
+
+  // |origin_allowed_| is set to the req.origin only if it's in the allowlist.
+  if (origin_allowed_.empty())
+    return SendResponseAndClose("403 Forbidden", {}, "Origin not allowed");
+
+  auto ws_ver = req.GetHeader("sec-webSocket-version").value_or(StringView());
+  auto ws_key = req.GetHeader("sec-webSocket-key").value_or(StringView());
+
+  if (!ws_ver.CaseInsensitiveEq("13"))
+    return SendResponseAndClose("505 HTTP Version Not Supported", {});
+
+  if (ws_key.size() != 24) {
+    // The nonce must be a base64 encoded 16 bytes value (24 after base64).
+    return SendResponseAndClose("400 Bad Request", {});
+  }
+
+  // From https://datatracker.ietf.org/doc/html/rfc6455#section-1.3 :
+  // For this header field, the server has to take the value (as present
+  // in the header field, e.g., the base64-encoded [RFC4648] version minus
+  // any leading and trailing whitespace) and concatenate this with the
+  // Globally Unique Identifier (GUID, [RFC4122]) "258EAFA5-E914-47DA-
+  // 95CA-C5AB0DC85B11" in string form, which is unlikely to be used by
+  // network endpoints that do not understand the WebSocket Protocol.  A
+  // SHA-1 hash (160 bits) [FIPS.180-3], base64-encoded (see Section 4 of
+  // [RFC4648]), of this concatenation is then returned in the server's
+  // handshake.
+  StackString<128> signed_nonce("%.*s%s", static_cast<int>(ws_key.size()),
+                                ws_key.data(), kWebsocketGuid);
+  auto digest = SHA1Hash(signed_nonce.c_str(), signed_nonce.len());
+  std::string digest_b64 = Base64Encode(digest.data(), digest.size());
+
+  StackString<128> accept_hdr("Sec-WebSocket-Accept: %s", digest_b64.c_str());
+
+  std::initializer_list<const char*> headers = {
+      "Upgrade: websocket",   //
+      "Connection: Upgrade",  //
+      accept_hdr.c_str(),     //
+  };
+  PERFETTO_DLOG("[HTTP] Handshaking WebSocket for %.*s",
+                static_cast<int>(req.uri.size()), req.uri.data());
+  for (const char* hdr : headers)
+    PERFETTO_DLOG("> %s", hdr);
+
+  SendResponseHeaders("101 Switching Protocols", headers,
+                      HttpServerConnection::kOmitContentLength);
+
+  is_websocket_ = true;
+}
+
+size_t HttpServer::ParseOneWebsocketFrame(HttpServerConnection* conn) {
+  auto* rxbuf = reinterpret_cast<uint8_t*>(conn->rxbuf.Get());
+  const size_t frame_size = conn->rxbuf_used;
+  uint8_t* rd = rxbuf;
+  uint8_t* const end = rxbuf + frame_size;
+
+  auto avail = [&] {
+    PERFETTO_CHECK(rd <= end);
+    return static_cast<size_t>(end - rd);
+  };
+
+  // From https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 :
+  //   0                   1                   2                   3
+  //   0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+  //  +-+-+-+-+-------+-+-------------+-------------------------------+
+  //  |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
+  //  |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
+  //  |N|V|V|V|       |S|             |   (if payload len==126/127)   |
+  //  | |1|2|3|       |K|             |                               |
+  //  +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+  //  |     Extended payload length continued, if payload len == 127  |
+  //  + - - - - - - - - - - - - - - - +-------------------------------+
+  //  |                               |Masking-key, if MASK set to 1  |
+  //  +-------------------------------+-------------------------------+
+  //  | Masking-key (continued)       |          Payload Data         |
+  //  +-------------------------------- - - - - - - - - - - - - - - - +
+  //  :                     Payload Data continued ...                :
+  //  + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+  //  |                     Payload Data continued ...                |
+  //  +---------------------------------------------------------------+
+
+  if (avail() < 2)
+    return 0;  // Can't even decode the frame header. Wait for more data.
+
+  uint8_t h0 = *(rd++);
+  uint8_t h1 = *(rd++);
+  const bool fin = !!(h0 & 0x80);  // This bit is set if this frame is the last
+                                   // data to complete this message.
+  const uint8_t opcode = h0 & 0x0F;
+
+  const bool has_mask = !!(h1 & 0x80);
+  uint64_t payload_len_u64 = (h1 & 0x7F);
+  uint8_t extended_payload_size = 0;
+  if (payload_len_u64 == 126) {
+    extended_payload_size = 2;
+  } else if (payload_len_u64 == 127) {
+    extended_payload_size = 8;
+  }
+
+  if (extended_payload_size > 0) {
+    if (avail() < extended_payload_size)
+      return 0;  // Not enough data to read the extended header.
+    payload_len_u64 = 0;
+    for (uint8_t i = 0; i < extended_payload_size; ++i) {
+      payload_len_u64 <<= 8;
+      payload_len_u64 |= *(rd++);
+    }
+  }
+
+  if (payload_len_u64 >= kMaxPayloadSize) {
+    PERFETTO_ELOG("[HTTP] Websocket payload too big (%" PRIu64 " > %zu)",
+                  payload_len_u64, kMaxPayloadSize);
+    conn->Close();
+    return 0;
+  }
+  const size_t payload_len = static_cast<size_t>(payload_len_u64);
+
+  if (!has_mask) {
+    // https://datatracker.ietf.org/doc/html/rfc6455#section-5.1
+    // The server MUST close the connection upon receiving a frame that is
+    // not masked.
+    PERFETTO_ELOG("[HTTP] Websocket inbound frames must be masked");
+    conn->Close();
+    return 0;
+  }
+
+  uint8_t mask[4];
+  if (avail() < sizeof(mask))
+    return 0;  // Not enough data to read the masking key.
+  memcpy(mask, rd, sizeof(mask));
+  rd += sizeof(mask);
+
+  PERFETTO_DLOG(
+      "[HTTP] Websocket fin=%d opcode=%u, payload_len=%zu (avail=%zu), "
+      "mask=%02x%02x%02x%02x",
+      fin, opcode, payload_len, avail(), mask[0], mask[1], mask[2], mask[3]);
+
+  if (avail() < payload_len)
+    return 0;  // Not enouh data to read the payload.
+  uint8_t* const payload_start = rd;
+
+  // Unmask the payload.
+  for (uint32_t i = 0; i < payload_len; ++i)
+    payload_start[i] ^= mask[i % sizeof(mask)];
+
+  if (opcode == kOpcodePing) {
+    PERFETTO_DLOG("[HTTP] Websocket PING");
+    conn->SendWebsocketFrame(kOpcodePong, payload_start, payload_len);
+  } else if (opcode == kOpcodeBinary || opcode == kOpcodeText ||
+             opcode == kOpcodeContinuation) {
+    // We do NOT handle fragmentation. We propagate all fragments as individual
+    // messages, breaking the message-oriented nature of websockets. We do this
+    // because in all our use cases we need only a byte stream without caring
+    // about message boundaries.
+    // If we wanted to support fragmentation, we'd have to stash
+    // kOpcodeContinuation messages in a buffer, until we FIN bit is set.
+    // When loading traces with trace processor, the messages can be up to
+    // 32MB big (SLICE_SIZE in trace_stream.ts). The double-buffering would
+    // slow down significantly trace loading with no benefits.
+    WebsocketMessage msg(conn);
+    msg.data =
+        StringView(reinterpret_cast<const char*>(payload_start), payload_len);
+    msg.is_text = opcode == kOpcodeText;
+    req_handler_->OnWebsocketMessage(msg);
+  } else if (opcode == kOpcodeClose) {
+    conn->Close();
+  } else {
+    PERFETTO_LOG("Unsupported WebSocket opcode: %d", opcode);
+  }
+  return static_cast<size_t>(rd - rxbuf) + payload_len;
+}
+
 void HttpServerConnection::SendResponseHeaders(
     const char* http_code,
     std::initializer_list<const char*> headers,
     size_t content_length) {
   PERFETTO_CHECK(!headers_sent_);
+  PERFETTO_CHECK(!is_websocket_);
   headers_sent_ = true;
   std::vector<char> resp_hdr;
   resp_hdr.reserve(512);
@@ -296,6 +497,7 @@
 }
 
 void HttpServerConnection::SendResponseBody(const void* data, size_t len) {
+  PERFETTO_CHECK(!is_websocket_);
   if (data == nullptr) {
     PERFETTO_DCHECK(len == 0);
     return;
@@ -323,6 +525,39 @@
     Close();
 }
 
+void HttpServerConnection::SendWebsocketMessage(const void* data, size_t len) {
+  SendWebsocketFrame(kOpcodeBinary, data, len);
+}
+
+void HttpServerConnection::SendWebsocketFrame(uint8_t opcode,
+                                              const void* payload,
+                                              size_t payload_len) {
+  PERFETTO_CHECK(is_websocket_);
+
+  uint8_t hdr[10]{};
+  uint32_t hdr_len = 0;
+
+  hdr[0] = opcode | 0x80 /* FIN=1, no fragmentation */;
+  if (payload_len < 126) {
+    hdr_len = 2;
+    hdr[1] = static_cast<uint8_t>(payload_len);
+  } else if (payload_len < 0xffff) {
+    hdr_len = 4;
+    hdr[1] = 126;  // Special value: Header extends for 2 bytes.
+    uint16_t len_be = HostToBE16(static_cast<uint16_t>(payload_len));
+    memcpy(&hdr[2], &len_be, sizeof(len_be));
+  } else {
+    hdr_len = 10;
+    hdr[1] = 127;  // Special value: Header extends for 4 bytes.
+    uint64_t len_be = HostToBE64(payload_len);
+    memcpy(&hdr[2], &len_be, sizeof(len_be));
+  }
+
+  sock->Send(hdr, hdr_len);
+  if (payload && payload_len > 0)
+    sock->Send(payload, payload_len);
+}
+
 HttpServerConnection::HttpServerConnection(std::unique_ptr<UnixSocket> s)
     : sock(std::move(s)), rxbuf(PagedMemory::Allocate(kMaxRequestSize)) {}
 
@@ -337,6 +572,7 @@
 }
 
 HttpRequestHandler::~HttpRequestHandler() = default;
+void HttpRequestHandler::OnWebsocketMessage(const WebsocketMessage&) {}
 void HttpRequestHandler::OnHttpConnectionClosed(HttpServerConnection*) {}
 
 }  // namespace base
diff --git a/src/base/http/http_server_unittest.cc b/src/base/http/http_server_unittest.cc
index 38eaa08..9726ab6 100644
--- a/src/base/http/http_server_unittest.cc
+++ b/src/base/http/http_server_unittest.cc
@@ -39,6 +39,7 @@
  public:
   MOCK_METHOD1(OnHttpRequest, void(const HttpRequest&));
   MOCK_METHOD1(OnHttpConnectionClosed, void(HttpServerConnection*));
+  MOCK_METHOD1(OnWebsocketMessage, void(const WebsocketMessage&));
 };
 
 class HttpCli {
@@ -59,7 +60,7 @@
     sock.SendStr(body);
   }
 
-  std::string RecvAndWaitConnClose() {
+  std::string Recv(size_t min_bytes) {
     static int n = 0;
     auto checkpoint_name = "rx_" + std::to_string(n++);
     auto checkpoint = task_runner_->CreateCheckpoint(checkpoint_name);
@@ -69,15 +70,17 @@
       char buf[1024]{};
       auto rsize = PERFETTO_EINTR(sock.Receive(buf, sizeof(buf)));
       ASSERT_GE(rsize, 0);
-      if (rsize == 0)
-        checkpoint();
       rxbuf.append(buf, static_cast<size_t>(rsize));
+      if (rsize == 0 || (min_bytes && rxbuf.length() >= min_bytes))
+        checkpoint();
     });
     task_runner_->RunUntilCheckpoint(checkpoint_name);
     task_runner_->RemoveFileDescriptorWatch(sock.fd());
     return rxbuf;
   }
 
+  std::string RecvAndWaitConnClose() { return Recv(0); }
+
   TestTaskRunner* task_runner_;
   UnixSocketRaw sock;
 };
@@ -103,6 +106,7 @@
                   req.GetHeader("X-header").value_or("N/A").ToStdString());
         EXPECT_EQ("foo",
                   req.GetHeader("X-header2").value_or("N/A").ToStdString());
+        EXPECT_FALSE(req.is_websocket_handshake);
         req.conn->SendResponseAndClose("200 OK", {}, "<html>");
       }));
   EXPECT_CALL(handler_, OnHttpConnectionClosed(_)).Times(kIterations);
@@ -213,6 +217,105 @@
   EXPECT_EQ(cli.RecvAndWaitConnClose(), expected_response);
 }
 
+TEST_F(HttpServerTest, Websocket) {
+  srv_.AddAllowedOrigin("http://foo.com");
+  srv_.AddAllowedOrigin("http://websocket.com");
+  for (int rep = 0; rep < 3; rep++) {
+    HttpCli cli(&task_runner_);
+    EXPECT_CALL(handler_, OnHttpRequest(_))
+        .WillOnce(Invoke([&](const HttpRequest& req) {
+          EXPECT_EQ(req.uri.ToStdString(), "/websocket");
+          EXPECT_EQ(req.method.ToStdString(), "GET");
+          EXPECT_EQ(req.origin.ToStdString(), "http://websocket.com");
+          EXPECT_TRUE(req.is_websocket_handshake);
+          req.conn->UpgradeToWebsocket(req);
+        }));
+
+    cli.SendHttpReq({
+        "GET /websocket HTTP/1.1",                      //
+        "Origin: http://websocket.com",                 //
+        "Connection: upgrade",                          //
+        "Upgrade: websocket",                           //
+        "Sec-WebSocket-Version: 13",                    //
+        "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",  //
+    });
+    std::string expected_resp =
+        "HTTP/1.1 101 Switching Protocols\r\n"
+        "Upgrade: websocket\r\n"
+        "Connection: Upgrade\r\n"
+        "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"
+        "Access-Control-Allow-Origin: http://websocket.com\r\n"
+        "Vary: Origin\r\n"
+        "\r\n";
+    EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
+
+    for (int i = 0; i < 3; i++) {
+      EXPECT_CALL(handler_, OnWebsocketMessage(_))
+          .WillOnce(Invoke([i](const WebsocketMessage& msg) {
+            EXPECT_EQ(msg.data.ToStdString(), "test message");
+            StackString<6> resp("PONG%d", i);
+            msg.conn->SendWebsocketMessage(resp.c_str(), resp.len());
+          }));
+
+      // A frame from a real tcpdump capture:
+      //   1... .... = Fin: True
+      //   .000 .... = Reserved: 0x0
+      //   .... 0001 = Opcode: Text (1)
+      //   1... .... = Mask: True
+      //   .000 1100 = Payload length: 12
+      //   Masking-Key: e17e8eb9
+      //   Masked payload: "test message"
+      cli.sock.SendStr(
+          "\x81\x8c\xe1\x7e\x8e\xb9\x95\x1b\xfd\xcd\xc1\x13\xeb\xca\x92\x1f\xe9"
+          "\xdc");
+      EXPECT_EQ(cli.Recv(2 + 5), "\x82\x05PONG" + std::to_string(i));
+    }
+
+    cli.sock.Shutdown();
+    auto checkpoint_name = "ws_close_" + std::to_string(rep);
+    auto ws_close = task_runner_.CreateCheckpoint(checkpoint_name);
+    EXPECT_CALL(handler_, OnHttpConnectionClosed(_))
+        .WillOnce(InvokeWithoutArgs(ws_close));
+    task_runner_.RunUntilCheckpoint(checkpoint_name);
+  }
+}
+
+TEST_F(HttpServerTest, Websocket_OriginNotAllowed) {
+  srv_.AddAllowedOrigin("http://websocket.com");
+  srv_.AddAllowedOrigin("http://notallowed.commando");
+  srv_.AddAllowedOrigin("http://iamnotallowed.com");
+  srv_.AddAllowedOrigin("iamnotallowed.com");
+  // The origin must match in full, including scheme. This won't match.
+  srv_.AddAllowedOrigin("notallowed.com");
+
+  HttpCli cli(&task_runner_);
+  EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
+  EXPECT_CALL(handler_, OnHttpRequest(_))
+      .WillOnce(Invoke([&](const HttpRequest& req) {
+        EXPECT_EQ(req.origin.ToStdString(), "http://notallowed.com");
+        EXPECT_TRUE(req.is_websocket_handshake);
+        req.conn->UpgradeToWebsocket(req);
+      }));
+
+  cli.SendHttpReq({
+      "GET /websocket HTTP/1.1",                      //
+      "Origin: http://notallowed.com",                //
+      "Connection: upgrade",                          //
+      "Upgrade: websocket",                           //
+      "Sec-WebSocket-Version: 13",                    //
+      "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",  //
+  });
+  std::string expected_resp =
+      "HTTP/1.1 403 Forbidden\r\n"
+      "Content-Length: 18\r\n"
+      "Connection: close\r\n"
+      "\r\n"
+      "Origin not allowed";
+
+  EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
+  cli.sock.Shutdown();
+}
+
 }  // namespace
 }  // namespace base
 }  // namespace perfetto
diff --git a/src/base/http/sha1.cc b/src/base/http/sha1.cc
new file mode 100644
index 0000000..da3f753
--- /dev/null
+++ b/src/base/http/sha1.cc
@@ -0,0 +1,242 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "perfetto/ext/base/http/sha1.h"
+
+#include <stddef.h>
+#include <stdint.h>
+#include <string.h>
+
+// From chrome_elf/sha1/sha1.cc.
+
+namespace perfetto {
+namespace base {
+
+namespace {
+
+inline uint32_t BSwap32(uint32_t x) {
+#if defined(__GNUC__)
+  return __builtin_bswap32(x);
+#elif defined(_MSC_VER)
+  return _byteswap_ulong(val);
+#else
+  return (((x & 0xff000000u) >> 24) | ((x & 0x00ff0000u) >> 8) |
+          ((x & 0x0000ff00u) << 8) | ((x & 0x000000ffu) << 24));
+#endif
+}
+
+// Usage example:
+//
+// SecureHashAlgorithm sha;
+// while(there is data to hash)
+//   sha.Update(moredata, size of data);
+// sha.Final();
+// memcpy(somewhere, sha.Digest(), 20);
+//
+// to reuse the instance of sha, call sha.Init();
+class SecureHashAlgorithm {
+ public:
+  SecureHashAlgorithm() { Init(); }
+
+  void Init();
+  void Update(const void* data, size_t nbytes);
+  void Final();
+
+  // 20 bytes of message digest.
+  const unsigned char* Digest() const {
+    return reinterpret_cast<const unsigned char*>(H);
+  }
+
+ private:
+  void Pad();
+  void Process();
+
+  uint32_t A, B, C, D, E;
+
+  uint32_t H[5];
+
+  union {
+    uint32_t W[80];
+    uint8_t M[64];
+  };
+
+  uint32_t cursor;
+  uint64_t l;
+};
+
+//------------------------------------------------------------------------------
+// Private functions
+//------------------------------------------------------------------------------
+
+// Identifier names follow notation in FIPS PUB 180-3, where you'll
+// also find a description of the algorithm:
+// http://csrc.nist.gov/publications/fips/fips180-3/fips180-3_final.pdf
+
+inline uint32_t f(uint32_t t, uint32_t B, uint32_t C, uint32_t D) {
+  if (t < 20) {
+    return (B & C) | ((~B) & D);
+  } else if (t < 40) {
+    return B ^ C ^ D;
+  } else if (t < 60) {
+    return (B & C) | (B & D) | (C & D);
+  } else {
+    return B ^ C ^ D;
+  }
+}
+
+inline uint32_t S(uint32_t n, uint32_t X) {
+  return (X << n) | (X >> (32 - n));
+}
+
+inline uint32_t K(uint32_t t) {
+  if (t < 20) {
+    return 0x5a827999;
+  } else if (t < 40) {
+    return 0x6ed9eba1;
+  } else if (t < 60) {
+    return 0x8f1bbcdc;
+  } else {
+    return 0xca62c1d6;
+  }
+}
+
+void SecureHashAlgorithm::Init() {
+  A = 0;
+  B = 0;
+  C = 0;
+  D = 0;
+  E = 0;
+  cursor = 0;
+  l = 0;
+  H[0] = 0x67452301;
+  H[1] = 0xefcdab89;
+  H[2] = 0x98badcfe;
+  H[3] = 0x10325476;
+  H[4] = 0xc3d2e1f0;
+}
+
+void SecureHashAlgorithm::Update(const void* data, size_t nbytes) {
+  const uint8_t* d = reinterpret_cast<const uint8_t*>(data);
+  while (nbytes--) {
+    M[cursor++] = *d++;
+    if (cursor >= 64)
+      Process();
+    l += 8;
+  }
+}
+
+void SecureHashAlgorithm::Final() {
+  Pad();
+  Process();
+
+  for (size_t t = 0; t < 5; ++t)
+    H[t] = BSwap32(H[t]);
+}
+
+void SecureHashAlgorithm::Process() {
+  uint32_t t;
+
+  // Each a...e corresponds to a section in the FIPS 180-3 algorithm.
+
+  // a.
+  //
+  // W and M are in a union, so no need to memcpy.
+  // memcpy(W, M, sizeof(M));
+  for (t = 0; t < 16; ++t)
+    W[t] = BSwap32(W[t]);
+
+  // b.
+  for (t = 16; t < 80; ++t)
+    W[t] = S(1, W[t - 3] ^ W[t - 8] ^ W[t - 14] ^ W[t - 16]);
+
+  // c.
+  A = H[0];
+  B = H[1];
+  C = H[2];
+  D = H[3];
+  E = H[4];
+
+  // d.
+  for (t = 0; t < 80; ++t) {
+    uint32_t TEMP = S(5, A) + f(t, B, C, D) + E + W[t] + K(t);
+    E = D;
+    D = C;
+    C = S(30, B);
+    B = A;
+    A = TEMP;
+  }
+
+  // e.
+  H[0] += A;
+  H[1] += B;
+  H[2] += C;
+  H[3] += D;
+  H[4] += E;
+
+  cursor = 0;
+}
+
+void SecureHashAlgorithm::Pad() {
+  M[cursor++] = 0x80;
+
+  if (cursor > 64 - 8) {
+    // pad out to next block
+    while (cursor < 64)
+      M[cursor++] = 0;
+
+    Process();
+  }
+
+  while (cursor < 64 - 8)
+    M[cursor++] = 0;
+
+  M[cursor++] = (l >> 56) & 0xff;
+  M[cursor++] = (l >> 48) & 0xff;
+  M[cursor++] = (l >> 40) & 0xff;
+  M[cursor++] = (l >> 32) & 0xff;
+  M[cursor++] = (l >> 24) & 0xff;
+  M[cursor++] = (l >> 16) & 0xff;
+  M[cursor++] = (l >> 8) & 0xff;
+  M[cursor++] = l & 0xff;
+}
+
+// Computes the SHA-1 hash of the |len| bytes in |data| and puts the hash
+// in |hash|. |hash| must be kSHA1Length bytes long.
+void SHA1HashBytes(const unsigned char* data, size_t len, unsigned char* hash) {
+  SecureHashAlgorithm sha;
+  sha.Update(data, len);
+  sha.Final();
+
+  ::memcpy(hash, sha.Digest(), kSHA1Length);
+}
+
+}  // namespace
+
+//------------------------------------------------------------------------------
+// Public functions
+//------------------------------------------------------------------------------
+SHA1Digest SHA1Hash(const void* data, size_t size) {
+  SHA1Digest digest;
+  SHA1HashBytes(static_cast<const unsigned char*>(data), size,
+                reinterpret_cast<unsigned char*>(&digest[0]));
+  return digest;
+}
+
+SHA1Digest SHA1Hash(const std::string& str) {
+  return SHA1Hash(str.data(), str.size());
+}
+
+}  // namespace base
+}  // namespace perfetto
diff --git a/src/base/http/sha1_unittest.cc b/src/base/http/sha1_unittest.cc
new file mode 100644
index 0000000..269afc1
--- /dev/null
+++ b/src/base/http/sha1_unittest.cc
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "perfetto/ext/base/http/sha1.h"
+
+#include <string>
+
+#include "perfetto/ext/base/string_view.h"
+#include "test/gtest_and_gmock.h"
+
+namespace perfetto {
+namespace base {
+namespace {
+
+using testing::ElementsAreArray;
+
+TEST(SHA1Test, Hash) {
+  EXPECT_THAT(SHA1Hash(""), ElementsAreArray<uint8_t>(
+                                {0xda, 0x39, 0xa3, 0xee, 0x5e, 0x6b, 0x4b,
+                                 0x0d, 0x32, 0x55, 0xbf, 0xef, 0x95, 0x60,
+                                 0x18, 0x90, 0xaf, 0xd8, 0x07, 0x09}));
+
+  EXPECT_THAT(SHA1Hash("abc"), ElementsAreArray<uint8_t>(
+                                   {0xa9, 0x99, 0x3e, 0x36, 0x47, 0x06, 0x81,
+                                    0x6a, 0xba, 0x3e, 0x25, 0x71, 0x78, 0x50,
+                                    0xc2, 0x6c, 0x9c, 0xd0, 0xd8, 0x9d}));
+
+  EXPECT_THAT(
+      SHA1Hash("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"),
+      ElementsAreArray<uint8_t>({0x84, 0x98, 0x3e, 0x44, 0x1c, 0x3b, 0xd2,
+                                 0x6e, 0xba, 0xae, 0x4a, 0xa1, 0xf9, 0x51,
+                                 0x29, 0xe5, 0xe5, 0x46, 0x70, 0xf1}));
+
+  EXPECT_THAT(
+      SHA1Hash(std::string(1000000, 'a')),
+      ElementsAreArray<uint8_t>({0x34, 0xaa, 0x97, 0x3c, 0xd4, 0xc4, 0xda,
+                                 0xa4, 0xf6, 0x1e, 0xeb, 0x2b, 0xdb, 0xad,
+                                 0x27, 0x31, 0x65, 0x34, 0x01, 0x6f}));
+}
+
+}  // namespace
+}  // namespace base
+}  // namespace perfetto