diff --git a/Cargo.lock b/Cargo.lock index 4f34174c8c..fa7d00884c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -713,12 +713,12 @@ dependencies = [ "percent-encoding", "regex", "ring", + "rustls", "serde", "sys-info", "termcolor", "test_util", "tokio", - "tokio-rustls", "tokio-util", "trust-dns-proto", "trust-dns-resolver", diff --git a/cli/tests/integration_tests.rs b/cli/tests/integration_tests.rs index 56043930b6..5cca0d1ccf 100644 --- a/cli/tests/integration_tests.rs +++ b/cli/tests/integration_tests.rs @@ -5,7 +5,9 @@ use deno_core::serde_json; use deno_core::url; use deno_runtime::deno_fetch::reqwest; use deno_runtime::deno_websocket::tokio_tungstenite; -use rustls::Session; +use deno_runtime::ops::tls::rustls; +use deno_runtime::ops::tls::webpki; +use deno_runtime::ops::tls::TlsStream; use std::fs; use std::io::BufReader; use std::io::Cursor; @@ -14,8 +16,7 @@ use std::process::Command; use std::sync::Arc; use tempfile::TempDir; use test_util as util; -use tokio_rustls::rustls; -use tokio_rustls::webpki; +use tokio::task::LocalSet; #[test] fn js_unit_tests_lint() { @@ -6134,79 +6135,103 @@ console.log("finish"); #[tokio::test] async fn listen_tls_alpn() { - let child = util::deno_cmd() - .current_dir(util::root_path()) - .arg("run") - .arg("--unstable") - .arg("--quiet") - .arg("--allow-net") - .arg("--allow-read") - .arg("./cli/tests/listen_tls_alpn.ts") - .arg("4504") - .stdout(std::process::Stdio::piped()) - .spawn() - .unwrap(); - let mut stdout = child.stdout.unwrap(); - let mut buffer = [0; 5]; - let read = stdout.read(&mut buffer).unwrap(); - assert_eq!(read, 5); - let msg = std::str::from_utf8(&buffer).unwrap(); - assert_eq!(msg, "READY"); + // TLS streams require the presence of an ambient local task set to gracefully + // close dropped connections in the background. + LocalSet::new() + .run_until(async { + let mut child = util::deno_cmd() + .current_dir(util::root_path()) + .arg("run") + .arg("--unstable") + .arg("--quiet") + .arg("--allow-net") + .arg("--allow-read") + .arg("./cli/tests/listen_tls_alpn.ts") + .arg("4504") + .stdout(std::process::Stdio::piped()) + .spawn() + .unwrap(); + let stdout = child.stdout.as_mut().unwrap(); + let mut buffer = [0; 5]; + let read = stdout.read(&mut buffer).unwrap(); + assert_eq!(read, 5); + let msg = std::str::from_utf8(&buffer).unwrap(); + assert_eq!(msg, "READY"); - let mut cfg = rustls::ClientConfig::new(); - let reader = - &mut BufReader::new(Cursor::new(include_bytes!("./tls/RootCA.crt"))); - cfg.root_store.add_pem_file(reader).unwrap(); - cfg.alpn_protocols.push("foobar".as_bytes().to_vec()); + let mut cfg = rustls::ClientConfig::new(); + let reader = + &mut BufReader::new(Cursor::new(include_bytes!("./tls/RootCA.crt"))); + cfg.root_store.add_pem_file(reader).unwrap(); + cfg.alpn_protocols.push("foobar".as_bytes().to_vec()); + let cfg = Arc::new(cfg); - let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(cfg)); - let hostname = webpki::DNSNameRef::try_from_ascii_str("localhost").unwrap(); - let stream = tokio::net::TcpStream::connect("localhost:4504") - .await - .unwrap(); + let hostname = + webpki::DNSNameRef::try_from_ascii_str("localhost").unwrap(); - let tls_stream = tls_connector.connect(hostname, stream).await.unwrap(); - let (_, session) = tls_stream.get_ref(); + let tcp_stream = tokio::net::TcpStream::connect("localhost:4504") + .await + .unwrap(); + let mut tls_stream = + TlsStream::new_client_side(tcp_stream, &cfg, hostname); + tls_stream.handshake().await.unwrap(); + let (_, session) = tls_stream.get_ref(); - let alpn = session.get_alpn_protocol().unwrap(); - assert_eq!(std::str::from_utf8(alpn).unwrap(), "foobar"); + let alpn = session.get_alpn_protocol().unwrap(); + assert_eq!(std::str::from_utf8(alpn).unwrap(), "foobar"); + + child.kill().unwrap(); + child.wait().unwrap(); + }) + .await; } #[tokio::test] async fn listen_tls_alpn_fail() { - let child = util::deno_cmd() - .current_dir(util::root_path()) - .arg("run") - .arg("--unstable") - .arg("--quiet") - .arg("--allow-net") - .arg("--allow-read") - .arg("./cli/tests/listen_tls_alpn.ts") - .arg("4505") - .stdout(std::process::Stdio::piped()) - .spawn() - .unwrap(); - let mut stdout = child.stdout.unwrap(); - let mut buffer = [0; 5]; - let read = stdout.read(&mut buffer).unwrap(); - assert_eq!(read, 5); - let msg = std::str::from_utf8(&buffer).unwrap(); - assert_eq!(msg, "READY"); + // TLS streams require the presence of an ambient local task set to gracefully + // close dropped connections in the background. + LocalSet::new() + .run_until(async { + let mut child = util::deno_cmd() + .current_dir(util::root_path()) + .arg("run") + .arg("--unstable") + .arg("--quiet") + .arg("--allow-net") + .arg("--allow-read") + .arg("./cli/tests/listen_tls_alpn.ts") + .arg("4505") + .stdout(std::process::Stdio::piped()) + .spawn() + .unwrap(); + let stdout = child.stdout.as_mut().unwrap(); + let mut buffer = [0; 5]; + let read = stdout.read(&mut buffer).unwrap(); + assert_eq!(read, 5); + let msg = std::str::from_utf8(&buffer).unwrap(); + assert_eq!(msg, "READY"); - let mut cfg = rustls::ClientConfig::new(); - let reader = - &mut BufReader::new(Cursor::new(include_bytes!("./tls/RootCA.crt"))); - cfg.root_store.add_pem_file(reader).unwrap(); - cfg.alpn_protocols.push("boofar".as_bytes().to_vec()); + let mut cfg = rustls::ClientConfig::new(); + let reader = + &mut BufReader::new(Cursor::new(include_bytes!("./tls/RootCA.crt"))); + cfg.root_store.add_pem_file(reader).unwrap(); + cfg.alpn_protocols.push("boofar".as_bytes().to_vec()); + let cfg = Arc::new(cfg); - let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(cfg)); - let hostname = webpki::DNSNameRef::try_from_ascii_str("localhost").unwrap(); - let stream = tokio::net::TcpStream::connect("localhost:4505") - .await - .unwrap(); + let hostname = + webpki::DNSNameRef::try_from_ascii_str("localhost").unwrap(); - let tls_stream = tls_connector.connect(hostname, stream).await.unwrap(); - let (_, session) = tls_stream.get_ref(); + let tcp_stream = tokio::net::TcpStream::connect("localhost:4505") + .await + .unwrap(); + let mut tls_stream = + TlsStream::new_client_side(tcp_stream, &cfg, hostname); + tls_stream.handshake().await.unwrap(); + let (_, session) = tls_stream.get_ref(); - assert!(session.get_alpn_protocol().is_none()); + assert!(session.get_alpn_protocol().is_none()); + + child.kill().unwrap(); + child.wait().unwrap(); + }) + .await; } diff --git a/cli/tests/unit/tls_test.ts b/cli/tests/unit/tls_test.ts index 0528c80438..cedcf467de 100644 --- a/cli/tests/unit/tls_test.ts +++ b/cli/tests/unit/tls_test.ts @@ -2,9 +2,11 @@ import { assert, assertEquals, + assertNotEquals, assertStrictEquals, assertThrows, assertThrowsAsync, + Deferred, deferred, unitTest, } from "./test_util.ts"; @@ -14,6 +16,14 @@ import { TextProtoReader } from "../../../test_util/std/textproto/mod.ts"; const encoder = new TextEncoder(); const decoder = new TextDecoder(); +async function sleep(msec: number): Promise { + await new Promise((res, _rej) => setTimeout(res, msec)); +} + +function unreachable(): never { + throw new Error("Unreachable code reached"); +} + unitTest(async function connectTLSNoPerm(): Promise { await assertThrowsAsync(async () => { await Deno.connectTls({ hostname: "github.com", port: 443 }); @@ -201,7 +211,13 @@ unitTest( }, ); -async function tlsPair(port: number): Promise<[Deno.Conn, Deno.Conn]> { +let nextPort = 3501; +function getPort() { + return nextPort++; +} + +async function tlsPair(): Promise<[Deno.Conn, Deno.Conn]> { + const port = getPort(); const listener = Deno.listenTls({ hostname: "localhost", port, @@ -215,59 +231,169 @@ async function tlsPair(port: number): Promise<[Deno.Conn, Deno.Conn]> { port, certFile: "cli/tests/tls/RootCA.pem", }); - const connections = await Promise.all([acceptPromise, connectPromise]); + const endpoints = await Promise.all([acceptPromise, connectPromise]); listener.close(); - return connections; + return endpoints; } -async function sendCloseWrite(conn: Deno.Conn): Promise { - const buf = new Uint8Array(1024); - let n: number | null; +async function sendThenCloseWriteThenReceive( + conn: Deno.Conn, + chunkCount: number, + chunkSize: number, +): Promise { + const byteCount = chunkCount * chunkSize; + const buf = new Uint8Array(chunkSize); // Note: buf is size of _chunk_. + let n: number; - // Send 1. - n = await conn.write(new Uint8Array([1])); - assertStrictEquals(n, 1); + // Slowly send 42s. + buf.fill(42); + for (let remaining = byteCount; remaining > 0; remaining -= n) { + n = await conn.write(buf.subarray(0, remaining)); + assert(n >= 1); + await sleep(10); + } // Send EOF. await conn.closeWrite(); - // Receive 2. - n = await conn.read(buf); - assertStrictEquals(n, 1); - assertStrictEquals(buf[0], 2); + // Receive 69s. + for (let remaining = byteCount; remaining > 0; remaining -= n) { + buf.fill(0); + n = await conn.read(buf) as number; + assert(n >= 1); + assertStrictEquals(buf[0], 69); + assertStrictEquals(buf[n - 1], 69); + } conn.close(); } -async function receiveCloseWrite(conn: Deno.Conn): Promise { - const buf = new Uint8Array(1024); - let n: number | null; +async function receiveThenSend( + conn: Deno.Conn, + chunkCount: number, + chunkSize: number, +): Promise { + const byteCount = chunkCount * chunkSize; + const buf = new Uint8Array(byteCount); // Note: buf size equals `byteCount`. + let n: number; - // Receive 1. - n = await conn.read(buf); - assertStrictEquals(n, 1); - assertStrictEquals(buf[0], 1); + // Receive 42s. + for (let remaining = byteCount; remaining > 0; remaining -= n) { + buf.fill(0); + n = await conn.read(buf) as number; + assert(n >= 1); + assertStrictEquals(buf[0], 42); + assertStrictEquals(buf[n - 1], 42); + } - // Receive EOF. - n = await conn.read(buf); - assertStrictEquals(n, null); - - // Send 2. - n = await conn.write(new Uint8Array([2])); - assertStrictEquals(n, 1); + // Slowly send 69s. + buf.fill(69); + for (let remaining = byteCount; remaining > 0; remaining -= n) { + n = await conn.write(buf.subarray(0, remaining)); + assert(n >= 1); + await sleep(10); + } conn.close(); } +unitTest( + { perms: { read: true, net: true } }, + async function tlsServerStreamHalfCloseSendOneByte(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + sendThenCloseWriteThenReceive(serverConn, 1, 1), + receiveThenSend(clientConn, 1, 1), + ]); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsClientStreamHalfCloseSendOneByte(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + sendThenCloseWriteThenReceive(clientConn, 1, 1), + receiveThenSend(serverConn, 1, 1), + ]); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsServerStreamHalfCloseSendOneChunk(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + sendThenCloseWriteThenReceive(serverConn, 1, 1 << 20 /* 1 MB */), + receiveThenSend(clientConn, 1, 1 << 20 /* 1 MB */), + ]); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsClientStreamHalfCloseSendOneChunk(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + sendThenCloseWriteThenReceive(clientConn, 1, 1 << 20 /* 1 MB */), + receiveThenSend(serverConn, 1, 1 << 20 /* 1 MB */), + ]); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsServerStreamHalfCloseSendManyBytes(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + sendThenCloseWriteThenReceive(serverConn, 100, 1), + receiveThenSend(clientConn, 100, 1), + ]); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsClientStreamHalfCloseSendManyBytes(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + sendThenCloseWriteThenReceive(clientConn, 100, 1), + receiveThenSend(serverConn, 100, 1), + ]); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsServerStreamHalfCloseSendManyChunks(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + sendThenCloseWriteThenReceive(serverConn, 100, 1 << 16 /* 64 kB */), + receiveThenSend(clientConn, 100, 1 << 16 /* 64 kB */), + ]); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsClientStreamHalfCloseSendManyChunks(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + sendThenCloseWriteThenReceive(clientConn, 100, 1 << 16 /* 64 kB */), + receiveThenSend(serverConn, 100, 1 << 16 /* 64 kB */), + ]); + }, +); + async function sendAlotReceiveNothing(conn: Deno.Conn): Promise { // Start receive op. const readBuf = new Uint8Array(1024); const readPromise = conn.read(readBuf); // Send 1 MB of data. - const writeBuf = new Uint8Array(1 << 20); + const writeBuf = new Uint8Array(1 << 20 /* 1 MB */); writeBuf.fill(42); await conn.write(writeBuf); @@ -289,7 +415,7 @@ async function receiveAlotSendNothing(conn: Deno.Conn): Promise { let n: number | null; // Receive 1 MB of data. - for (let nread = 0; nread < 1 << 20; nread += n!) { + for (let nread = 0; nread < 1 << 20 /* 1 MB */; nread += n!) { n = await conn.read(readBuf); assertStrictEquals(typeof n, "number"); assert(n! > 0); @@ -300,32 +426,10 @@ async function receiveAlotSendNothing(conn: Deno.Conn): Promise { conn.close(); } -unitTest( - { perms: { read: true, net: true } }, - async function tlsServerStreamHalfClose(): Promise { - const [serverConn, clientConn] = await tlsPair(3501); - await Promise.all([ - sendCloseWrite(serverConn), - receiveCloseWrite(clientConn), - ]); - }, -); - -unitTest( - { perms: { read: true, net: true } }, - async function tlsClientStreamHalfClose(): Promise { - const [serverConn, clientConn] = await tlsPair(3502); - await Promise.all([ - sendCloseWrite(clientConn), - receiveCloseWrite(serverConn), - ]); - }, -); - unitTest( { perms: { read: true, net: true } }, async function tlsServerStreamCancelRead(): Promise { - const [serverConn, clientConn] = await tlsPair(3503); + const [serverConn, clientConn] = await tlsPair(); await Promise.all([ sendAlotReceiveNothing(serverConn), receiveAlotSendNothing(clientConn), @@ -336,7 +440,7 @@ unitTest( unitTest( { perms: { read: true, net: true } }, async function tlsClientStreamCancelRead(): Promise { - const [serverConn, clientConn] = await tlsPair(3504); + const [serverConn, clientConn] = await tlsPair(); await Promise.all([ sendAlotReceiveNothing(clientConn), receiveAlotSendNothing(serverConn), @@ -344,6 +448,493 @@ unitTest( }, ); +async function sendReceiveEmptyBuf(conn: Deno.Conn): Promise { + const byteBuf = new Uint8Array([1]); + const emptyBuf = new Uint8Array(0); + let n: number | null; + + n = await conn.write(emptyBuf); + assertStrictEquals(n, 0); + + n = await conn.read(emptyBuf); + assertStrictEquals(n, 0); + + n = await conn.write(byteBuf); + assertStrictEquals(n, 1); + + n = await conn.read(byteBuf); + assertStrictEquals(n, 1); + + await conn.closeWrite(); + + n = await conn.write(emptyBuf); + assertStrictEquals(n, 0); + + await assertThrowsAsync(async () => { + await conn.write(byteBuf); + }, Deno.errors.BrokenPipe); + + n = await conn.write(emptyBuf); + assertStrictEquals(n, 0); + + n = await conn.read(byteBuf); + assertStrictEquals(n, null); + + conn.close(); +} + +unitTest( + { perms: { read: true, net: true } }, + async function tlsStreamSendReceiveEmptyBuf(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + sendReceiveEmptyBuf(serverConn), + sendReceiveEmptyBuf(clientConn), + ]); + }, +); + +function immediateClose(conn: Deno.Conn): Promise { + conn.close(); + return Promise.resolve(); +} + +async function closeWriteAndClose(conn: Deno.Conn): Promise { + await conn.closeWrite(); + + if (await conn.read(new Uint8Array(1)) !== null) { + throw new Error("did not expect to receive data on TLS stream"); + } + + conn.close(); +} + +unitTest( + { perms: { read: true, net: true } }, + async function tlsServerStreamImmediateClose(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + immediateClose(serverConn), + closeWriteAndClose(clientConn), + ]); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsClientStreamImmediateClose(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + closeWriteAndClose(serverConn), + immediateClose(clientConn), + ]); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsClientAndServerStreamImmediateClose(): Promise { + const [serverConn, clientConn] = await tlsPair(); + await Promise.all([ + immediateClose(serverConn), + immediateClose(clientConn), + ]); + }, +); + +async function tlsWithTcpFailureTestImpl( + phase: "handshake" | "traffic", + cipherByteCount: number, + failureMode: "corruption" | "shutdown", + reverse: boolean, +): Promise { + const tlsPort = getPort(); + const tlsListener = Deno.listenTls({ + hostname: "localhost", + port: tlsPort, + certFile: "cli/tests/tls/localhost.crt", + keyFile: "cli/tests/tls/localhost.key", + }); + + const tcpPort = getPort(); + const tcpListener = Deno.listen({ hostname: "localhost", port: tcpPort }); + + const [tlsServerConn, tcpServerConn] = await Promise.all([ + tlsListener.accept(), + Deno.connect({ hostname: "localhost", port: tlsPort }), + ]); + + const [tcpClientConn, tlsClientConn] = await Promise.all([ + tcpListener.accept(), + Deno.connectTls({ + hostname: "localhost", + port: tcpPort, + certFile: "cli/tests/tls/RootCA.crt", + }), + ]); + + tlsListener.close(); + tcpListener.close(); + + const { + tlsConn1, + tlsConn2, + tcpConn1, + tcpConn2, + } = reverse + ? { + tlsConn1: tlsClientConn, + tlsConn2: tlsServerConn, + tcpConn1: tcpClientConn, + tcpConn2: tcpServerConn, + } + : { + tlsConn1: tlsServerConn, + tlsConn2: tlsClientConn, + tcpConn1: tcpServerConn, + tcpConn2: tcpClientConn, + }; + + const tcpForwardingInterruptPromise1 = deferred(); + const tcpForwardingPromise1 = forwardBytes( + tcpConn2, + tcpConn1, + cipherByteCount, + tcpForwardingInterruptPromise1, + ); + + const tcpForwardingInterruptPromise2 = deferred(); + const tcpForwardingPromise2 = forwardBytes( + tcpConn1, + tcpConn2, + Infinity, + tcpForwardingInterruptPromise2, + ); + + switch (phase) { + case "handshake": { + let expectedError; + switch (failureMode) { + case "corruption": + expectedError = Deno.errors.InvalidData; + break; + case "shutdown": + expectedError = Deno.errors.UnexpectedEof; + break; + default: + unreachable(); + } + + const tlsTrafficPromise1 = Promise.all([ + assertThrowsAsync( + () => sendBytes(tlsConn1, 0x01, 1), + expectedError, + ), + assertThrowsAsync( + () => receiveBytes(tlsConn1, 0x02, 1), + expectedError, + ), + ]); + + const tlsTrafficPromise2 = Promise.all([ + assertThrowsAsync( + () => sendBytes(tlsConn2, 0x02, 1), + Deno.errors.UnexpectedEof, + ), + assertThrowsAsync( + () => receiveBytes(tlsConn2, 0x01, 1), + Deno.errors.UnexpectedEof, + ), + ]); + + await tcpForwardingPromise1; + + switch (failureMode) { + case "corruption": + await sendBytes(tcpConn1, 0xff, 1 << 14 /* 16 kB */); + break; + case "shutdown": + await tcpConn1.closeWrite(); + break; + default: + unreachable(); + } + await tlsTrafficPromise1; + + tcpForwardingInterruptPromise2.resolve(); + await tcpForwardingPromise2; + await tcpConn2.closeWrite(); + await tlsTrafficPromise2; + + break; + } + + case "traffic": { + await Promise.all([ + sendBytes(tlsConn2, 0x88, 8888), + receiveBytes(tlsConn1, 0x88, 8888), + sendBytes(tlsConn1, 0x99, 99999), + receiveBytes(tlsConn2, 0x99, 99999), + ]); + + tcpForwardingInterruptPromise1.resolve(); + await tcpForwardingPromise1; + + switch (failureMode) { + case "corruption": + await sendBytes(tcpConn1, 0xff, 1 << 14 /* 16 kB */); + await assertThrowsAsync( + () => receiveEof(tlsConn1), + Deno.errors.InvalidData, + ); + tcpForwardingInterruptPromise2.resolve(); + break; + case "shutdown": + // Receiving a TCP FIN packet without receiving a TLS CloseNotify + // alert is not the expected mode of operation, but it is not a + // problem either, so it should be treated as if the TLS session was + // gracefully closed. + await Promise.all([ + tcpConn1.closeWrite(), + await receiveEof(tlsConn1), + await tlsConn1.closeWrite(), + await receiveEof(tlsConn2), + ]); + break; + default: + unreachable(); + } + + await tcpForwardingPromise2; + + break; + } + + default: + unreachable(); + } + + tlsServerConn.close(); + tlsClientConn.close(); + tcpServerConn.close(); + tcpClientConn.close(); + + async function sendBytes( + conn: Deno.Conn, + byte: number, + count: number, + ): Promise { + let buf = new Uint8Array(1 << 12 /* 4 kB */); + buf.fill(byte); + + while (count > 0) { + buf = buf.subarray(0, Math.min(buf.length, count)); + const nwritten = await conn.write(buf); + assertStrictEquals(nwritten, buf.length); + count -= nwritten; + } + } + + async function receiveBytes( + conn: Deno.Conn, + byte: number, + count: number, + ): Promise { + let buf = new Uint8Array(1 << 12 /* 4 kB */); + while (count > 0) { + buf = buf.subarray(0, Math.min(buf.length, count)); + const r = await conn.read(buf); + assertNotEquals(r, null); + assert(buf.subarray(0, r!).every((b) => b === byte)); + count -= r!; + } + } + + async function receiveEof(conn: Deno.Conn) { + const buf = new Uint8Array(1); + const r = await conn.read(buf); + assertStrictEquals(r, null); + } + + async function forwardBytes( + source: Deno.Conn, + sink: Deno.Conn, + count: number, + interruptPromise: Deferred, + ): Promise { + let buf = new Uint8Array(1 << 12 /* 4 kB */); + while (count > 0) { + buf = buf.subarray(0, Math.min(buf.length, count)); + const nread = await Promise.race([source.read(buf), interruptPromise]); + if (nread == null) break; // Either EOF or interrupted. + const nwritten = await sink.write(buf.subarray(0, nread)); + assertStrictEquals(nread, nwritten); + count -= nwritten; + } + } +} + +unitTest( + { perms: { read: true, net: true } }, + async function tlsHandshakeWithTcpCorruptionImmediately() { + await tlsWithTcpFailureTestImpl("handshake", 0, "corruption", false); + await tlsWithTcpFailureTestImpl("handshake", 0, "corruption", true); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsHandshakeWithTcpShutdownImmediately() { + await tlsWithTcpFailureTestImpl("handshake", 0, "shutdown", false); + await tlsWithTcpFailureTestImpl("handshake", 0, "shutdown", true); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsHandshakeWithTcpCorruptionAfter70Bytes() { + await tlsWithTcpFailureTestImpl("handshake", 76, "corruption", false); + await tlsWithTcpFailureTestImpl("handshake", 78, "corruption", true); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsHandshakeWithTcpShutdownAfter70bytes() { + await tlsWithTcpFailureTestImpl("handshake", 77, "shutdown", false); + await tlsWithTcpFailureTestImpl("handshake", 79, "shutdown", true); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsHandshakeWithTcpCorruptionAfter200Bytes() { + await tlsWithTcpFailureTestImpl("handshake", 200, "corruption", false); + await tlsWithTcpFailureTestImpl("handshake", 202, "corruption", true); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsHandshakeWithTcpShutdownAfter200bytes() { + await tlsWithTcpFailureTestImpl("handshake", 201, "shutdown", false); + await tlsWithTcpFailureTestImpl("handshake", 203, "shutdown", true); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsTrafficWithTcpCorruption() { + await tlsWithTcpFailureTestImpl("traffic", Infinity, "corruption", false); + await tlsWithTcpFailureTestImpl("traffic", Infinity, "corruption", true); + }, +); + +unitTest( + { perms: { read: true, net: true } }, + async function tlsTrafficWithTcpShutdown() { + await tlsWithTcpFailureTestImpl("traffic", Infinity, "shutdown", false); + await tlsWithTcpFailureTestImpl("traffic", Infinity, "shutdown", true); + }, +); + +function createHttpsListener(port: number): Deno.Listener { + // Query format: `curl --insecure https://localhost:8443/z/12345` + // The server returns a response consisting of 12345 times the letter 'z'. + const listener = Deno.listenTls({ + hostname: "localhost", + port, + certFile: "./cli/tests/tls/localhost.crt", + keyFile: "./cli/tests/tls/localhost.key", + }); + + serve(listener); + return listener; + + async function serve(listener: Deno.Listener) { + for await (const conn of listener) { + const EOL = "\r\n"; + + // Read GET request plus headers. + const buf = new Uint8Array(1 << 12 /* 4 kB */); + const decoder = new TextDecoder(); + let req = ""; + while (!req.endsWith(EOL + EOL)) { + const n = await conn.read(buf); + if (n === null) throw new Error("Unexpected EOF"); + req += decoder.decode(buf.subarray(0, n)); + } + + // Parse GET request. + const { filler, count, version } = + /^GET \/(?[^\/]+)\/(?\d+) HTTP\/(?1\.\d)\r\n/ + .exec(req)!.groups as { + filler: string; + count: string; + version: string; + }; + + // Generate response. + const resBody = new TextEncoder().encode(filler.repeat(+count)); + const resHead = new TextEncoder().encode( + [ + `HTTP/${version} 200 OK`, + `Content-Length: ${resBody.length}`, + "Content-Type: text/plain", + ].join(EOL) + EOL + EOL, + ); + + // Send response. + await conn.write(resHead); + await conn.write(resBody); + + // Close TCP connection. + conn.close(); + } + } +} + +async function curl(url: string): Promise { + const curl = Deno.run({ + cmd: ["curl", "--insecure", url], + stdout: "piped", + }); + + try { + const [status, output] = await Promise.all([curl.status(), curl.output()]); + if (!status.success) { + throw new Error(`curl ${url} failed: ${status.code}`); + } + return new TextDecoder().decode(output); + } finally { + curl.close(); + } +} + +unitTest( + { perms: { read: true, net: true, run: true } }, + async function curlFakeHttpsServer(): Promise { + const port = getPort(); + const listener = createHttpsListener(port); + + const res1 = await curl(`https://localhost:${port}/d/1`); + assertStrictEquals(res1, "d"); + + const res2 = await curl(`https://localhost:${port}/e/12345`); + assertStrictEquals(res2, "e".repeat(12345)); + + const count3 = 1 << 17; // 128 kB. + const res3 = await curl(`https://localhost:${port}/n/${count3}`); + assertStrictEquals(res3, "n".repeat(count3)); + + const count4 = 12345678; + const res4 = await curl(`https://localhost:${port}/o/${count4}`); + assertStrictEquals(res4, "o".repeat(count4)); + + listener.close(); + }, +); + unitTest( { perms: { read: true, net: true } }, async function startTls(): Promise { diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 4cefa23ee8..4ca0539db5 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -64,12 +64,12 @@ notify = "5.0.0-pre.7" percent-encoding = "2.1.0" regex = "1.4.3" ring = "0.16.20" +rustls = "0.19.0" serde = { version = "1.0.125", features = ["derive"] } sys-info = "0.9.0" termcolor = "1.1.2" tokio = { version = "1.4.0", features = ["full"] } tokio-util = { version = "0.6", features = ["io"] } -tokio-rustls = "0.22.0" uuid = { version = "0.8.2", features = ["v4"] } webpki = "0.21.4" webpki-roots = "0.21.1" diff --git a/runtime/ops/http.rs b/runtime/ops/http.rs index 3642a0ac33..e4ba2db2a5 100644 --- a/runtime/ops/http.rs +++ b/runtime/ops/http.rs @@ -1,7 +1,8 @@ // Copyright 2018-2021 the Deno authors. All rights reserved. MIT license. use crate::ops::io::TcpStreamResource; -use crate::ops::io::TlsServerStreamResource; +use crate::ops::io::TlsStreamResource; +use crate::ops::tls::TlsStream; use deno_core::error::bad_resource_id; use deno_core::error::null_opbuf; use deno_core::error::type_error; @@ -43,7 +44,6 @@ use std::task::Poll; use tokio::io::AsyncReadExt; use tokio::net::TcpStream; use tokio::sync::oneshot; -use tokio_rustls::server::TlsStream; use tokio_util::io::StreamReader; pub fn init() -> Extension { @@ -100,7 +100,7 @@ impl HyperService> for Service { enum ConnType { Tcp(Rc>>), - Tls(Rc, Service, LocalExecutor>>>), + Tls(Rc>>), } struct ConnResource { @@ -305,12 +305,12 @@ fn op_http_start( if let Some(resource_rc) = state .resource_table - .take::(tcp_stream_rid) + .take::(tcp_stream_rid) { let resource = Rc::try_unwrap(resource_rc) .expect("Only a single use of this resource should happen"); let (read_half, write_half) = resource.into_inner(); - let tls_stream = read_half.unsplit(write_half); + let tls_stream = read_half.reunite(write_half); let addr = tls_stream.get_ref().0.local_addr()?; let hyper_connection = Http::new() diff --git a/runtime/ops/io.rs b/runtime/ops/io.rs index c7faa73d7f..d9f21e1f54 100644 --- a/runtime/ops/io.rs +++ b/runtime/ops/io.rs @@ -1,5 +1,6 @@ // Copyright 2018-2021 the Deno authors. All rights reserved. MIT license. +use crate::ops::tls; use deno_core::error::null_opbuf; use deno_core::error::resource_unavailable; use deno_core::error::AnyError; @@ -21,17 +22,12 @@ use std::cell::RefCell; use std::io::Read; use std::io::Write; use std::rc::Rc; -use tokio::io::split; use tokio::io::AsyncRead; use tokio::io::AsyncReadExt; use tokio::io::AsyncWrite; use tokio::io::AsyncWriteExt; -use tokio::io::ReadHalf; -use tokio::io::WriteHalf; use tokio::net::tcp; -use tokio::net::TcpStream; use tokio::process; -use tokio_rustls as tls; #[cfg(unix)] use std::os::unix::io::FromRawFd; @@ -306,18 +302,6 @@ where } } -pub type FullDuplexSplitResource = - FullDuplexResource, WriteHalf>; - -impl From for FullDuplexSplitResource -where - S: AsyncRead + AsyncWrite + 'static, -{ - fn from(stream: S) -> Self { - Self::new(split(stream)) - } -} - pub type ChildStdinResource = WriteOnlyResource; impl Resource for ChildStdinResource { @@ -363,25 +347,11 @@ impl Resource for TcpStreamResource { } } -pub type TlsClientStreamResource = - FullDuplexSplitResource>; +pub type TlsStreamResource = FullDuplexResource; -impl Resource for TlsClientStreamResource { +impl Resource for TlsStreamResource { fn name(&self) -> Cow { - "tlsClientStream".into() - } - - fn close(self: Rc) { - self.cancel_read_ops(); - } -} - -pub type TlsServerStreamResource = - FullDuplexSplitResource>; - -impl Resource for TlsServerStreamResource { - fn name(&self) -> Cow { - "tlsServerStream".into() + "tlsStream".into() } fn close(self: Rc) { @@ -572,9 +542,7 @@ async fn op_read_async( s.read(buf).await? } else if let Some(s) = resource.downcast_rc::() { s.read(buf).await? - } else if let Some(s) = resource.downcast_rc::() { - s.read(buf).await? - } else if let Some(s) = resource.downcast_rc::() { + } else if let Some(s) = resource.downcast_rc::() { s.read(buf).await? } else if let Some(s) = resource.downcast_rc::() { s.read(buf).await? @@ -616,9 +584,7 @@ async fn op_write_async( s.write(buf).await? } else if let Some(s) = resource.downcast_rc::() { s.write(buf).await? - } else if let Some(s) = resource.downcast_rc::() { - s.write(buf).await? - } else if let Some(s) = resource.downcast_rc::() { + } else if let Some(s) = resource.downcast_rc::() { s.write(buf).await? } else if let Some(s) = resource.downcast_rc::() { s.write(buf).await? @@ -644,9 +610,7 @@ async fn op_shutdown( s.shutdown().await?; } else if let Some(s) = resource.downcast_rc::() { s.shutdown().await?; - } else if let Some(s) = resource.downcast_rc::() { - s.shutdown().await?; - } else if let Some(s) = resource.downcast_rc::() { + } else if let Some(s) = resource.downcast_rc::() { s.shutdown().await?; } else if let Some(s) = resource.downcast_rc::() { s.shutdown().await?; diff --git a/runtime/ops/tls.rs b/runtime/ops/tls.rs index 9143a86fa2..c3f554856d 100644 --- a/runtime/ops/tls.rs +++ b/runtime/ops/tls.rs @@ -1,11 +1,13 @@ // Copyright 2018-2021 the Deno authors. All rights reserved. MIT license. -use super::io::TcpStreamResource; -use super::io::TlsClientStreamResource; -use super::io::TlsServerStreamResource; -use super::net::IpAddr; -use super::net::OpAddr; -use super::net::OpConn; +pub use rustls; +pub use webpki; + +use crate::ops::io::TcpStreamResource; +use crate::ops::io::TlsStreamResource; +use crate::ops::net::IpAddr; +use crate::ops::net::OpAddr; +use crate::ops::net::OpConn; use crate::permissions::Permissions; use crate::resolve_addr::resolve_addr; use crate::resolve_addr::resolve_addr_sync; @@ -15,6 +17,15 @@ use deno_core::error::custom_error; use deno_core::error::generic_error; use deno_core::error::invalid_hostname; use deno_core::error::AnyError; +use deno_core::futures::future::poll_fn; +use deno_core::futures::ready; +use deno_core::futures::task::noop_waker_ref; +use deno_core::futures::task::AtomicWaker; +use deno_core::futures::task::Context; +use deno_core::futures::task::Poll; +use deno_core::futures::task::RawWaker; +use deno_core::futures::task::RawWakerVTable; +use deno_core::futures::task::Waker; use deno_core::op_async; use deno_core::op_sync; use deno_core::AsyncRefCell; @@ -25,27 +36,44 @@ use deno_core::OpState; use deno_core::RcRef; use deno_core::Resource; use deno_core::ResourceId; +use io::Error; +use io::Read; +use io::Write; +use rustls::internal::pemfile::certs; +use rustls::internal::pemfile::pkcs8_private_keys; +use rustls::internal::pemfile::rsa_private_keys; +use rustls::Certificate; +use rustls::ClientConfig; +use rustls::ClientSession; +use rustls::NoClientAuth; +use rustls::PrivateKey; +use rustls::ServerConfig; +use rustls::ServerSession; +use rustls::Session; +use rustls::StoresClientSessions; use serde::Deserialize; use std::borrow::Cow; use std::cell::RefCell; use std::collections::HashMap; use std::convert::From; use std::fs::File; +use std::io; use std::io::BufReader; +use std::io::ErrorKind; +use std::ops::Deref; +use std::ops::DerefMut; use std::path::Path; +use std::pin::Pin; use std::rc::Rc; use std::sync::Arc; use std::sync::Mutex; +use std::sync::Weak; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; use tokio::net::TcpListener; use tokio::net::TcpStream; -use tokio_rustls::{rustls::ClientConfig, TlsConnector}; -use tokio_rustls::{ - rustls::{ - internal::pemfile::{certs, pkcs8_private_keys, rsa_private_keys}, - Certificate, NoClientAuth, PrivateKey, ServerConfig, StoresClientSessions, - }, - TlsAcceptor, -}; +use tokio::task::spawn_local; use webpki::DNSNameRef; lazy_static::lazy_static! { @@ -73,6 +101,567 @@ impl StoresClientSessions for ClientSessionMemoryCache { } } +#[derive(Debug)] +enum TlsSession { + Client(ClientSession), + Server(ServerSession), +} + +impl Deref for TlsSession { + type Target = dyn Session; + + fn deref(&self) -> &Self::Target { + match self { + TlsSession::Client(client_session) => client_session, + TlsSession::Server(server_session) => server_session, + } + } +} + +impl DerefMut for TlsSession { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + TlsSession::Client(client_session) => client_session, + TlsSession::Server(server_session) => server_session, + } + } +} + +impl From for TlsSession { + fn from(client_session: ClientSession) -> Self { + TlsSession::Client(client_session) + } +} + +impl From for TlsSession { + fn from(server_session: ServerSession) -> Self { + TlsSession::Server(server_session) + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +enum Flow { + Read, + Write, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +enum State { + StreamOpen, + StreamClosed, + TlsClosing, + TlsClosed, + TcpClosed, +} + +#[derive(Debug)] +pub struct TlsStream(Option); + +impl TlsStream { + fn new(tcp: TcpStream, tls: TlsSession) -> Self { + let inner = TlsStreamInner { + tcp, + tls, + rd_state: State::StreamOpen, + wr_state: State::StreamOpen, + }; + Self(Some(inner)) + } + + pub fn new_client_side( + tcp: TcpStream, + tls_config: &Arc, + hostname: DNSNameRef, + ) -> Self { + let tls = TlsSession::Client(ClientSession::new(tls_config, hostname)); + Self::new(tcp, tls) + } + + pub fn new_server_side( + tcp: TcpStream, + tls_config: &Arc, + ) -> Self { + let tls = TlsSession::Server(ServerSession::new(tls_config)); + Self::new(tcp, tls) + } + + pub async fn handshake(&mut self) -> io::Result<()> { + poll_fn(|cx| self.inner_mut().poll_io(cx, Flow::Write)).await + } + + fn into_split(self) -> (ReadHalf, WriteHalf) { + let shared = Shared::new(self); + let rd = ReadHalf { + shared: shared.clone(), + }; + let wr = WriteHalf { shared }; + (rd, wr) + } + + /// Tokio-rustls compatibility: returns a reference to the underlying TCP + /// stream, and a reference to the Rustls `Session` object. + pub fn get_ref(&self) -> (&TcpStream, &dyn Session) { + let inner = self.0.as_ref().unwrap(); + (&inner.tcp, &*inner.tls) + } + + fn inner_mut(&mut self) -> &mut TlsStreamInner { + self.0.as_mut().unwrap() + } +} + +impl AsyncRead for TlsStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.inner_mut().poll_read(cx, buf) + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.inner_mut().poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.inner_mut().poll_io(cx, Flow::Write) + // The underlying TCP stream does not need to be flushed. + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.inner_mut().poll_shutdown(cx) + } +} + +impl Drop for TlsStream { + fn drop(&mut self) { + let mut inner = self.0.take().unwrap(); + + let mut cx = Context::from_waker(noop_waker_ref()); + let use_linger_task = inner.poll_close(&mut cx).is_pending(); + + if use_linger_task { + spawn_local(poll_fn(move |cx| inner.poll_close(cx))); + } else if cfg!(debug_assertions) { + spawn_local(async {}); // Spawn dummy task to detect missing LocalSet. + } + } +} + +#[derive(Debug)] +pub struct TlsStreamInner { + tls: TlsSession, + tcp: TcpStream, + rd_state: State, + wr_state: State, +} + +impl TlsStreamInner { + fn poll_io( + &mut self, + cx: &mut Context<'_>, + flow: Flow, + ) -> Poll> { + loop { + let wr_ready = loop { + match self.wr_state { + _ if self.tls.is_handshaking() && !self.tls.wants_write() => { + break true; + } + _ if self.tls.is_handshaking() => {} + State::StreamOpen if !self.tls.wants_write() => break true, + State::StreamClosed => { + // Rustls will enqueue the 'CloseNotify' alert and send it after + // flusing the data that is already in the queue. + self.tls.send_close_notify(); + self.wr_state = State::TlsClosing; + continue; + } + State::TlsClosing if !self.tls.wants_write() => { + self.wr_state = State::TlsClosed; + continue; + } + // If a 'CloseNotify' alert sent by the remote end has been received, + // shut down the underlying TCP socket. Otherwise, consider polling + // done for the moment. + State::TlsClosed if self.rd_state < State::TlsClosed => break true, + State::TlsClosed + if Pin::new(&mut self.tcp).poll_shutdown(cx)?.is_pending() => + { + break false; + } + State::TlsClosed => { + self.wr_state = State::TcpClosed; + continue; + } + State::TcpClosed => break true, + _ => {} + } + + // Poll whether there is space in the socket send buffer so we can flush + // the remaining outgoing ciphertext. + if self.tcp.poll_write_ready(cx)?.is_pending() { + break false; + } + + // Write ciphertext to the TCP socket. + let mut wrapped_tcp = ImplementWriteTrait(&mut self.tcp); + match self.tls.write_tls(&mut wrapped_tcp) { + Ok(0) => unreachable!(), + Ok(_) => {} + Err(err) if err.kind() == ErrorKind::WouldBlock => {} + Err(err) => return Poll::Ready(Err(err)), + } + }; + + let rd_ready = loop { + match self.rd_state { + State::TcpClosed if self.tls.is_handshaking() => { + let err = Error::new(ErrorKind::UnexpectedEof, "tls handshake eof"); + return Poll::Ready(Err(err)); + } + _ if self.tls.is_handshaking() && !self.tls.wants_read() => { + break true; + } + _ if self.tls.is_handshaking() => {} + State::StreamOpen if !self.tls.wants_read() => break true, + State::StreamOpen => {} + State::StreamClosed if !self.tls.wants_read() => { + // Rustls has more incoming cleartext buffered up, but the TLS + // session is closing so this data will never be processed by the + // application layer. Just like what would happen if this were a raw + // TCP stream, don't gracefully end the TLS session, but abort it. + return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); + } + State::StreamClosed => {} + State::TlsClosed if self.wr_state == State::TcpClosed => { + // Wait for the remote end to gracefully close the TCP connection. + // TODO(piscisaureus): this is unnecessary; remove when stable. + } + _ => break true, + } + + if self.rd_state < State::TlsClosed { + // Do a zero-length plaintext read so we can detect the arrival of + // 'CloseNotify' messages, even if only the write half is open. + // Actually reading data from the socket is done in `poll_read()`. + match self.tls.read(&mut []) { + Ok(0) => {} + Err(err) if err.kind() == ErrorKind::ConnectionAborted => { + // `Session::read()` returns `ConnectionAborted` when a + // 'CloseNotify' alert has been received, which indicates that + // the remote peer wants to gracefully end the TLS session. + self.rd_state = State::TlsClosed; + continue; + } + Err(err) => return Poll::Ready(Err(err)), + _ => unreachable!(), + } + } + + // Poll whether more ciphertext is available in the socket receive + // buffer. + if self.tcp.poll_read_ready(cx)?.is_pending() { + break false; + } + + // Receive ciphertext from the socket. + let mut wrapped_tcp = ImplementReadTrait(&mut self.tcp); + match self.tls.read_tls(&mut wrapped_tcp) { + Ok(0) => self.rd_state = State::TcpClosed, + Ok(_) => self + .tls + .process_new_packets() + .map_err(|err| Error::new(ErrorKind::InvalidData, err))?, + Err(err) if err.kind() == ErrorKind::WouldBlock => {} + Err(err) => return Poll::Ready(Err(err)), + } + }; + + if wr_ready { + if self.rd_state >= State::TlsClosed + && self.wr_state >= State::TlsClosed + && self.wr_state < State::TcpClosed + { + continue; + } + if self.tls.wants_write() { + continue; + } + } + + let io_ready = match flow { + _ if self.tls.is_handshaking() => false, + Flow::Read => rd_ready, + Flow::Write => wr_ready, + }; + return match io_ready { + false => Poll::Pending, + true => Poll::Ready(Ok(())), + }; + } + } + + fn poll_read( + &mut self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + ready!(self.poll_io(cx, Flow::Read))?; + + if self.rd_state == State::StreamOpen { + let buf_slice = + unsafe { &mut *(buf.unfilled_mut() as *mut [_] as *mut [u8]) }; + let bytes_read = self.tls.read(buf_slice)?; + assert_ne!(bytes_read, 0); + unsafe { buf.assume_init(bytes_read) }; + buf.advance(bytes_read); + } + + Poll::Ready(Ok(())) + } + + fn poll_write( + &mut self, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if buf.is_empty() { + // Tokio-rustls compatibility: a zero byte write always succeeds. + Poll::Ready(Ok(0)) + } else if self.wr_state == State::StreamOpen { + // Flush Rustls' ciphertext send queue. + ready!(self.poll_io(cx, Flow::Write))?; + + // Copy data from `buf` to the Rustls cleartext send queue. + let bytes_written = self.tls.write(buf)?; + assert_ne!(bytes_written, 0); + + // Try to flush as much ciphertext as possible. However, since we just + // handed off at least some bytes to rustls, so we can't return + // `Poll::Pending()` any more: this would tell the caller that it should + // try to send those bytes again. + let _ = self.poll_io(cx, Flow::Write)?; + + Poll::Ready(Ok(bytes_written)) + } else { + // Return error if stream has been shut down for writing. + Poll::Ready(Err(ErrorKind::BrokenPipe.into())) + } + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.wr_state == State::StreamOpen { + self.wr_state = State::StreamClosed; + } + + ready!(self.poll_io(cx, Flow::Write))?; + + // At minimum, a TLS 'CloseNotify' alert should have been sent. + assert!(self.wr_state >= State::TlsClosed); + // If we received a TLS 'CloseNotify' alert from the remote end + // already, the TCP socket should be shut down at this point. + assert!( + self.rd_state < State::TlsClosed || self.wr_state == State::TcpClosed + ); + + Poll::Ready(Ok(())) + } + + fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.rd_state == State::StreamOpen { + self.rd_state = State::StreamClosed; + } + + // Send TLS 'CloseNotify' alert. + ready!(self.poll_shutdown(cx))?; + // Wait for 'CloseNotify', shut down TCP stream, wait for TCP FIN packet. + ready!(self.poll_io(cx, Flow::Read))?; + + assert_eq!(self.rd_state, State::TcpClosed); + assert_eq!(self.wr_state, State::TcpClosed); + + Poll::Ready(Ok(())) + } +} + +#[derive(Debug)] +pub struct ReadHalf { + shared: Arc, +} + +impl ReadHalf { + pub fn reunite(self, wr: WriteHalf) -> TlsStream { + assert!(Arc::ptr_eq(&self.shared, &wr.shared)); + drop(wr); // Drop `wr`, so only one strong reference to `shared` remains. + + Arc::try_unwrap(self.shared) + .unwrap_or_else(|_| panic!("Arc::::try_unwrap() failed")) + .tls_stream + .into_inner() + .unwrap() + } +} + +impl AsyncRead for ReadHalf { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self + .shared + .poll_with_shared_waker(cx, Flow::Read, move |tls, cx| { + tls.poll_read(cx, buf) + }) + } +} + +#[derive(Debug)] +pub struct WriteHalf { + shared: Arc, +} + +impl AsyncWrite for WriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self + .shared + .poll_with_shared_waker(cx, Flow::Write, move |tls, cx| { + tls.poll_write(cx, buf) + }) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self + .shared + .poll_with_shared_waker(cx, Flow::Write, |tls, cx| tls.poll_flush(cx)) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self + .shared + .poll_with_shared_waker(cx, Flow::Write, |tls, cx| tls.poll_shutdown(cx)) + } +} + +#[derive(Debug)] +struct Shared { + tls_stream: Mutex, + rd_waker: AtomicWaker, + wr_waker: AtomicWaker, +} + +impl Shared { + fn new(tls_stream: TlsStream) -> Arc { + let self_ = Self { + tls_stream: Mutex::new(tls_stream), + rd_waker: AtomicWaker::new(), + wr_waker: AtomicWaker::new(), + }; + Arc::new(self_) + } + + fn poll_with_shared_waker( + self: &Arc, + cx: &mut Context<'_>, + flow: Flow, + mut f: impl FnMut(Pin<&mut TlsStream>, &mut Context<'_>) -> R, + ) -> R { + match flow { + Flow::Read => self.rd_waker.register(cx.waker()), + Flow::Write => self.wr_waker.register(cx.waker()), + } + + let shared_waker = self.new_shared_waker(); + let mut cx = Context::from_waker(&shared_waker); + + let mut tls_stream = self.tls_stream.lock().unwrap(); + f(Pin::new(&mut tls_stream), &mut cx) + } + + const SHARED_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( + Self::clone_shared_waker, + Self::wake_shared_waker, + Self::wake_shared_waker_by_ref, + Self::drop_shared_waker, + ); + + fn new_shared_waker(self: &Arc) -> Waker { + let self_weak = Arc::downgrade(self); + let self_ptr = self_weak.into_raw() as *const (); + let raw_waker = RawWaker::new(self_ptr, &Self::SHARED_WAKER_VTABLE); + unsafe { Waker::from_raw(raw_waker) } + } + + fn clone_shared_waker(self_ptr: *const ()) -> RawWaker { + let self_weak = unsafe { Weak::from_raw(self_ptr as *const Self) }; + let ptr1 = self_weak.clone().into_raw(); + let ptr2 = self_weak.into_raw(); + assert!(ptr1 == ptr2); + RawWaker::new(self_ptr, &Self::SHARED_WAKER_VTABLE) + } + + fn wake_shared_waker(self_ptr: *const ()) { + Self::wake_shared_waker_by_ref(self_ptr); + Self::drop_shared_waker(self_ptr); + } + + fn wake_shared_waker_by_ref(self_ptr: *const ()) { + let self_weak = unsafe { Weak::from_raw(self_ptr as *const Self) }; + if let Some(self_arc) = Weak::upgrade(&self_weak) { + self_arc.rd_waker.wake(); + self_arc.wr_waker.wake(); + } + self_weak.into_raw(); + } + + fn drop_shared_waker(self_ptr: *const ()) { + let _ = unsafe { Weak::from_raw(self_ptr as *const Self) }; + } +} + +struct ImplementReadTrait<'a, T>(&'a mut T); + +impl Read for ImplementReadTrait<'_, TcpStream> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.try_read(buf) + } +} + +struct ImplementWriteTrait<'a, T>(&'a mut T); + +impl Write for ImplementWriteTrait<'_, TcpStream> { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.try_write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + pub fn init() -> Extension { Extension::builder() .ops(vec![ @@ -107,21 +696,25 @@ async fn op_start_tls( _: (), ) -> Result { let rid = args.rid; + let hostname = match &*args.hostname { + "" => "localhost", + n => n, + }; + let cert_file = args.cert_file.as_deref(); - let mut domain = args.hostname.as_str(); - if domain.is_empty() { - domain = "localhost"; - } { super::check_unstable2(&state, "Deno.startTls"); let mut s = state.borrow_mut(); let permissions = s.borrow_mut::(); - permissions.net.check(&(&domain, Some(0)))?; - if let Some(path) = &args.cert_file { - permissions.read.check(Path::new(&path))?; + permissions.net.check(&(hostname, Some(0)))?; + if let Some(path) = cert_file { + permissions.read.check(Path::new(path))?; } } + let hostname_dns = DNSNameRef::try_from_ascii_str(hostname) + .map_err(|_| invalid_hostname(hostname))?; + let resource_rc = state .borrow_mut() .resource_table @@ -134,28 +727,29 @@ async fn op_start_tls( let local_addr = tcp_stream.local_addr()?; let remote_addr = tcp_stream.peer_addr()?; - let mut config = ClientConfig::new(); - config.set_persistence(CLIENT_SESSION_MEMORY_CACHE.clone()); - config + + let mut tls_config = ClientConfig::new(); + tls_config.set_persistence(CLIENT_SESSION_MEMORY_CACHE.clone()); + tls_config .root_store .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); - if let Some(path) = args.cert_file { + if let Some(path) = cert_file { let key_file = File::open(path)?; let reader = &mut BufReader::new(key_file); - config.root_store.add_pem_file(reader).unwrap(); + tls_config.root_store.add_pem_file(reader).unwrap(); } + let tls_config = Arc::new(tls_config); - let tls_connector = TlsConnector::from(Arc::new(config)); - let dnsname = DNSNameRef::try_from_ascii_str(domain) - .map_err(|_| invalid_hostname(domain))?; - let tls_stream = tls_connector.connect(dnsname, tcp_stream).await?; + let tls_stream = + TlsStream::new_client_side(tcp_stream, &tls_config, hostname_dns); let rid = { let mut state_ = state.borrow_mut(); state_ .resource_table - .add(TlsClientStreamResource::from(tls_stream)) + .add(TlsStreamResource::new(tls_stream.into_split())) }; + Ok(OpConn { rid, local_addr: Some(OpAddr::Tcp(IpAddr { @@ -175,47 +769,55 @@ async fn op_connect_tls( _: (), ) -> Result { assert_eq!(args.transport, "tcp"); + let hostname = match &*args.hostname { + "" => "localhost", + n => n, + }; + let port = args.port; + let cert_file = args.cert_file.as_deref(); - let mut domain = args.hostname.as_str(); - if domain.is_empty() { - domain = "localhost"; - } { let mut s = state.borrow_mut(); let permissions = s.borrow_mut::(); - permissions.net.check(&(domain, Some(args.port)))?; - if let Some(path) = &args.cert_file { - permissions.read.check(Path::new(&path))?; + permissions.net.check(&(hostname, Some(port)))?; + if let Some(path) = cert_file { + permissions.read.check(Path::new(path))?; } } - let dnsname = DNSNameRef::try_from_ascii_str(domain) - .map_err(|_| invalid_hostname(domain))?; - let addr = resolve_addr(domain, args.port) + let hostname_dns = DNSNameRef::try_from_ascii_str(hostname) + .map_err(|_| invalid_hostname(hostname))?; + + let connect_addr = resolve_addr(hostname, port) .await? .next() .ok_or_else(|| generic_error("No resolved address found"))?; - let tcp_stream = TcpStream::connect(&addr).await?; + let tcp_stream = TcpStream::connect(connect_addr).await?; let local_addr = tcp_stream.local_addr()?; let remote_addr = tcp_stream.peer_addr()?; - let mut config = ClientConfig::new(); - config.set_persistence(CLIENT_SESSION_MEMORY_CACHE.clone()); - config + + let mut tls_config = ClientConfig::new(); + tls_config.set_persistence(CLIENT_SESSION_MEMORY_CACHE.clone()); + tls_config .root_store .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); - if let Some(path) = args.cert_file { + if let Some(path) = cert_file { let key_file = File::open(path)?; let reader = &mut BufReader::new(key_file); - config.root_store.add_pem_file(reader).unwrap(); + tls_config.root_store.add_pem_file(reader).unwrap(); } - let tls_connector = TlsConnector::from(Arc::new(config)); - let tls_stream = tls_connector.connect(dnsname, tcp_stream).await?; + let tls_config = Arc::new(tls_config); + + let tls_stream = + TlsStream::new_client_side(tcp_stream, &tls_config, hostname_dns); + let rid = { let mut state_ = state.borrow_mut(); state_ .resource_table - .add(TlsClientStreamResource::from(tls_stream)) + .add(TlsStreamResource::new(tls_stream.into_split())) }; + Ok(OpConn { rid, local_addr: Some(OpAddr::Tcp(IpAddr { @@ -284,9 +886,9 @@ fn load_keys(path: &str) -> Result, AnyError> { } pub struct TlsListenerResource { - listener: AsyncRefCell, - tls_acceptor: TlsAcceptor, - cancel: CancelHandle, + tcp_listener: AsyncRefCell, + tls_config: Arc, + cancel_handle: CancelHandle, } impl Resource for TlsListenerResource { @@ -295,7 +897,7 @@ impl Resource for TlsListenerResource { } fn close(self: Rc) { - self.cancel.cancel(); + self.cancel_handle.cancel(); } } @@ -316,36 +918,40 @@ fn op_listen_tls( _: (), ) -> Result { assert_eq!(args.transport, "tcp"); + let hostname = &*args.hostname; + let port = args.port; + let cert_file = &*args.cert_file; + let key_file = &*args.key_file; - let cert_file = args.cert_file; - let key_file = args.key_file; { let permissions = state.borrow_mut::(); - permissions.net.check(&(&args.hostname, Some(args.port)))?; - permissions.read.check(Path::new(&cert_file))?; - permissions.read.check(Path::new(&key_file))?; + permissions.net.check(&(hostname, Some(port)))?; + permissions.read.check(Path::new(cert_file))?; + permissions.read.check(Path::new(key_file))?; } - let mut config = ServerConfig::new(NoClientAuth::new()); + + let mut tls_config = ServerConfig::new(NoClientAuth::new()); if let Some(alpn_protocols) = args.alpn_protocols { super::check_unstable(state, "Deno.listenTls#alpn_protocols"); - config.alpn_protocols = + tls_config.alpn_protocols = alpn_protocols.into_iter().map(|s| s.into_bytes()).collect(); } - config - .set_single_cert(load_certs(&cert_file)?, load_keys(&key_file)?.remove(0)) + tls_config + .set_single_cert(load_certs(cert_file)?, load_keys(key_file)?.remove(0)) .expect("invalid key or certificate"); - let tls_acceptor = TlsAcceptor::from(Arc::new(config)); - let addr = resolve_addr_sync(&args.hostname, args.port)? + + let bind_addr = resolve_addr_sync(hostname, port)? .next() .ok_or_else(|| generic_error("No resolved address found"))?; - let std_listener = std::net::TcpListener::bind(&addr)?; + let std_listener = std::net::TcpListener::bind(bind_addr)?; std_listener.set_nonblocking(true)?; - let listener = TcpListener::from_std(std_listener)?; - let local_addr = listener.local_addr()?; + let tcp_listener = TcpListener::from_std(std_listener)?; + let local_addr = tcp_listener.local_addr()?; + let tls_listener_resource = TlsListenerResource { - listener: AsyncRefCell::new(listener), - tls_acceptor, - cancel: Default::default(), + tcp_listener: AsyncRefCell::new(tcp_listener), + tls_config: Arc::new(tls_config), + cancel_handle: Default::default(), }; let rid = state.resource_table.add(tls_listener_resource); @@ -370,38 +976,31 @@ async fn op_accept_tls( .resource_table .get::(rid) .ok_or_else(|| bad_resource("Listener has been closed"))?; - let listener = RcRef::map(&resource, |r| &r.listener) + + let cancel_handle = RcRef::map(&resource, |r| &r.cancel_handle); + let tcp_listener = RcRef::map(&resource, |r| &r.tcp_listener) .try_borrow_mut() .ok_or_else(|| custom_error("Busy", "Another accept task is ongoing"))?; - let cancel = RcRef::map(resource, |r| &r.cancel); - let (tcp_stream, _socket_addr) = - listener.accept().try_or_cancel(cancel).await.map_err(|e| { - // FIXME(bartlomieju): compatibility with current JS implementation - if let std::io::ErrorKind::Interrupted = e.kind() { - bad_resource("Listener has been closed") - } else { - e.into() + + let (tcp_stream, remote_addr) = + match tcp_listener.accept().try_or_cancel(&cancel_handle).await { + Ok(tuple) => tuple, + Err(err) if err.kind() == ErrorKind::Interrupted => { + // FIXME(bartlomieju): compatibility with current JS implementation. + return Err(bad_resource("Listener has been closed")); } - })?; + Err(err) => return Err(err.into()), + }; + let local_addr = tcp_stream.local_addr()?; - let remote_addr = tcp_stream.peer_addr()?; - let resource = state - .borrow() - .resource_table - .get::(rid) - .ok_or_else(|| bad_resource("Listener has been closed"))?; - let cancel = RcRef::map(&resource, |r| &r.cancel); - let tls_acceptor = resource.tls_acceptor.clone(); - let tls_stream = tls_acceptor - .accept(tcp_stream) - .try_or_cancel(cancel) - .await?; + + let tls_stream = TlsStream::new_server_side(tcp_stream, &resource.tls_config); let rid = { let mut state_ = state.borrow_mut(); state_ .resource_table - .add(TlsServerStreamResource::from(tls_stream)) + .add(TlsStreamResource::new(tls_stream.into_split())) }; Ok(OpConn {