From bb74e75a049768c2949aa08de6752a16813b97de Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Mon, 24 Apr 2023 23:24:40 +0200 Subject: [PATCH] feat(ext/http): h2c for http/2 (#18817) This implements HTTP/2 prior-knowledge connections, allowing clients to request HTTP/2 over plaintext or TLS-without-ALPN connections. If a client requests a specific protocol via ALPN (`h2` or `http/1.1`), however, the protocol is forced and must be used. --- cli/tests/unit/serve_test.ts | 78 ++++++++ ext/http/http_next.rs | 107 ++++++++--- ext/http/lib.rs | 70 +++++-- ext/http/network_buffered_stream.rs | 284 ++++++++++++++++++++++++++++ ext/net/raw.rs | 19 ++ 5 files changed, 520 insertions(+), 38 deletions(-) create mode 100644 ext/http/network_buffered_stream.rs diff --git a/cli/tests/unit/serve_test.ts b/cli/tests/unit/serve_test.ts index 9268c7aab8..55b7c4590a 100644 --- a/cli/tests/unit/serve_test.ts +++ b/cli/tests/unit/serve_test.ts @@ -15,6 +15,7 @@ import { deferred, fail, } from "./test_util.ts"; +import { consoleSize } from "../../../runtime/js/40_tty.js"; function createOnErrorCb(ac: AbortController): (err: unknown) => Response { return (err) => { @@ -2709,3 +2710,80 @@ function isProhibitedForTrailer(key: string): boolean { const s = new Set(["transfer-encoding", "content-length", "trailer"]); return s.has(key.toLowerCase()); } + +Deno.test( + { permissions: { net: true, run: true } }, + async function httpServeCurlH2C() { + const ac = new AbortController(); + const server = Deno.serve( + () => new Response("hello world!"), + { signal: ac.signal }, + ); + + assertEquals( + "hello world!", + await curlRequest(["http://localhost:8000/path"]), + ); + assertEquals( + "hello world!", + await curlRequest(["http://localhost:8000/path", "--http2"]), + ); + assertEquals( + "hello world!", + await curlRequest([ + "http://localhost:8000/path", + "--http2", + "--http2-prior-knowledge", + ]), + ); + + ac.abort(); + await server; + }, +); + +Deno.test( + { permissions: { net: true, run: true, read: true } }, + async function httpsServeCurlH2C() { + const ac = new AbortController(); + const server = Deno.serve( + () => new Response("hello world!"), + { + signal: ac.signal, + cert: Deno.readTextFileSync("cli/tests/testdata/tls/localhost.crt"), + key: Deno.readTextFileSync("cli/tests/testdata/tls/localhost.key"), + }, + ); + + assertEquals( + "hello world!", + await curlRequest(["https://localhost:9000/path", "-k"]), + ); + assertEquals( + "hello world!", + await curlRequest(["https://localhost:9000/path", "-k", "--http2"]), + ); + assertEquals( + "hello world!", + await curlRequest([ + "https://localhost:9000/path", + "-k", + "--http2", + "--http2-prior-knowledge", + ]), + ); + + ac.abort(); + await server; + }, +); + +async function curlRequest(args: string[]) { + const { success, stdout } = await new Deno.Command("curl", { + args, + stdout: "piped", + stderr: "null", + }).output(); + assert(success); + return new TextDecoder().decode(stdout); +} diff --git a/ext/http/http_next.rs b/ext/http/http_next.rs index 47888f0a49..71f2a32b68 100644 --- a/ext/http/http_next.rs +++ b/ext/http/http_next.rs @@ -1,5 +1,6 @@ // Copyright 2018-2023 the Deno authors. All rights reserved. MIT license. use crate::extract_network_stream; +use crate::network_buffered_stream::NetworkStreamPrefixCheck; use crate::request_body::HttpRequestBody; use crate::request_properties::DefaultHttpRequestProperties; use crate::request_properties::HttpConnectionProperties; @@ -36,6 +37,7 @@ use hyper1::http::HeaderValue; use hyper1::server::conn::http1; use hyper1::server::conn::http2; use hyper1::service::service_fn; +use hyper1::service::HttpService; use hyper1::upgrade::OnUpgrade; use hyper1::StatusCode; use pin_project::pin_project; @@ -56,6 +58,37 @@ use tokio::task::JoinHandle; type Request = hyper1::Request; type Response = hyper1::Response; +/// All HTTP/2 connections start with this byte string. +/// +/// In HTTP/2, each endpoint is required to send a connection preface as a final confirmation +/// of the protocol in use and to establish the initial settings for the HTTP/2 connection. The +/// client and server each send a different connection preface. +/// +/// The client connection preface starts with a sequence of 24 octets, which in hex notation is: +/// +/// 0x505249202a20485454502f322e300d0a0d0a534d0d0a0d0a +/// +/// That is, the connection preface starts with the string PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n). This sequence +/// MUST be followed by a SETTINGS frame (Section 6.5), which MAY be empty. +const HTTP2_PREFIX: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + +/// ALPN negotation for "h2" +const TLS_ALPN_HTTP_2: &[u8] = b"h2"; + +/// ALPN negotation for "http/1.1" +const TLS_ALPN_HTTP_11: &[u8] = b"http/1.1"; + +/// Name a trait for streams we can serve HTTP over. +trait HttpServeStream: + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static +{ +} +impl< + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, + > HttpServeStream for S +{ +} + pub struct HttpSlabRecord { request_info: HttpConnectionProperties, request_parts: Parts, @@ -514,6 +547,44 @@ impl> Future for SlabFuture { } } +fn serve_http11_unconditional( + io: impl HttpServeStream, + svc: impl HttpService + 'static, + cancel: RcRef, +) -> impl Future> + 'static { + let conn = http1::Builder::new() + .keep_alive(true) + .serve_connection(io, svc); + + conn + .with_upgrades() + .map_err(AnyError::from) + .try_or_cancel(cancel) +} + +fn serve_http2_unconditional( + io: impl HttpServeStream, + svc: impl HttpService + 'static, + cancel: RcRef, +) -> impl Future> + 'static { + let conn = http2::Builder::new(LocalExecutor).serve_connection(io, svc); + conn.map_err(AnyError::from).try_or_cancel(cancel) +} + +async fn serve_http2_autodetect( + io: impl HttpServeStream, + svc: impl HttpService + 'static, + cancel: RcRef, +) -> Result<(), AnyError> { + let prefix = NetworkStreamPrefixCheck::new(io, HTTP2_PREFIX); + let (matches, io) = prefix.match_prefix().await?; + if matches { + serve_http2_unconditional(io, svc, cancel).await + } else { + serve_http11_unconditional(io, svc, cancel).await + } +} + fn serve_https( mut io: TlsStream, request_info: HttpConnectionProperties, @@ -526,28 +597,21 @@ fn serve_https( }); spawn_local(async { io.handshake().await?; + // If the client specifically negotiates a protocol, we will use it. If not, we'll auto-detect + // based on the prefix bytes let handshake = io.get_ref().1.alpn_protocol(); - // h2 - if handshake == Some(&[104, 50]) { - let conn = http2::Builder::new(LocalExecutor).serve_connection(io, svc); - - conn.map_err(AnyError::from).try_or_cancel(cancel).await + if handshake == Some(TLS_ALPN_HTTP_2) { + serve_http2_unconditional(io, svc, cancel).await + } else if handshake == Some(TLS_ALPN_HTTP_11) { + serve_http11_unconditional(io, svc, cancel).await } else { - let conn = http1::Builder::new() - .keep_alive(true) - .serve_connection(io, svc); - - conn - .with_upgrades() - .map_err(AnyError::from) - .try_or_cancel(cancel) - .await + serve_http2_autodetect(io, svc, cancel).await } }) } fn serve_http( - io: impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, + io: impl HttpServeStream, request_info: HttpConnectionProperties, cancel: RcRef, tx: tokio::sync::mpsc::Sender, @@ -556,16 +620,7 @@ fn serve_http( let svc = service_fn(move |req: Request| { new_slab_future(req, request_info.clone(), tx.clone()) }); - spawn_local(async { - let conn = http1::Builder::new() - .keep_alive(true) - .serve_connection(io, svc); - conn - .with_upgrades() - .map_err(AnyError::from) - .try_or_cancel(cancel) - .await - }) + spawn_local(serve_http2_autodetect(io, svc, cancel)) } fn serve_http_on( @@ -702,7 +757,7 @@ pub fn op_serve_http_on( AsyncRefCell::new(rx), )); - let handle = serve_http_on( + let handle: JoinHandle> = serve_http_on( network_stream, &listen_properties, resource.cancel_handle(), diff --git a/ext/http/lib.rs b/ext/http/lib.rs index 561b13885d..d5404d189a 100644 --- a/ext/http/lib.rs +++ b/ext/http/lib.rs @@ -73,11 +73,13 @@ use tokio::io::AsyncWriteExt; use tokio::task::spawn_local; use websocket_upgrade::WebSocketUpgrade; +use crate::network_buffered_stream::NetworkBufferedStream; use crate::reader_stream::ExternallyAbortableReaderStream; use crate::reader_stream::ShutdownHandle; pub mod compressible; mod http_next; +mod network_buffered_stream; mod reader_stream; mod request_body; mod request_properties; @@ -1251,22 +1253,66 @@ impl CanDowncastUpgrade for hyper::upgrade::Upgraded { } } +fn maybe_extract_network_stream< + T: Into + AsyncRead + AsyncWrite + Unpin + 'static, + U: CanDowncastUpgrade, +>( + upgraded: U, +) -> Result<(NetworkStream, Bytes), U> { + let upgraded = match upgraded.downcast::() { + Ok((stream, bytes)) => return Ok((stream.into(), bytes)), + Err(x) => x, + }; + + match upgraded.downcast::>() { + Ok((stream, upgraded_bytes)) => { + // Both the upgrade and the stream might have unread bytes + let (io, stream_bytes) = stream.into_inner(); + let bytes = match (stream_bytes.is_empty(), upgraded_bytes.is_empty()) { + (false, false) => Bytes::default(), + (true, false) => upgraded_bytes, + (false, true) => stream_bytes, + (true, true) => { + // The upgraded bytes come first as they have already been read + let mut v = upgraded_bytes.to_vec(); + v.append(&mut stream_bytes.to_vec()); + Bytes::from(v) + } + }; + Ok((io.into(), bytes)) + } + Err(x) => Err(x), + } +} + fn extract_network_stream( upgraded: U, ) -> (NetworkStream, Bytes) { - let upgraded = match upgraded.downcast::() { - Ok((stream, bytes)) => return (NetworkStream::Tcp(stream), bytes), - Err(x) => x, - }; - let upgraded = match upgraded.downcast::() { - Ok((stream, bytes)) => return (NetworkStream::Tls(stream), bytes), - Err(x) => x, - }; + let upgraded = + match maybe_extract_network_stream::(upgraded) { + Ok(res) => return res, + Err(x) => x, + }; + let upgraded = + match maybe_extract_network_stream::( + upgraded, + ) { + Ok(res) => return res, + Err(x) => x, + }; #[cfg(unix)] - let upgraded = match upgraded.downcast::() { - Ok((stream, bytes)) => return (NetworkStream::Unix(stream), bytes), - Err(x) => x, - }; + let upgraded = + match maybe_extract_network_stream::(upgraded) { + Ok(res) => return res, + Err(x) => x, + }; + let upgraded = + match maybe_extract_network_stream::(upgraded) { + Ok(res) => return res, + Err(x) => x, + }; + + // TODO(mmastrac): HTTP/2 websockets may yield an un-downgradable type drop(upgraded); unreachable!("unexpected stream type"); } diff --git a/ext/http/network_buffered_stream.rs b/ext/http/network_buffered_stream.rs new file mode 100644 index 0000000000..e4b2ee895d --- /dev/null +++ b/ext/http/network_buffered_stream.rs @@ -0,0 +1,284 @@ +// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license. + +use bytes::Bytes; +use deno_core::futures::future::poll_fn; +use deno_core::futures::ready; +use std::io; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::Poll; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; + +const MAX_PREFIX_SIZE: usize = 256; + +pub struct NetworkStreamPrefixCheck { + io: S, + prefix: &'static [u8], + buffer: [MaybeUninit; MAX_PREFIX_SIZE * 2], +} + +impl NetworkStreamPrefixCheck { + pub fn new(io: S, prefix: &'static [u8]) -> Self { + debug_assert!(prefix.len() < MAX_PREFIX_SIZE); + Self { + io, + prefix, + buffer: [MaybeUninit::::uninit(); MAX_PREFIX_SIZE * 2], + } + } + + // Returns a [`NetworkBufferedStream`], rewound with the bytes we read to determine what + // type of stream this is. + pub async fn match_prefix( + self, + ) -> io::Result<(bool, NetworkBufferedStream)> { + let mut buffer = self.buffer; + let mut readbuf = ReadBuf::uninit(&mut buffer); + let mut io = self.io; + let prefix = self.prefix; + loop { + enum State { + Unknown, + Matched, + NotMatched, + } + + let state = poll_fn(|cx| { + let filled_len = readbuf.filled().len(); + let res = ready!(Pin::new(&mut io).poll_read(cx, &mut readbuf)); + if let Err(e) = res { + return Poll::Ready(Err(e)); + } + let filled = readbuf.filled(); + let new_len = filled.len(); + if new_len == filled_len { + // Empty read, no match + return Poll::Ready(Ok(State::NotMatched)); + } else if new_len < prefix.len() { + // Read less than prefix, make sure we're still matching the prefix (early exit) + if !prefix.starts_with(filled) { + return Poll::Ready(Ok(State::NotMatched)); + } + } else if new_len >= prefix.len() { + // We have enough to determine + if filled.starts_with(prefix) { + return Poll::Ready(Ok(State::Matched)); + } else { + return Poll::Ready(Ok(State::NotMatched)); + } + } + + Poll::Ready(Ok(State::Unknown)) + }) + .await?; + + match state { + State::Unknown => continue, + State::Matched => { + let initialized_len = readbuf.filled().len(); + return Ok(( + true, + NetworkBufferedStream::new(io, buffer, initialized_len), + )); + } + State::NotMatched => { + let initialized_len = readbuf.filled().len(); + return Ok(( + false, + NetworkBufferedStream::new(io, buffer, initialized_len), + )); + } + } + } + } +} + +pub struct NetworkBufferedStream { + io: S, + initialized_len: usize, + prefix_offset: usize, + prefix: [MaybeUninit; MAX_PREFIX_SIZE * 2], + prefix_read: bool, +} + +impl NetworkBufferedStream { + fn new( + io: S, + prefix: [MaybeUninit; MAX_PREFIX_SIZE * 2], + initialized_len: usize, + ) -> Self { + Self { + io, + initialized_len, + prefix_offset: 0, + prefix, + prefix_read: false, + } + } + + fn current_slice(&self) -> &[u8] { + // We trust that these bytes are initialized properly + let slice = &self.prefix[self.prefix_offset..self.initialized_len]; + + // This guarantee comes from slice_assume_init_ref (we can't use that until it's stable) + + // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that + // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`. + // The pointer obtained is valid since it refers to memory owned by `slice` which is a + // reference and thus guaranteed to be valid for reads. + + unsafe { &*(slice as *const [_] as *const [u8]) as _ } + } + + pub fn into_inner(self) -> (S, Bytes) { + let bytes = Bytes::copy_from_slice(self.current_slice()); + (self.io, bytes) + } +} + +impl AsyncRead for NetworkBufferedStream { + // From hyper's Rewind (https://github.com/hyperium/hyper), MIT License, Copyright (c) Sean McArthur + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.prefix_read { + let prefix = self.current_slice(); + + // If there are no remaining bytes, let the bytes get dropped. + if !prefix.is_empty() { + let copy_len = std::cmp::min(prefix.len(), buf.remaining()); + buf.put_slice(&prefix[..copy_len]); + self.prefix_offset += copy_len; + + return Poll::Ready(Ok(())); + } else { + self.prefix_read = true; + } + } + Pin::new(&mut self.io).poll_read(cx, buf) + } +} + +impl AsyncWrite + for NetworkBufferedStream +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + Pin::new(&mut self.io).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.io).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + Pin::new(&mut self.io).poll_write_vectored(cx, bufs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncReadExt; + + struct YieldsOneByteAtATime(&'static [u8]); + + impl AsyncRead for YieldsOneByteAtATime { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some((head, tail)) = self.as_mut().0.split_first() { + self.as_mut().0 = tail; + let dest = buf.initialize_unfilled_to(1); + dest[0] = *head; + buf.advance(1); + } + Poll::Ready(Ok(())) + } + } + + async fn test( + io: impl AsyncRead + Unpin, + prefix: &'static [u8], + expect_match: bool, + expect_string: &'static str, + ) -> io::Result<()> { + let (matches, mut io) = NetworkStreamPrefixCheck::new(io, prefix) + .match_prefix() + .await?; + assert_eq!(matches, expect_match); + let mut s = String::new(); + Pin::new(&mut io).read_to_string(&mut s).await?; + assert_eq!(s, expect_string); + Ok(()) + } + + #[tokio::test] + async fn matches_prefix_simple() -> io::Result<()> { + let buf = b"prefix match".as_slice(); + test(buf, b"prefix", true, "prefix match").await + } + + #[tokio::test] + async fn matches_prefix_exact() -> io::Result<()> { + let buf = b"prefix".as_slice(); + test(buf, b"prefix", true, "prefix").await + } + + #[tokio::test] + async fn not_matches_prefix_simple() -> io::Result<()> { + let buf = b"prefill match".as_slice(); + test(buf, b"prefix", false, "prefill match").await + } + + #[tokio::test] + async fn not_matches_prefix_short() -> io::Result<()> { + let buf = b"nope".as_slice(); + test(buf, b"prefix", false, "nope").await + } + + #[tokio::test] + async fn not_matches_prefix_empty() -> io::Result<()> { + let buf = b"".as_slice(); + test(buf, b"prefix", false, "").await + } + + #[tokio::test] + async fn matches_one_byte_at_a_time() -> io::Result<()> { + let buf = YieldsOneByteAtATime(b"prefix"); + test(buf, b"prefix", true, "prefix").await + } + + #[tokio::test] + async fn not_matches_one_byte_at_a_time() -> io::Result<()> { + let buf = YieldsOneByteAtATime(b"prefill"); + test(buf, b"prefix", false, "prefill").await + } +} diff --git a/ext/net/raw.rs b/ext/net/raw.rs index 74cc10d630..3b50af41e0 100644 --- a/ext/net/raw.rs +++ b/ext/net/raw.rs @@ -30,6 +30,25 @@ pub enum NetworkStream { Unix(#[pin] UnixStream), } +impl From for NetworkStream { + fn from(value: TcpStream) -> Self { + NetworkStream::Tcp(value) + } +} + +impl From for NetworkStream { + fn from(value: TlsStream) -> Self { + NetworkStream::Tls(value) + } +} + +#[cfg(unix)] +impl From for NetworkStream { + fn from(value: UnixStream) -> Self { + NetworkStream::Unix(value) + } +} + /// A raw stream of one of the types handled by this extension. #[derive(Copy, Clone, PartialEq, Eq)] pub enum NetworkStreamType {