diff --git a/Cargo.lock b/Cargo.lock index f53ad5474a..6201dc138f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1420,6 +1420,7 @@ dependencies = [ "enum-as-inner", "log", "pin-project", + "rustls-tokio-stream", "serde", "socket2 0.5.4", "tokio", @@ -4460,9 +4461,9 @@ dependencies = [ [[package]] name = "rustls-tokio-stream" -version = "0.2.9" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55cae64d5219dfdd7f2d18dda421a2137ebdd63be6d0dc53d7836003f224f3d0" +checksum = "897937c68ff975d028e8cc07bc887f2d5a9ec2bc952549f40db9a91dc557974c" dependencies = [ "futures", "rustls", diff --git a/Cargo.toml b/Cargo.toml index cb2db83e67..3d1fd01c14 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -130,7 +130,7 @@ ring = "^0.17.0" rusqlite = { version = "=0.29.0", features = ["unlock_notify", "bundled"] } rustls = "0.21.8" rustls-pemfile = "1.0.0" -rustls-tokio-stream = "=0.2.9" +rustls-tokio-stream = "=0.2.16" rustls-webpki = "0.101.4" webpki-roots = "0.25.2" scopeguard = "1.2.0" diff --git a/cli/tests/integration/cert_tests.rs b/cli/tests/integration/cert_tests.rs index 20bf4d80d0..f53ce9ceef 100644 --- a/cli/tests/integration/cert_tests.rs +++ b/cli/tests/integration/cert_tests.rs @@ -250,13 +250,12 @@ async fn listen_tls_alpn() { let tcp_stream = tokio::net::TcpStream::connect("localhost:4504") .await .unwrap(); - let mut tls_stream = TlsStream::new_client_side(tcp_stream, cfg, hostname); + let mut tls_stream = + TlsStream::new_client_side(tcp_stream, cfg, hostname, None); - tls_stream.handshake().await.unwrap(); + let handshake = tls_stream.handshake().await.unwrap(); - let (_, rustls_connection) = tls_stream.get_ref(); - let alpn = rustls_connection.alpn_protocol().unwrap(); - assert_eq!(alpn, b"foobar"); + assert_eq!(handshake.alpn, Some(b"foobar".to_vec())); let status = child.wait().unwrap(); assert!(status.success()); @@ -300,13 +299,11 @@ async fn listen_tls_alpn_fail() { let tcp_stream = tokio::net::TcpStream::connect("localhost:4505") .await .unwrap(); - let mut tls_stream = TlsStream::new_client_side(tcp_stream, cfg, hostname); + let mut tls_stream = + TlsStream::new_client_side(tcp_stream, cfg, hostname, None); tls_stream.handshake().await.unwrap_err(); - let (_, rustls_connection) = tls_stream.get_ref(); - assert!(rustls_connection.alpn_protocol().is_none()); - let status = child.wait().unwrap(); assert!(status.success()); } diff --git a/cli/tests/testdata/run/websocket_test.ts b/cli/tests/testdata/run/websocket_test.ts index b759d042d1..74a369d55f 100644 --- a/cli/tests/testdata/run/websocket_test.ts +++ b/cli/tests/testdata/run/websocket_test.ts @@ -159,11 +159,9 @@ Deno.test("websocket error", async () => { ws.onopen = () => fail(); ws.onerror = (err) => { assert(err instanceof ErrorEvent); - - // Error message got changed because we don't use warp in test_util assertEquals( err.message, - "NetworkError: failed to connect to WebSocket: invalid data", + "NetworkError: failed to connect to WebSocket: received corrupt message of type InvalidContentType", ); promise1.resolve(); }; diff --git a/cli/tests/unit/tls_test.ts b/cli/tests/unit/tls_test.ts index 8162c53b56..31e24d5477 100644 --- a/cli/tests/unit/tls_test.ts +++ b/cli/tests/unit/tls_test.ts @@ -11,6 +11,7 @@ import { } from "./test_util.ts"; import { BufReader, BufWriter } from "../../../test_util/std/io/mod.ts"; import { readAll } from "../../../test_util/std/streams/read_all.ts"; +import { writeAll } from "../../../test_util/std/streams/write_all.ts"; import { TextProtoReader } from "../testdata/run/textproto.ts"; const encoder = new TextEncoder(); @@ -538,15 +539,23 @@ Deno.test( }, ); +const largeAmount = 1 << 20 /* 1 MB */; + async function sendAlotReceiveNothing(conn: Deno.Conn) { // Start receive op. const readBuf = new Uint8Array(1024); const readPromise = conn.read(readBuf); + const timeout = setTimeout(() => { + throw new Error("Failed to send buffer in a reasonable amount of time"); + }, 10_000); + // Send 1 MB of data. - const writeBuf = new Uint8Array(1 << 20 /* 1 MB */); + const writeBuf = new Uint8Array(largeAmount); writeBuf.fill(42); - await conn.write(writeBuf); + await writeAll(conn, writeBuf); + + clearTimeout(timeout); // Send EOF. await conn.closeWrite(); @@ -564,14 +573,29 @@ async function sendAlotReceiveNothing(conn: Deno.Conn) { async function receiveAlotSendNothing(conn: Deno.Conn) { const readBuf = new Uint8Array(1024); let n: number | null; + let nread = 0; + + const timeout = setTimeout(() => { + throw new Error( + `Failed to read buffer in a reasonable amount of time (got ${nread}/${largeAmount})`, + ); + }, 10_000); // Receive 1 MB of data. - for (let nread = 0; nread < 1 << 20 /* 1 MB */; nread += n!) { - n = await conn.read(readBuf); - assertStrictEquals(typeof n, "number"); - assert(n! > 0); - assertStrictEquals(readBuf[0], 42); + try { + for (; nread < largeAmount; nread += n!) { + n = await conn.read(readBuf); + assertStrictEquals(typeof n, "number"); + assert(n! > 0); + assertStrictEquals(readBuf[0], 42); + } + } catch (e) { + throw new Error( + `Got an error (${e.message}) after reading ${nread}/${largeAmount} bytes`, + { cause: e }, + ); } + clearTimeout(timeout); // Close the connection, without sending anything at all. conn.close(); @@ -623,7 +647,7 @@ async function sendReceiveEmptyBuf(conn: Deno.Conn) { await assertRejects(async () => { await conn.write(byteBuf); - }, Deno.errors.BrokenPipe); + }, Deno.errors.NotConnected); n = await conn.write(emptyBuf); assertStrictEquals(n, 0); @@ -841,13 +865,12 @@ async function tlsWithTcpFailureTestImpl( tcpForwardingInterruptPromise2.resolve(); break; case "shutdown": - // Receiving a TCP FIN packet without receiving a TLS CloseNotify - // alert is not the expected mode of operation, but it is not a - // problem either, so it should be treated as if the TLS session was - // gracefully closed. await Promise.all([ tcpConn1.closeWrite(), - await receiveEof(tlsConn1), + await assertRejects( + () => receiveEof(tlsConn1), + Deno.errors.UnexpectedEof, + ), await tlsConn1.closeWrite(), await receiveEof(tlsConn2), ]); @@ -1036,8 +1059,8 @@ function createHttpsListener(port: number): Deno.Listener { ); // Send response. - await conn.write(resHead); - await conn.write(resBody); + await writeAll(conn, resHead); + await writeAll(conn, resBody); // Close TCP connection. conn.close(); @@ -1046,12 +1069,14 @@ function createHttpsListener(port: number): Deno.Listener { } async function curl(url: string): Promise { - const { success, code, stdout } = await new Deno.Command("curl", { + const { success, code, stdout, stderr } = await new Deno.Command("curl", { args: ["--insecure", url], }).output(); if (!success) { - throw new Error(`curl ${url} failed: ${code}`); + throw new Error( + `curl ${url} failed: ${code}:\n${new TextDecoder().decode(stderr)}`, + ); } return new TextDecoder().decode(stdout); } @@ -1276,8 +1301,7 @@ Deno.test( // Begin sending a 10mb blob over the TLS connection. const whole = new Uint8Array(10 << 20); // 10mb. whole.fill(42); - const sendPromise = conn1.write(whole); - + const sendPromise = writeAll(conn1, whole); // Set up the other end to receive half of the large blob. const half = new Uint8Array(whole.byteLength / 2); const receivePromise = readFull(conn2, half); @@ -1294,7 +1318,7 @@ Deno.test( // Receive second half of large blob. Wait for the send promise and check it. assertEquals(await readFull(conn2, half), half.length); - assertEquals(await sendPromise, whole.length); + await sendPromise; await conn1.handshake(); await conn2.handshake(); @@ -1352,7 +1376,7 @@ Deno.test( await assertRejects( () => conn.handshake(), Deno.errors.InvalidData, - "UnknownIssuer", + "invalid peer certificate: UnknownIssuer", ); conn.close(); } diff --git a/ext/http/http_next.rs b/ext/http/http_next.rs index f42275b0ec..98fbd1f8bc 100644 --- a/ext/http/http_next.rs +++ b/ext/http/http_next.rs @@ -850,15 +850,15 @@ fn serve_https( }); spawn( async { - io.handshake().await?; + let handshake = io.handshake().await?; // If the client specifically negotiates a protocol, we will use it. If not, we'll auto-detect // based on the prefix bytes - let handshake = io.get_ref().1.alpn_protocol(); - if handshake == Some(TLS_ALPN_HTTP_2) { + let handshake = handshake.alpn; + if Some(TLS_ALPN_HTTP_2) == handshake.as_deref() { serve_http2_unconditional(io, svc, listen_cancel_handle) .await .map_err(|e| e.into()) - } else if handshake == Some(TLS_ALPN_HTTP_11) { + } else if Some(TLS_ALPN_HTTP_11) == handshake.as_deref() { serve_http11_unconditional(io, svc, listen_cancel_handle) .await .map_err(|e| e.into()) diff --git a/ext/net/Cargo.toml b/ext/net/Cargo.toml index 20dde87181..c9c30b9a2e 100644 --- a/ext/net/Cargo.toml +++ b/ext/net/Cargo.toml @@ -21,6 +21,7 @@ deno_tls.workspace = true enum-as-inner = "=0.5.1" log.workspace = true pin-project.workspace = true +rustls-tokio-stream.workspace = true serde.workspace = true socket2.workspace = true tokio.workspace = true diff --git a/ext/net/ops_tls.rs b/ext/net/ops_tls.rs index 26ec48fba2..8c64744328 100644 --- a/ext/net/ops_tls.rs +++ b/ext/net/ops_tls.rs @@ -14,22 +14,9 @@ use deno_core::error::generic_error; use deno_core::error::invalid_hostname; use deno_core::error::type_error; use deno_core::error::AnyError; -use deno_core::futures::future::poll_fn; -use deno_core::futures::ready; -use deno_core::futures::task::noop_waker_ref; -use deno_core::futures::task::AtomicWaker; -use deno_core::futures::task::Context; -use deno_core::futures::task::Poll; -use deno_core::futures::task::RawWaker; -use deno_core::futures::task::RawWakerVTable; -use deno_core::futures::task::Waker; use deno_core::op2; - -use deno_core::parking_lot::Mutex; -use deno_core::unsync::spawn; use deno_core::AsyncRefCell; use deno_core::AsyncResult; -use deno_core::ByteString; use deno_core::CancelHandle; use deno_core::CancelTryFuture; use deno_core::OpState; @@ -40,17 +27,13 @@ use deno_tls::create_client_config; use deno_tls::load_certs; use deno_tls::load_private_keys; use deno_tls::rustls::Certificate; -use deno_tls::rustls::ClientConfig; -use deno_tls::rustls::ClientConnection; -use deno_tls::rustls::Connection; use deno_tls::rustls::PrivateKey; use deno_tls::rustls::ServerConfig; -use deno_tls::rustls::ServerConnection; use deno_tls::rustls::ServerName; use deno_tls::SocketUse; -use io::Error; use io::Read; -use io::Write; +use rustls_tokio_stream::TlsStreamRead; +use rustls_tokio_stream::TlsStreamWrite; use serde::Deserialize; use socket2::Domain; use socket2::Socket; @@ -63,632 +46,31 @@ use std::fs::File; use std::io; use std::io::BufReader; use std::io::ErrorKind; -use std::net::SocketAddr; +use std::num::NonZeroUsize; use std::path::Path; -use std::pin::Pin; use std::rc::Rc; use std::sync::Arc; -use std::sync::Weak; -use tokio::io::AsyncRead; use tokio::io::AsyncReadExt; -use tokio::io::AsyncWrite; use tokio::io::AsyncWriteExt; -use tokio::io::ReadBuf; use tokio::net::TcpListener; use tokio::net::TcpStream; -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -enum Flow { - Handshake, - Read, - Write, -} +pub use rustls_tokio_stream::TlsStream; -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -enum State { - StreamOpen, - StreamClosed, - TlsClosing, - TlsClosed, - TcpClosed, -} - -pub struct TlsStream(Option); - -impl TlsStream { - fn new(tcp: TcpStream, mut tls: Connection) -> Self { - tls.set_buffer_limit(None); - - let inner = TlsStreamInner { - tcp, - tls, - rd_state: State::StreamOpen, - wr_state: State::StreamOpen, - }; - Self(Some(inner)) - } - - pub fn new_client_side( - tcp: TcpStream, - tls_config: Arc, - server_name: ServerName, - ) -> Self { - let tls = ClientConnection::new(tls_config, server_name).unwrap(); - Self::new(tcp, Connection::Client(tls)) - } - - pub fn new_client_side_from( - tcp: TcpStream, - connection: ClientConnection, - ) -> Self { - Self::new(tcp, Connection::Client(connection)) - } - - pub fn new_server_side( - tcp: TcpStream, - tls_config: Arc, - ) -> Self { - let tls = ServerConnection::new(tls_config).unwrap(); - Self::new(tcp, Connection::Server(tls)) - } - - pub fn new_server_side_from( - tcp: TcpStream, - connection: ServerConnection, - ) -> Self { - Self::new(tcp, Connection::Server(connection)) - } - - pub fn into_split(self) -> (ReadHalf, WriteHalf) { - let shared = Shared::new(self); - let rd = ReadHalf { - shared: shared.clone(), - }; - let wr = WriteHalf { shared }; - (rd, wr) - } - - /// Convenience method to match [`TcpStream`]. - pub fn peer_addr(&self) -> Result { - self.0.as_ref().unwrap().tcp.peer_addr() - } - - /// Convenience method to match [`TcpStream`]. - pub fn local_addr(&self) -> Result { - self.0.as_ref().unwrap().tcp.local_addr() - } - - /// Tokio-rustls compatibility: returns a reference to the underlying TCP - /// stream, and a reference to the Rustls `Connection` object. - pub fn get_ref(&self) -> (&TcpStream, &Connection) { - let inner = self.0.as_ref().unwrap(); - (&inner.tcp, &inner.tls) - } - - fn inner_mut(&mut self) -> &mut TlsStreamInner { - self.0.as_mut().unwrap() - } - - pub async fn handshake(&mut self) -> io::Result<()> { - poll_fn(|cx| self.inner_mut().poll_handshake(cx)).await - } - - fn poll_handshake(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner_mut().poll_handshake(cx) - } - - fn get_alpn_protocol(&mut self) -> Option { - self.inner_mut().tls.alpn_protocol().map(|s| s.into()) - } -} - -impl AsyncRead for TlsStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - self.inner_mut().poll_read(cx, buf) - } -} - -impl AsyncWrite for TlsStream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.inner_mut().poll_write(cx, buf) - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.inner_mut().poll_io(cx, Flow::Write) - // The underlying TCP stream does not need to be flushed. - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.inner_mut().poll_shutdown(cx) - } -} - -impl Drop for TlsStream { - fn drop(&mut self) { - let mut inner = self.0.take().unwrap(); - - let mut cx = Context::from_waker(noop_waker_ref()); - let use_linger_task = inner.poll_close(&mut cx).is_pending(); - - if use_linger_task { - spawn(poll_fn(move |cx| inner.poll_close(cx))); - } else if cfg!(debug_assertions) { - spawn(async {}); // Spawn dummy task to detect missing runtime. - } - } -} - -pub struct TlsStreamInner { - tls: Connection, - tcp: TcpStream, - rd_state: State, - wr_state: State, -} - -impl TlsStreamInner { - fn poll_io( - &mut self, - cx: &mut Context<'_>, - flow: Flow, - ) -> Poll> { - loop { - let wr_ready = loop { - match self.wr_state { - _ if self.tls.is_handshaking() && !self.tls.wants_write() => { - break true; - } - _ if self.tls.is_handshaking() => {} - State::StreamOpen if !self.tls.wants_write() => break true, - State::StreamClosed => { - // Rustls will enqueue the 'CloseNotify' alert and send it after - // flushing the data that is already in the queue. - self.tls.send_close_notify(); - self.wr_state = State::TlsClosing; - continue; - } - State::TlsClosing if !self.tls.wants_write() => { - self.wr_state = State::TlsClosed; - continue; - } - // If a 'CloseNotify' alert sent by the remote end has been received, - // shut down the underlying TCP socket. Otherwise, consider polling - // done for the moment. - State::TlsClosed if self.rd_state < State::TlsClosed => break true, - State::TlsClosed - if Pin::new(&mut self.tcp).poll_shutdown(cx)?.is_pending() => - { - break false; - } - State::TlsClosed => { - self.wr_state = State::TcpClosed; - continue; - } - State::TcpClosed => break true, - _ => {} - } - - // Write ciphertext to the TCP socket. - let mut wrapped_tcp = ImplementWriteTrait(&mut self.tcp); - match self.tls.write_tls(&mut wrapped_tcp) { - Ok(0) => {} // Wait until the socket has enough buffer space. - Ok(_) => continue, // Try to send more more data immediately. - Err(err) if err.kind() == ErrorKind::WouldBlock => unreachable!(), - Err(err) => return Poll::Ready(Err(err)), - } - - // Poll whether there is space in the socket send buffer so we can flush - // the remaining outgoing ciphertext. - if self.tcp.poll_write_ready(cx)?.is_pending() { - break false; - } - }; - - let rd_ready = loop { - // Interpret and decrypt unprocessed TLS protocol data. - let tls_state = self - .tls - .process_new_packets() - .map_err(|e| Error::new(ErrorKind::InvalidData, e))?; - - match self.rd_state { - State::TcpClosed if self.tls.is_handshaking() => { - let err = Error::new(ErrorKind::UnexpectedEof, "tls handshake eof"); - return Poll::Ready(Err(err)); - } - _ if self.tls.is_handshaking() && !self.tls.wants_read() => { - break true; - } - _ if self.tls.is_handshaking() => {} - State::StreamOpen if tls_state.plaintext_bytes_to_read() > 0 => { - break true; - } - State::StreamOpen if tls_state.peer_has_closed() => { - self.rd_state = State::TlsClosed; - continue; - } - State::StreamOpen => {} - State::StreamClosed if tls_state.plaintext_bytes_to_read() > 0 => { - // Rustls has more incoming cleartext buffered up, but the TLS - // session is closing so this data will never be processed by the - // application layer. Just like what would happen if this were a raw - // TCP stream, don't gracefully end the TLS session, but abort it. - return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); - } - State::StreamClosed => {} - State::TlsClosed if self.wr_state == State::TcpClosed => { - // Keep trying to read from the TCP connection until the remote end - // closes it gracefully. - } - State::TlsClosed => break true, - State::TcpClosed => break true, - _ => unreachable!(), - } - - // Try to read more TLS protocol data from the TCP socket. - let mut wrapped_tcp = ImplementReadTrait(&mut self.tcp); - match self.tls.read_tls(&mut wrapped_tcp) { - Ok(0) => { - self.rd_state = State::TcpClosed; - continue; - } - Ok(_) => continue, - Err(err) if err.kind() == ErrorKind::WouldBlock => {} - Err(err) => return Poll::Ready(Err(err)), - } - - // Get notified when more ciphertext becomes available to read from the - // TCP socket. - if self.tcp.poll_read_ready(cx)?.is_pending() { - break false; - } - }; - - if wr_ready { - if self.rd_state >= State::TlsClosed - && self.wr_state >= State::TlsClosed - && self.wr_state < State::TcpClosed - { - continue; - } - if self.tls.wants_write() { - continue; - } - } - - let io_ready = match flow { - _ if self.tls.is_handshaking() => false, - Flow::Handshake => true, - Flow::Read => rd_ready, - Flow::Write => wr_ready, - }; - return match io_ready { - false => Poll::Pending, - true => Poll::Ready(Ok(())), - }; - } - } - - fn poll_handshake(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.tls.is_handshaking() { - ready!(self.poll_io(cx, Flow::Handshake))?; - } - Poll::Ready(Ok(())) - } - - fn poll_read( - &mut self, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - ready!(self.poll_io(cx, Flow::Read))?; - - if self.rd_state == State::StreamOpen { - // TODO(bartlomieju): - #[allow(clippy::undocumented_unsafe_blocks)] - let buf_slice = - unsafe { &mut *(buf.unfilled_mut() as *mut [_] as *mut [u8]) }; - let bytes_read = self.tls.reader().read(buf_slice)?; - assert_ne!(bytes_read, 0); - // TODO(bartlomieju): - #[allow(clippy::undocumented_unsafe_blocks)] - unsafe { - buf.assume_init(bytes_read) - }; - buf.advance(bytes_read); - } - - Poll::Ready(Ok(())) - } - - fn poll_write( - &mut self, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if buf.is_empty() { - // Tokio-rustls compatibility: a zero byte write always succeeds. - Poll::Ready(Ok(0)) - } else if self.wr_state == State::StreamOpen { - // Flush Rustls' ciphertext send queue. - ready!(self.poll_io(cx, Flow::Write))?; - - // Copy data from `buf` to the Rustls cleartext send queue. - let bytes_written = self.tls.writer().write(buf)?; - assert_ne!(bytes_written, 0); - - // Try to flush as much ciphertext as possible. However, since we just - // handed off at least some bytes to rustls, so we can't return - // `Poll::Pending()` any more: this would tell the caller that it should - // try to send those bytes again. - let _ = self.poll_io(cx, Flow::Write)?; - - Poll::Ready(Ok(bytes_written)) - } else { - // Return error if stream has been shut down for writing. - Poll::Ready(Err(ErrorKind::BrokenPipe.into())) - } - } - - fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.wr_state == State::StreamOpen { - self.wr_state = State::StreamClosed; - } - - ready!(self.poll_io(cx, Flow::Write))?; - - // At minimum, a TLS 'CloseNotify' alert should have been sent. - assert!(self.wr_state >= State::TlsClosed); - // If we received a TLS 'CloseNotify' alert from the remote end - // already, the TCP socket should be shut down at this point. - assert!( - self.rd_state < State::TlsClosed || self.wr_state == State::TcpClosed - ); - - Poll::Ready(Ok(())) - } - - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.rd_state == State::StreamOpen { - self.rd_state = State::StreamClosed; - } - - // Wait for the handshake to complete. - ready!(self.poll_io(cx, Flow::Handshake))?; - // Send TLS 'CloseNotify' alert. - ready!(self.poll_shutdown(cx))?; - // Wait for 'CloseNotify', shut down TCP stream, wait for TCP FIN packet. - ready!(self.poll_io(cx, Flow::Read))?; - - assert_eq!(self.rd_state, State::TcpClosed); - assert_eq!(self.wr_state, State::TcpClosed); - - Poll::Ready(Ok(())) - } -} - -pub struct ReadHalf { - shared: Arc, -} - -impl ReadHalf { - pub fn reunite(self, wr: WriteHalf) -> TlsStream { - assert!(Arc::ptr_eq(&self.shared, &wr.shared)); - drop(wr); // Drop `wr`, so only one strong reference to `shared` remains. - - Arc::try_unwrap(self.shared) - .unwrap_or_else(|_| panic!("Arc::::try_unwrap() failed")) - .tls_stream - .into_inner() - } -} - -impl AsyncRead for ReadHalf { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - self - .shared - .poll_with_shared_waker(cx, Flow::Read, move |tls, cx| { - tls.poll_read(cx, buf) - }) - } -} - -pub struct WriteHalf { - shared: Arc, -} - -impl WriteHalf { - pub async fn handshake(&mut self) -> io::Result<()> { - poll_fn(|cx| { - self - .shared - .poll_with_shared_waker(cx, Flow::Write, |mut tls, cx| { - tls.poll_handshake(cx) - }) - }) - .await - } - - fn get_alpn_protocol(&mut self) -> Option { - self.shared.get_alpn_protocol() - } -} - -impl AsyncWrite for WriteHalf { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self - .shared - .poll_with_shared_waker(cx, Flow::Write, move |tls, cx| { - tls.poll_write(cx, buf) - }) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self - .shared - .poll_with_shared_waker(cx, Flow::Write, |tls, cx| tls.poll_flush(cx)) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self - .shared - .poll_with_shared_waker(cx, Flow::Write, |tls, cx| tls.poll_shutdown(cx)) - } -} - -struct Shared { - tls_stream: Mutex, - rd_waker: AtomicWaker, - wr_waker: AtomicWaker, -} - -impl Shared { - fn new(tls_stream: TlsStream) -> Arc { - let self_ = Self { - tls_stream: Mutex::new(tls_stream), - rd_waker: AtomicWaker::new(), - wr_waker: AtomicWaker::new(), - }; - Arc::new(self_) - } - - fn poll_with_shared_waker( - self: &Arc, - cx: &mut Context<'_>, - flow: Flow, - mut f: impl FnMut(Pin<&mut TlsStream>, &mut Context<'_>) -> R, - ) -> R { - match flow { - Flow::Handshake => unreachable!(), - Flow::Read => self.rd_waker.register(cx.waker()), - Flow::Write => self.wr_waker.register(cx.waker()), - } - - let shared_waker = self.new_shared_waker(); - let mut cx = Context::from_waker(&shared_waker); - - let mut tls_stream = self.tls_stream.lock(); - f(Pin::new(&mut tls_stream), &mut cx) - } - - const SHARED_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( - Self::clone_shared_waker, - Self::wake_shared_waker, - Self::wake_shared_waker_by_ref, - Self::drop_shared_waker, - ); - - fn new_shared_waker(self: &Arc) -> Waker { - let self_weak = Arc::downgrade(self); - let self_ptr = self_weak.into_raw() as *const (); - let raw_waker = RawWaker::new(self_ptr, &Self::SHARED_WAKER_VTABLE); - // TODO(bartlomieju): - #[allow(clippy::undocumented_unsafe_blocks)] - unsafe { - Waker::from_raw(raw_waker) - } - } - - fn clone_shared_waker(self_ptr: *const ()) -> RawWaker { - // TODO(bartlomieju): - #[allow(clippy::undocumented_unsafe_blocks)] - let self_weak = unsafe { Weak::from_raw(self_ptr as *const Self) }; - let ptr1 = self_weak.clone().into_raw(); - let ptr2 = self_weak.into_raw(); - assert!(ptr1 == ptr2); - RawWaker::new(self_ptr, &Self::SHARED_WAKER_VTABLE) - } - - fn wake_shared_waker(self_ptr: *const ()) { - Self::wake_shared_waker_by_ref(self_ptr); - Self::drop_shared_waker(self_ptr); - } - - fn wake_shared_waker_by_ref(self_ptr: *const ()) { - // TODO(bartlomieju): - #[allow(clippy::undocumented_unsafe_blocks)] - let self_weak = unsafe { Weak::from_raw(self_ptr as *const Self) }; - if let Some(self_arc) = Weak::upgrade(&self_weak) { - self_arc.rd_waker.wake(); - self_arc.wr_waker.wake(); - } - let _ = self_weak.into_raw(); - } - - fn drop_shared_waker(self_ptr: *const ()) { - // TODO(bartlomieju): - #[allow(clippy::undocumented_unsafe_blocks)] - let _ = unsafe { Weak::from_raw(self_ptr as *const Self) }; - } - - fn get_alpn_protocol(self: &Arc) -> Option { - let mut tls_stream = self.tls_stream.lock(); - tls_stream.get_alpn_protocol() - } -} - -struct ImplementReadTrait<'a, T>(&'a mut T); - -impl Read for ImplementReadTrait<'_, TcpStream> { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.try_read(buf) - } -} - -struct ImplementWriteTrait<'a, T>(&'a mut T); - -impl Write for ImplementWriteTrait<'_, TcpStream> { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self.0.try_write(buf) { - Ok(n) => Ok(n), - Err(err) if err.kind() == ErrorKind::WouldBlock => Ok(0), - Err(err) => Err(err), - } - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} +pub(crate) const TLS_BUFFER_SIZE: Option = + NonZeroUsize::new(65536); #[derive(Debug)] pub struct TlsStreamResource { - rd: AsyncRefCell, - wr: AsyncRefCell, + rd: AsyncRefCell, + wr: AsyncRefCell, // `None` when a TLS handshake hasn't been done. handshake_info: RefCell>, cancel_handle: CancelHandle, // Only read and handshake ops get canceled. } impl TlsStreamResource { - pub fn new((rd, wr): (ReadHalf, WriteHalf)) -> Self { + pub fn new((rd, wr): (TlsStreamRead, TlsStreamWrite)) -> Self { Self { rd: rd.into(), wr: wr.into(), @@ -697,7 +79,7 @@ impl TlsStreamResource { } } - pub fn into_inner(self) -> (ReadHalf, WriteHalf) { + pub fn into_inner(self) -> (TlsStreamRead, TlsStreamWrite) { (self.rd.into_inner(), self.wr.into_inner()) } @@ -707,12 +89,10 @@ impl TlsStreamResource { ) -> Result { let mut rd = RcRef::map(&self, |r| &r.rd).borrow_mut().await; let cancel_handle = RcRef::map(&self, |r| &r.cancel_handle); - let nread = rd.read(data).try_or_cancel(cancel_handle).await?; - Ok(nread) + Ok(rd.read(data).try_or_cancel(cancel_handle).await?) } pub async fn write(self: Rc, data: &[u8]) -> Result { - self.handshake().await?; let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; let nwritten = wr.write(data).await?; wr.flush().await?; @@ -720,7 +100,6 @@ impl TlsStreamResource { } pub async fn shutdown(self: Rc) -> Result<(), AnyError> { - self.handshake().await?; let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; wr.shutdown().await?; Ok(()) @@ -735,9 +114,9 @@ impl TlsStreamResource { let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; let cancel_handle = RcRef::map(self, |r| &r.cancel_handle); - wr.handshake().try_or_cancel(cancel_handle).await?; + let handshake = wr.handshake().try_or_cancel(cancel_handle).await?; - let alpn_protocol = wr.get_alpn_protocol(); + let alpn_protocol = handshake.alpn.map(|alpn| alpn.into()); let tls_info = TlsHandshakeInfo { alpn_protocol }; self.handshake_info.replace(Some(tls_info.clone())); Ok(tls_info) @@ -849,9 +228,12 @@ where } let tls_config = Arc::new(tls_config); - - let tls_stream = - TlsStream::new_client_side(tcp_stream, tls_config, hostname_dns); + let tls_stream = TlsStream::new_client_side( + tcp_stream, + tls_config, + hostname_dns, + TLS_BUFFER_SIZE, + ); let rid = { let mut state_ = state.borrow_mut(); @@ -950,8 +332,12 @@ where let tls_config = Arc::new(tls_config); - let tls_stream = - TlsStream::new_client_side(tcp_stream, tls_config, hostname_dns); + let tls_stream = TlsStream::new_client_side( + tcp_stream, + tls_config, + hostname_dns, + TLS_BUFFER_SIZE, + ); let rid = { let mut state_ = state.borrow_mut(); @@ -1136,8 +522,11 @@ pub async fn op_net_accept_tls( let local_addr = tcp_stream.local_addr()?; - let tls_stream = - TlsStream::new_server_side(tcp_stream, resource.tls_config.clone()); + let tls_stream = TlsStream::new_server_side( + tcp_stream, + resource.tls_config.clone(), + TLS_BUFFER_SIZE, + ); let rid = { let mut state_ = state.borrow_mut(); diff --git a/ext/net/raw.rs b/ext/net/raw.rs index 0c92c46707..9bdea4191a 100644 --- a/ext/net/raw.rs +++ b/ext/net/raw.rs @@ -4,8 +4,8 @@ use crate::io::TcpStreamResource; use crate::io::UnixStreamResource; use crate::ops::TcpListenerResource; use crate::ops_tls::TlsListenerResource; -use crate::ops_tls::TlsStream; use crate::ops_tls::TlsStreamResource; +use crate::ops_tls::TLS_BUFFER_SIZE; #[cfg(unix)] use crate::ops_unix::UnixListenerResource; use deno_core::error::bad_resource; @@ -15,6 +15,7 @@ use deno_core::ResourceId; use deno_core::ResourceTable; use deno_tls::rustls::ServerConfig; use pin_project::pin_project; +use rustls_tokio_stream::TlsStream; use std::rc::Rc; use std::sync::Arc; use tokio::net::TcpStream; @@ -187,7 +188,11 @@ impl NetworkStreamListener { } Self::Tls(tcp, config) => { let (stream, _addr) = tcp.accept().await?; - NetworkStream::Tls(TlsStream::new_server_side(stream, config.clone())) + NetworkStream::Tls(TlsStream::new_server_side( + stream, + config.clone(), + TLS_BUFFER_SIZE, + )) } #[cfg(unix)] Self::Unix(unix) => { @@ -242,7 +247,7 @@ pub fn take_network_stream_resource( let resource = Rc::try_unwrap(resource_rc) .map_err(|_| bad_resource("TLS stream is currently in use"))?; let (read_half, write_half) = resource.into_inner(); - let tls_stream = read_half.reunite(write_half); + let tls_stream = read_half.unsplit(write_half); return Ok(NetworkStream::Tls(tls_stream)); } diff --git a/runtime/ops/http.rs b/runtime/ops/http.rs index f0f0510daf..fc66c9fabf 100644 --- a/runtime/ops/http.rs +++ b/runtime/ops/http.rs @@ -65,8 +65,8 @@ fn op_http_start( let resource = Rc::try_unwrap(resource_rc) .map_err(|_| bad_resource("TLS stream is currently in use"))?; let (read_half, write_half) = resource.into_inner(); - let tls_stream = read_half.reunite(write_half); - let addr = tls_stream.get_ref().0.local_addr()?; + let tls_stream = read_half.unsplit(write_half); + let addr = tls_stream.local_addr()?; return http_create_conn_resource(state, tls_stream, addr, "https"); }