mirror of
https://github.com/denoland/deno.git
synced 2024-12-22 15:24:46 -05:00
feat(ext/http): h2c for http/2 (#18817)
This implements HTTP/2 prior-knowledge connections, allowing clients to request HTTP/2 over plaintext or TLS-without-ALPN connections. If a client requests a specific protocol via ALPN (`h2` or `http/1.1`), however, the protocol is forced and must be used.
This commit is contained in:
parent
0e97fa4d5f
commit
bb74e75a04
5 changed files with 520 additions and 38 deletions
|
@ -15,6 +15,7 @@ import {
|
|||
deferred,
|
||||
fail,
|
||||
} from "./test_util.ts";
|
||||
import { consoleSize } from "../../../runtime/js/40_tty.js";
|
||||
|
||||
function createOnErrorCb(ac: AbortController): (err: unknown) => Response {
|
||||
return (err) => {
|
||||
|
@ -2709,3 +2710,80 @@ function isProhibitedForTrailer(key: string): boolean {
|
|||
const s = new Set(["transfer-encoding", "content-length", "trailer"]);
|
||||
return s.has(key.toLowerCase());
|
||||
}
|
||||
|
||||
Deno.test(
|
||||
{ permissions: { net: true, run: true } },
|
||||
async function httpServeCurlH2C() {
|
||||
const ac = new AbortController();
|
||||
const server = Deno.serve(
|
||||
() => new Response("hello world!"),
|
||||
{ signal: ac.signal },
|
||||
);
|
||||
|
||||
assertEquals(
|
||||
"hello world!",
|
||||
await curlRequest(["http://localhost:8000/path"]),
|
||||
);
|
||||
assertEquals(
|
||||
"hello world!",
|
||||
await curlRequest(["http://localhost:8000/path", "--http2"]),
|
||||
);
|
||||
assertEquals(
|
||||
"hello world!",
|
||||
await curlRequest([
|
||||
"http://localhost:8000/path",
|
||||
"--http2",
|
||||
"--http2-prior-knowledge",
|
||||
]),
|
||||
);
|
||||
|
||||
ac.abort();
|
||||
await server;
|
||||
},
|
||||
);
|
||||
|
||||
Deno.test(
|
||||
{ permissions: { net: true, run: true, read: true } },
|
||||
async function httpsServeCurlH2C() {
|
||||
const ac = new AbortController();
|
||||
const server = Deno.serve(
|
||||
() => new Response("hello world!"),
|
||||
{
|
||||
signal: ac.signal,
|
||||
cert: Deno.readTextFileSync("cli/tests/testdata/tls/localhost.crt"),
|
||||
key: Deno.readTextFileSync("cli/tests/testdata/tls/localhost.key"),
|
||||
},
|
||||
);
|
||||
|
||||
assertEquals(
|
||||
"hello world!",
|
||||
await curlRequest(["https://localhost:9000/path", "-k"]),
|
||||
);
|
||||
assertEquals(
|
||||
"hello world!",
|
||||
await curlRequest(["https://localhost:9000/path", "-k", "--http2"]),
|
||||
);
|
||||
assertEquals(
|
||||
"hello world!",
|
||||
await curlRequest([
|
||||
"https://localhost:9000/path",
|
||||
"-k",
|
||||
"--http2",
|
||||
"--http2-prior-knowledge",
|
||||
]),
|
||||
);
|
||||
|
||||
ac.abort();
|
||||
await server;
|
||||
},
|
||||
);
|
||||
|
||||
async function curlRequest(args: string[]) {
|
||||
const { success, stdout } = await new Deno.Command("curl", {
|
||||
args,
|
||||
stdout: "piped",
|
||||
stderr: "null",
|
||||
}).output();
|
||||
assert(success);
|
||||
return new TextDecoder().decode(stdout);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
|
||||
use crate::extract_network_stream;
|
||||
use crate::network_buffered_stream::NetworkStreamPrefixCheck;
|
||||
use crate::request_body::HttpRequestBody;
|
||||
use crate::request_properties::DefaultHttpRequestProperties;
|
||||
use crate::request_properties::HttpConnectionProperties;
|
||||
|
@ -36,6 +37,7 @@ use hyper1::http::HeaderValue;
|
|||
use hyper1::server::conn::http1;
|
||||
use hyper1::server::conn::http2;
|
||||
use hyper1::service::service_fn;
|
||||
use hyper1::service::HttpService;
|
||||
use hyper1::upgrade::OnUpgrade;
|
||||
use hyper1::StatusCode;
|
||||
use pin_project::pin_project;
|
||||
|
@ -56,6 +58,37 @@ use tokio::task::JoinHandle;
|
|||
type Request = hyper1::Request<Incoming>;
|
||||
type Response = hyper1::Response<ResponseBytes>;
|
||||
|
||||
/// All HTTP/2 connections start with this byte string.
|
||||
///
|
||||
/// In HTTP/2, each endpoint is required to send a connection preface as a final confirmation
|
||||
/// of the protocol in use and to establish the initial settings for the HTTP/2 connection. The
|
||||
/// client and server each send a different connection preface.
|
||||
///
|
||||
/// The client connection preface starts with a sequence of 24 octets, which in hex notation is:
|
||||
///
|
||||
/// 0x505249202a20485454502f322e300d0a0d0a534d0d0a0d0a
|
||||
///
|
||||
/// That is, the connection preface starts with the string PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n). This sequence
|
||||
/// MUST be followed by a SETTINGS frame (Section 6.5), which MAY be empty.
|
||||
const HTTP2_PREFIX: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
|
||||
|
||||
/// ALPN negotation for "h2"
|
||||
const TLS_ALPN_HTTP_2: &[u8] = b"h2";
|
||||
|
||||
/// ALPN negotation for "http/1.1"
|
||||
const TLS_ALPN_HTTP_11: &[u8] = b"http/1.1";
|
||||
|
||||
/// Name a trait for streams we can serve HTTP over.
|
||||
trait HttpServeStream:
|
||||
tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static
|
||||
{
|
||||
}
|
||||
impl<
|
||||
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
> HttpServeStream for S
|
||||
{
|
||||
}
|
||||
|
||||
pub struct HttpSlabRecord {
|
||||
request_info: HttpConnectionProperties,
|
||||
request_parts: Parts,
|
||||
|
@ -514,6 +547,44 @@ impl<F: Future<Output = ()>> Future for SlabFuture<F> {
|
|||
}
|
||||
}
|
||||
|
||||
fn serve_http11_unconditional(
|
||||
io: impl HttpServeStream,
|
||||
svc: impl HttpService<Incoming, ResBody = ResponseBytes> + 'static,
|
||||
cancel: RcRef<CancelHandle>,
|
||||
) -> impl Future<Output = Result<(), AnyError>> + 'static {
|
||||
let conn = http1::Builder::new()
|
||||
.keep_alive(true)
|
||||
.serve_connection(io, svc);
|
||||
|
||||
conn
|
||||
.with_upgrades()
|
||||
.map_err(AnyError::from)
|
||||
.try_or_cancel(cancel)
|
||||
}
|
||||
|
||||
fn serve_http2_unconditional(
|
||||
io: impl HttpServeStream,
|
||||
svc: impl HttpService<Incoming, ResBody = ResponseBytes> + 'static,
|
||||
cancel: RcRef<CancelHandle>,
|
||||
) -> impl Future<Output = Result<(), AnyError>> + 'static {
|
||||
let conn = http2::Builder::new(LocalExecutor).serve_connection(io, svc);
|
||||
conn.map_err(AnyError::from).try_or_cancel(cancel)
|
||||
}
|
||||
|
||||
async fn serve_http2_autodetect(
|
||||
io: impl HttpServeStream,
|
||||
svc: impl HttpService<Incoming, ResBody = ResponseBytes> + 'static,
|
||||
cancel: RcRef<CancelHandle>,
|
||||
) -> Result<(), AnyError> {
|
||||
let prefix = NetworkStreamPrefixCheck::new(io, HTTP2_PREFIX);
|
||||
let (matches, io) = prefix.match_prefix().await?;
|
||||
if matches {
|
||||
serve_http2_unconditional(io, svc, cancel).await
|
||||
} else {
|
||||
serve_http11_unconditional(io, svc, cancel).await
|
||||
}
|
||||
}
|
||||
|
||||
fn serve_https(
|
||||
mut io: TlsStream,
|
||||
request_info: HttpConnectionProperties,
|
||||
|
@ -526,28 +597,21 @@ fn serve_https(
|
|||
});
|
||||
spawn_local(async {
|
||||
io.handshake().await?;
|
||||
// If the client specifically negotiates a protocol, we will use it. If not, we'll auto-detect
|
||||
// based on the prefix bytes
|
||||
let handshake = io.get_ref().1.alpn_protocol();
|
||||
// h2
|
||||
if handshake == Some(&[104, 50]) {
|
||||
let conn = http2::Builder::new(LocalExecutor).serve_connection(io, svc);
|
||||
|
||||
conn.map_err(AnyError::from).try_or_cancel(cancel).await
|
||||
if handshake == Some(TLS_ALPN_HTTP_2) {
|
||||
serve_http2_unconditional(io, svc, cancel).await
|
||||
} else if handshake == Some(TLS_ALPN_HTTP_11) {
|
||||
serve_http11_unconditional(io, svc, cancel).await
|
||||
} else {
|
||||
let conn = http1::Builder::new()
|
||||
.keep_alive(true)
|
||||
.serve_connection(io, svc);
|
||||
|
||||
conn
|
||||
.with_upgrades()
|
||||
.map_err(AnyError::from)
|
||||
.try_or_cancel(cancel)
|
||||
.await
|
||||
serve_http2_autodetect(io, svc, cancel).await
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn serve_http(
|
||||
io: impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
io: impl HttpServeStream,
|
||||
request_info: HttpConnectionProperties,
|
||||
cancel: RcRef<CancelHandle>,
|
||||
tx: tokio::sync::mpsc::Sender<usize>,
|
||||
|
@ -556,16 +620,7 @@ fn serve_http(
|
|||
let svc = service_fn(move |req: Request| {
|
||||
new_slab_future(req, request_info.clone(), tx.clone())
|
||||
});
|
||||
spawn_local(async {
|
||||
let conn = http1::Builder::new()
|
||||
.keep_alive(true)
|
||||
.serve_connection(io, svc);
|
||||
conn
|
||||
.with_upgrades()
|
||||
.map_err(AnyError::from)
|
||||
.try_or_cancel(cancel)
|
||||
.await
|
||||
})
|
||||
spawn_local(serve_http2_autodetect(io, svc, cancel))
|
||||
}
|
||||
|
||||
fn serve_http_on(
|
||||
|
@ -702,7 +757,7 @@ pub fn op_serve_http_on(
|
|||
AsyncRefCell::new(rx),
|
||||
));
|
||||
|
||||
let handle = serve_http_on(
|
||||
let handle: JoinHandle<Result<(), deno_core::anyhow::Error>> = serve_http_on(
|
||||
network_stream,
|
||||
&listen_properties,
|
||||
resource.cancel_handle(),
|
||||
|
|
|
@ -73,11 +73,13 @@ use tokio::io::AsyncWriteExt;
|
|||
use tokio::task::spawn_local;
|
||||
use websocket_upgrade::WebSocketUpgrade;
|
||||
|
||||
use crate::network_buffered_stream::NetworkBufferedStream;
|
||||
use crate::reader_stream::ExternallyAbortableReaderStream;
|
||||
use crate::reader_stream::ShutdownHandle;
|
||||
|
||||
pub mod compressible;
|
||||
mod http_next;
|
||||
mod network_buffered_stream;
|
||||
mod reader_stream;
|
||||
mod request_body;
|
||||
mod request_properties;
|
||||
|
@ -1251,22 +1253,66 @@ impl CanDowncastUpgrade for hyper::upgrade::Upgraded {
|
|||
}
|
||||
}
|
||||
|
||||
fn maybe_extract_network_stream<
|
||||
T: Into<NetworkStream> + AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
U: CanDowncastUpgrade,
|
||||
>(
|
||||
upgraded: U,
|
||||
) -> Result<(NetworkStream, Bytes), U> {
|
||||
let upgraded = match upgraded.downcast::<T>() {
|
||||
Ok((stream, bytes)) => return Ok((stream.into(), bytes)),
|
||||
Err(x) => x,
|
||||
};
|
||||
|
||||
match upgraded.downcast::<NetworkBufferedStream<T>>() {
|
||||
Ok((stream, upgraded_bytes)) => {
|
||||
// Both the upgrade and the stream might have unread bytes
|
||||
let (io, stream_bytes) = stream.into_inner();
|
||||
let bytes = match (stream_bytes.is_empty(), upgraded_bytes.is_empty()) {
|
||||
(false, false) => Bytes::default(),
|
||||
(true, false) => upgraded_bytes,
|
||||
(false, true) => stream_bytes,
|
||||
(true, true) => {
|
||||
// The upgraded bytes come first as they have already been read
|
||||
let mut v = upgraded_bytes.to_vec();
|
||||
v.append(&mut stream_bytes.to_vec());
|
||||
Bytes::from(v)
|
||||
}
|
||||
};
|
||||
Ok((io.into(), bytes))
|
||||
}
|
||||
Err(x) => Err(x),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_network_stream<U: CanDowncastUpgrade>(
|
||||
upgraded: U,
|
||||
) -> (NetworkStream, Bytes) {
|
||||
let upgraded = match upgraded.downcast::<tokio::net::TcpStream>() {
|
||||
Ok((stream, bytes)) => return (NetworkStream::Tcp(stream), bytes),
|
||||
Err(x) => x,
|
||||
};
|
||||
let upgraded = match upgraded.downcast::<deno_net::ops_tls::TlsStream>() {
|
||||
Ok((stream, bytes)) => return (NetworkStream::Tls(stream), bytes),
|
||||
Err(x) => x,
|
||||
};
|
||||
let upgraded =
|
||||
match maybe_extract_network_stream::<tokio::net::TcpStream, _>(upgraded) {
|
||||
Ok(res) => return res,
|
||||
Err(x) => x,
|
||||
};
|
||||
let upgraded =
|
||||
match maybe_extract_network_stream::<deno_net::ops_tls::TlsStream, _>(
|
||||
upgraded,
|
||||
) {
|
||||
Ok(res) => return res,
|
||||
Err(x) => x,
|
||||
};
|
||||
#[cfg(unix)]
|
||||
let upgraded = match upgraded.downcast::<tokio::net::UnixStream>() {
|
||||
Ok((stream, bytes)) => return (NetworkStream::Unix(stream), bytes),
|
||||
Err(x) => x,
|
||||
};
|
||||
let upgraded =
|
||||
match maybe_extract_network_stream::<tokio::net::UnixStream, _>(upgraded) {
|
||||
Ok(res) => return res,
|
||||
Err(x) => x,
|
||||
};
|
||||
let upgraded =
|
||||
match maybe_extract_network_stream::<NetworkStream, _>(upgraded) {
|
||||
Ok(res) => return res,
|
||||
Err(x) => x,
|
||||
};
|
||||
|
||||
// TODO(mmastrac): HTTP/2 websockets may yield an un-downgradable type
|
||||
drop(upgraded);
|
||||
unreachable!("unexpected stream type");
|
||||
}
|
||||
|
|
284
ext/http/network_buffered_stream.rs
Normal file
284
ext/http/network_buffered_stream.rs
Normal file
|
@ -0,0 +1,284 @@
|
|||
// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
|
||||
|
||||
use bytes::Bytes;
|
||||
use deno_core::futures::future::poll_fn;
|
||||
use deno_core::futures::ready;
|
||||
use std::io;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::pin::Pin;
|
||||
use std::task::Poll;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::io::ReadBuf;
|
||||
|
||||
const MAX_PREFIX_SIZE: usize = 256;
|
||||
|
||||
pub struct NetworkStreamPrefixCheck<S: AsyncRead + Unpin> {
|
||||
io: S,
|
||||
prefix: &'static [u8],
|
||||
buffer: [MaybeUninit<u8>; MAX_PREFIX_SIZE * 2],
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> NetworkStreamPrefixCheck<S> {
|
||||
pub fn new(io: S, prefix: &'static [u8]) -> Self {
|
||||
debug_assert!(prefix.len() < MAX_PREFIX_SIZE);
|
||||
Self {
|
||||
io,
|
||||
prefix,
|
||||
buffer: [MaybeUninit::<u8>::uninit(); MAX_PREFIX_SIZE * 2],
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a [`NetworkBufferedStream`], rewound with the bytes we read to determine what
|
||||
// type of stream this is.
|
||||
pub async fn match_prefix(
|
||||
self,
|
||||
) -> io::Result<(bool, NetworkBufferedStream<S>)> {
|
||||
let mut buffer = self.buffer;
|
||||
let mut readbuf = ReadBuf::uninit(&mut buffer);
|
||||
let mut io = self.io;
|
||||
let prefix = self.prefix;
|
||||
loop {
|
||||
enum State {
|
||||
Unknown,
|
||||
Matched,
|
||||
NotMatched,
|
||||
}
|
||||
|
||||
let state = poll_fn(|cx| {
|
||||
let filled_len = readbuf.filled().len();
|
||||
let res = ready!(Pin::new(&mut io).poll_read(cx, &mut readbuf));
|
||||
if let Err(e) = res {
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
let filled = readbuf.filled();
|
||||
let new_len = filled.len();
|
||||
if new_len == filled_len {
|
||||
// Empty read, no match
|
||||
return Poll::Ready(Ok(State::NotMatched));
|
||||
} else if new_len < prefix.len() {
|
||||
// Read less than prefix, make sure we're still matching the prefix (early exit)
|
||||
if !prefix.starts_with(filled) {
|
||||
return Poll::Ready(Ok(State::NotMatched));
|
||||
}
|
||||
} else if new_len >= prefix.len() {
|
||||
// We have enough to determine
|
||||
if filled.starts_with(prefix) {
|
||||
return Poll::Ready(Ok(State::Matched));
|
||||
} else {
|
||||
return Poll::Ready(Ok(State::NotMatched));
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(State::Unknown))
|
||||
})
|
||||
.await?;
|
||||
|
||||
match state {
|
||||
State::Unknown => continue,
|
||||
State::Matched => {
|
||||
let initialized_len = readbuf.filled().len();
|
||||
return Ok((
|
||||
true,
|
||||
NetworkBufferedStream::new(io, buffer, initialized_len),
|
||||
));
|
||||
}
|
||||
State::NotMatched => {
|
||||
let initialized_len = readbuf.filled().len();
|
||||
return Ok((
|
||||
false,
|
||||
NetworkBufferedStream::new(io, buffer, initialized_len),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NetworkBufferedStream<S: AsyncRead + Unpin> {
|
||||
io: S,
|
||||
initialized_len: usize,
|
||||
prefix_offset: usize,
|
||||
prefix: [MaybeUninit<u8>; MAX_PREFIX_SIZE * 2],
|
||||
prefix_read: bool,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> NetworkBufferedStream<S> {
|
||||
fn new(
|
||||
io: S,
|
||||
prefix: [MaybeUninit<u8>; MAX_PREFIX_SIZE * 2],
|
||||
initialized_len: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
io,
|
||||
initialized_len,
|
||||
prefix_offset: 0,
|
||||
prefix,
|
||||
prefix_read: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn current_slice(&self) -> &[u8] {
|
||||
// We trust that these bytes are initialized properly
|
||||
let slice = &self.prefix[self.prefix_offset..self.initialized_len];
|
||||
|
||||
// This guarantee comes from slice_assume_init_ref (we can't use that until it's stable)
|
||||
|
||||
// SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that
|
||||
// `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`.
|
||||
// The pointer obtained is valid since it refers to memory owned by `slice` which is a
|
||||
// reference and thus guaranteed to be valid for reads.
|
||||
|
||||
unsafe { &*(slice as *const [_] as *const [u8]) as _ }
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> (S, Bytes) {
|
||||
let bytes = Bytes::copy_from_slice(self.current_slice());
|
||||
(self.io, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> AsyncRead for NetworkBufferedStream<S> {
|
||||
// From hyper's Rewind (https://github.com/hyperium/hyper), MIT License, Copyright (c) Sean McArthur
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
if !self.prefix_read {
|
||||
let prefix = self.current_slice();
|
||||
|
||||
// If there are no remaining bytes, let the bytes get dropped.
|
||||
if !prefix.is_empty() {
|
||||
let copy_len = std::cmp::min(prefix.len(), buf.remaining());
|
||||
buf.put_slice(&prefix[..copy_len]);
|
||||
self.prefix_offset += copy_len;
|
||||
|
||||
return Poll::Ready(Ok(()));
|
||||
} else {
|
||||
self.prefix_read = true;
|
||||
}
|
||||
}
|
||||
Pin::new(&mut self.io).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite
|
||||
for NetworkBufferedStream<S>
|
||||
{
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
Pin::new(&mut self.io).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.io).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.io).poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.io.is_write_vectored()
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
struct YieldsOneByteAtATime(&'static [u8]);
|
||||
|
||||
impl AsyncRead for YieldsOneByteAtATime {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if let Some((head, tail)) = self.as_mut().0.split_first() {
|
||||
self.as_mut().0 = tail;
|
||||
let dest = buf.initialize_unfilled_to(1);
|
||||
dest[0] = *head;
|
||||
buf.advance(1);
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn test(
|
||||
io: impl AsyncRead + Unpin,
|
||||
prefix: &'static [u8],
|
||||
expect_match: bool,
|
||||
expect_string: &'static str,
|
||||
) -> io::Result<()> {
|
||||
let (matches, mut io) = NetworkStreamPrefixCheck::new(io, prefix)
|
||||
.match_prefix()
|
||||
.await?;
|
||||
assert_eq!(matches, expect_match);
|
||||
let mut s = String::new();
|
||||
Pin::new(&mut io).read_to_string(&mut s).await?;
|
||||
assert_eq!(s, expect_string);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn matches_prefix_simple() -> io::Result<()> {
|
||||
let buf = b"prefix match".as_slice();
|
||||
test(buf, b"prefix", true, "prefix match").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn matches_prefix_exact() -> io::Result<()> {
|
||||
let buf = b"prefix".as_slice();
|
||||
test(buf, b"prefix", true, "prefix").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn not_matches_prefix_simple() -> io::Result<()> {
|
||||
let buf = b"prefill match".as_slice();
|
||||
test(buf, b"prefix", false, "prefill match").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn not_matches_prefix_short() -> io::Result<()> {
|
||||
let buf = b"nope".as_slice();
|
||||
test(buf, b"prefix", false, "nope").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn not_matches_prefix_empty() -> io::Result<()> {
|
||||
let buf = b"".as_slice();
|
||||
test(buf, b"prefix", false, "").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn matches_one_byte_at_a_time() -> io::Result<()> {
|
||||
let buf = YieldsOneByteAtATime(b"prefix");
|
||||
test(buf, b"prefix", true, "prefix").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn not_matches_one_byte_at_a_time() -> io::Result<()> {
|
||||
let buf = YieldsOneByteAtATime(b"prefill");
|
||||
test(buf, b"prefix", false, "prefill").await
|
||||
}
|
||||
}
|
|
@ -30,6 +30,25 @@ pub enum NetworkStream {
|
|||
Unix(#[pin] UnixStream),
|
||||
}
|
||||
|
||||
impl From<TcpStream> for NetworkStream {
|
||||
fn from(value: TcpStream) -> Self {
|
||||
NetworkStream::Tcp(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TlsStream> for NetworkStream {
|
||||
fn from(value: TlsStream) -> Self {
|
||||
NetworkStream::Tls(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl From<UnixStream> for NetworkStream {
|
||||
fn from(value: UnixStream) -> Self {
|
||||
NetworkStream::Unix(value)
|
||||
}
|
||||
}
|
||||
|
||||
/// A raw stream of one of the types handled by this extension.
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum NetworkStreamType {
|
||||
|
|
Loading…
Reference in a new issue