diff --git a/cli/tests/unit/tls_test.ts b/cli/tests/unit/tls_test.ts index b2382833f7..17f2312c0e 100644 --- a/cli/tests/unit/tls_test.ts +++ b/cli/tests/unit/tls_test.ts @@ -1121,3 +1121,139 @@ unitTest( conn.close(); }, ); + +// TODO(piscisaureus): use `TlsConn.handhake()` instead, once this is added to +// the public API in Deno 1.16. +function tlsHandshake(conn: Deno.Conn): Promise { + // deno-lint-ignore no-explicit-any + const opAsync = (Deno as any).core.opAsync; + return opAsync("op_tls_handshake", conn.rid); +} + +unitTest( + { permissions: { read: true, net: true } }, + async function tlsHandshakeSuccess() { + const hostname = "localhost"; + const port = getPort(); + + const listener = Deno.listenTls({ + hostname, + port, + certFile: "cli/tests/testdata/tls/localhost.crt", + keyFile: "cli/tests/testdata/tls/localhost.key", + }); + const acceptPromise = listener.accept(); + const connectPromise = Deno.connectTls({ + hostname, + port, + certFile: "cli/tests/testdata/tls/RootCA.crt", + }); + const [conn1, conn2] = await Promise.all([acceptPromise, connectPromise]); + listener.close(); + + await Promise.all([tlsHandshake(conn1), tlsHandshake(conn2)]); + + // Begin sending a 10mb blob over the TLS connection. + const whole = new Uint8Array(10 << 20); // 10mb. + whole.fill(42); + const sendPromise = conn1.write(whole); + + // Set up the other end to receive half of the large blob. + const half = new Uint8Array(whole.byteLength / 2); + const receivePromise = readFull(conn2, half); + + await tlsHandshake(conn1); + await tlsHandshake(conn2); + + // Finish receiving the first 5mb. + assertEquals(await receivePromise, half.length); + + // See that we can call `handshake()` in the middle of large reads and writes. + await tlsHandshake(conn1); + await tlsHandshake(conn2); + + // Receive second half of large blob. Wait for the send promise and check it. + assertEquals(await readFull(conn2, half), half.length); + assertEquals(await sendPromise, whole.length); + + await tlsHandshake(conn1); + await tlsHandshake(conn2); + + await conn1.closeWrite(); + await conn2.closeWrite(); + + await tlsHandshake(conn1); + await tlsHandshake(conn2); + + conn1.close(); + conn2.close(); + + async function readFull(conn: Deno.Conn, buf: Uint8Array) { + let offset, n; + for (offset = 0; offset < buf.length; offset += n) { + n = await conn.read(buf.subarray(offset, buf.length)); + assert(n != null && n > 0); + } + return offset; + } + }, +); + +unitTest( + { permissions: { read: true, net: true } }, + async function tlsHandshakeFailure() { + const hostname = "localhost"; + const port = getPort(); + + async function server() { + const listener = Deno.listenTls({ + hostname, + port, + certFile: "cli/tests/testdata/tls/localhost.crt", + keyFile: "cli/tests/testdata/tls/localhost.key", + }); + for await (const conn of listener) { + for (let i = 0; i < 10; i++) { + // Handshake fails because the client rejects the server certificate. + await assertRejects( + () => tlsHandshake(conn), + Deno.errors.InvalidData, + "BadCertificate", + ); + } + conn.close(); + break; + } + } + + async function connectTlsClient() { + const conn = await Deno.connectTls({ hostname, port }); + // Handshake fails because the server presents a self-signed certificate. + await assertRejects( + () => tlsHandshake(conn), + Deno.errors.InvalidData, + "UnknownIssuer", + ); + conn.close(); + } + + await Promise.all([server(), connectTlsClient()]); + + async function startTlsClient() { + const tcpConn = await Deno.connect({ hostname, port }); + const tlsConn = await Deno.startTls(tcpConn, { + hostname: "foo.land", + certFile: "cli/tests/testdata/tls/RootCA.crt", + }); + // Handshake fails because hostname doesn't match the certificate. + await assertRejects( + () => tlsHandshake(tlsConn), + Deno.errors.InvalidData, + "CertNotValidForName", + ); + tlsConn.close(); + } + + await Promise.all([server(), startTlsClient()]); + }, +); diff --git a/ext/net/README.md b/ext/net/README.md index cdd8923e1c..6377e79749 100644 --- a/ext/net/README.md +++ b/ext/net/README.md @@ -22,6 +22,7 @@ Following ops are provided: - "op_connect_tls" - "op_listen_tls" - "op_accept_tls" +- "op_tls_handshake" - "op_http_start" - "op_http_request_next" - "op_http_request_read" diff --git a/ext/net/io.rs b/ext/net/io.rs index 6a93b8cf6e..6cefbde2de 100644 --- a/ext/net/io.rs +++ b/ext/net/io.rs @@ -1,6 +1,6 @@ // Copyright 2018-2021 the Deno authors. All rights reserved. MIT license. -use crate::ops_tls as tls; +use crate::ops_tls::TlsStreamResource; use deno_core::error::not_supported; use deno_core::error::AnyError; use deno_core::op_async; @@ -114,18 +114,6 @@ impl Resource for TcpStreamResource { } } -pub type TlsStreamResource = FullDuplexResource; - -impl Resource for TlsStreamResource { - fn name(&self) -> Cow { - "tlsStream".into() - } - - fn close(self: Rc) { - self.cancel_read_ops(); - } -} - #[cfg(unix)] pub type UnixStreamResource = FullDuplexResource; diff --git a/ext/net/ops_tls.rs b/ext/net/ops_tls.rs index d6618440fb..129a702bcf 100644 --- a/ext/net/ops_tls.rs +++ b/ext/net/ops_tls.rs @@ -1,7 +1,6 @@ // Copyright 2018-2021 the Deno authors. All rights reserved. MIT license. use crate::io::TcpStreamResource; -use crate::io::TlsStreamResource; use crate::ops::IpAddr; use crate::ops::OpAddr; use crate::ops::OpConn; @@ -53,6 +52,7 @@ use io::Read; use io::Write; use serde::Deserialize; use std::borrow::Cow; +use std::cell::Cell; use std::cell::RefCell; use std::convert::From; use std::fs::File; @@ -67,7 +67,9 @@ use std::rc::Rc; use std::sync::Arc; use std::sync::Weak; use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; use tokio::io::ReadBuf; use tokio::net::TcpListener; use tokio::net::TcpStream; @@ -113,6 +115,7 @@ impl From for TlsSession { #[derive(Copy, Clone, Debug, Eq, PartialEq)] enum Flow { + Handshake, Read, Write, } @@ -123,6 +126,7 @@ enum State { StreamClosed, TlsClosing, TlsClosed, + TlsError, TcpClosed, } @@ -157,10 +161,6 @@ impl TlsStream { Self::new(tcp, tls) } - pub async fn handshake(&mut self) -> io::Result<()> { - poll_fn(|cx| self.inner_mut().poll_io(cx, Flow::Write)).await - } - fn into_split(self) -> (ReadHalf, WriteHalf) { let shared = Shared::new(self); let rd = ReadHalf { @@ -180,6 +180,14 @@ impl TlsStream { fn inner_mut(&mut self) -> &mut TlsStreamInner { self.0.as_mut().unwrap() } + + pub async fn handshake(&mut self) -> io::Result<()> { + poll_fn(|cx| self.inner_mut().poll_handshake(cx)).await + } + + fn poll_handshake(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner_mut().poll_handshake(cx) + } } impl AsyncRead for TlsStream { @@ -282,20 +290,20 @@ impl TlsStreamInner { _ => {} } + // Write ciphertext to the TCP socket. + let mut wrapped_tcp = ImplementWriteTrait(&mut self.tcp); + match self.tls.write_tls(&mut wrapped_tcp) { + Ok(0) => {} // Wait until the socket has enough buffer space. + Ok(_) => continue, // Try to send more more data immediately. + Err(err) if err.kind() == ErrorKind::WouldBlock => unreachable!(), + Err(err) => return Poll::Ready(Err(err)), + } + // Poll whether there is space in the socket send buffer so we can flush // the remaining outgoing ciphertext. if self.tcp.poll_write_ready(cx)?.is_pending() { break false; } - - // Write ciphertext to the TCP socket. - let mut wrapped_tcp = ImplementWriteTrait(&mut self.tcp); - match self.tls.write_tls(&mut wrapped_tcp) { - Ok(0) => unreachable!(), - Ok(_) => {} - Err(err) if err.kind() == ErrorKind::WouldBlock => {} - Err(err) => return Poll::Ready(Err(err)), - } }; let rd_ready = loop { @@ -304,6 +312,7 @@ impl TlsStreamInner { let err = Error::new(ErrorKind::UnexpectedEof, "tls handshake eof"); return Poll::Ready(Err(err)); } + State::TlsError => {} _ if self.tls.is_handshaking() && !self.tls.wants_read() => { break true; } @@ -343,22 +352,36 @@ impl TlsStreamInner { } } - // Poll whether more ciphertext is available in the socket receive - // buffer. - if self.tcp.poll_read_ready(cx)?.is_pending() { - break false; + if self.rd_state != State::TlsError { + // Receive ciphertext from the socket. + let mut wrapped_tcp = ImplementReadTrait(&mut self.tcp); + match self.tls.read_tls(&mut wrapped_tcp) { + Ok(0) => { + // End of TCP stream. + self.rd_state = State::TcpClosed; + continue; + } + Err(err) if err.kind() == ErrorKind::WouldBlock => { + // Get notified when more ciphertext becomes available in the + // socket receive buffer. + if self.tcp.poll_read_ready(cx)?.is_pending() { + break false; + } else { + continue; + } + } + Err(err) => return Poll::Ready(Err(err)), + _ => {} + } } - // Receive ciphertext from the socket. - let mut wrapped_tcp = ImplementReadTrait(&mut self.tcp); - match self.tls.read_tls(&mut wrapped_tcp) { - Ok(0) => self.rd_state = State::TcpClosed, - Ok(_) => self - .tls - .process_new_packets() - .map_err(|err| Error::new(ErrorKind::InvalidData, err))?, - Err(err) if err.kind() == ErrorKind::WouldBlock => {} - Err(err) => return Poll::Ready(Err(err)), + // Interpret and decrypt TLS protocol data. + match self.tls.process_new_packets() { + Ok(_) => assert!(self.rd_state < State::TcpClosed), + Err(err) => { + self.rd_state = State::TlsError; + return Poll::Ready(Err(Error::new(ErrorKind::InvalidData, err))); + } } }; @@ -376,6 +399,7 @@ impl TlsStreamInner { let io_ready = match flow { _ if self.tls.is_handshaking() => false, + Flow::Handshake => true, Flow::Read => rd_ready, Flow::Write => wr_ready, }; @@ -386,6 +410,13 @@ impl TlsStreamInner { } } + fn poll_handshake(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.tls.is_handshaking() { + ready!(self.poll_io(cx, Flow::Handshake))?; + } + Poll::Ready(Ok(())) + } + fn poll_read( &mut self, cx: &mut Context<'_>, @@ -505,6 +536,19 @@ pub struct WriteHalf { shared: Arc, } +impl WriteHalf { + pub async fn handshake(&mut self) -> io::Result<()> { + poll_fn(|cx| { + self + .shared + .poll_with_shared_waker(cx, Flow::Write, |mut tls, cx| { + tls.poll_handshake(cx) + }) + }) + .await + } +} + impl AsyncWrite for WriteHalf { fn poll_write( self: Pin<&mut Self>, @@ -561,6 +605,7 @@ impl Shared { mut f: impl FnMut(Pin<&mut TlsStream>, &mut Context<'_>) -> R, ) -> R { match flow { + Flow::Handshake => unreachable!(), Flow::Read => self.rd_waker.register(cx.waker()), Flow::Write => self.wr_waker.register(cx.waker()), } @@ -625,7 +670,11 @@ struct ImplementWriteTrait<'a, T>(&'a mut T); impl Write for ImplementWriteTrait<'_, TcpStream> { fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.try_write(buf) + match self.0.try_write(buf) { + Ok(n) => Ok(n), + Err(err) if err.kind() == ErrorKind::WouldBlock => Ok(0), + Err(err) => Err(err), + } } fn flush(&mut self) -> io::Result<()> { @@ -639,9 +688,78 @@ pub fn init() -> Vec { ("op_connect_tls", op_async(op_connect_tls::

)), ("op_listen_tls", op_sync(op_listen_tls::

)), ("op_accept_tls", op_async(op_accept_tls)), + ("op_tls_handshake", op_async(op_tls_handshake)), ] } +#[derive(Debug)] +pub struct TlsStreamResource { + rd: AsyncRefCell, + wr: AsyncRefCell, + handshake_done: Cell, + cancel_handle: CancelHandle, // Only read and handshake ops get canceled. +} + +impl TlsStreamResource { + pub fn new((rd, wr): (ReadHalf, WriteHalf)) -> Self { + Self { + rd: rd.into(), + wr: wr.into(), + handshake_done: Cell::new(false), + cancel_handle: Default::default(), + } + } + + pub fn into_inner(self) -> (ReadHalf, WriteHalf) { + (self.rd.into_inner(), self.wr.into_inner()) + } + + pub async fn read( + self: &Rc, + buf: &mut [u8], + ) -> Result { + let mut rd = RcRef::map(self, |r| &r.rd).borrow_mut().await; + let cancel_handle = RcRef::map(self, |r| &r.cancel_handle); + let nread = rd.read(buf).try_or_cancel(cancel_handle).await?; + Ok(nread) + } + + pub async fn write(self: &Rc, buf: &[u8]) -> Result { + self.handshake().await?; + let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; + let nwritten = wr.write(buf).await?; + wr.flush().await?; + Ok(nwritten) + } + + pub async fn shutdown(self: &Rc) -> Result<(), AnyError> { + self.handshake().await?; + let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; + wr.shutdown().await?; + Ok(()) + } + + pub async fn handshake(self: &Rc) -> Result<(), AnyError> { + if !self.handshake_done.get() { + let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; + let cancel_handle = RcRef::map(self, |r| &r.cancel_handle); + wr.handshake().try_or_cancel(cancel_handle).await?; + self.handshake_done.set(true); + } + Ok(()) + } +} + +impl Resource for TlsStreamResource { + fn name(&self) -> Cow { + "tlsStream".into() + } + + fn close(self: Rc) { + self.cancel_handle.cancel(); + } +} + #[derive(Deserialize)] #[serde(rename_all = "camelCase")] pub struct ConnectTlsArgs { @@ -1015,3 +1133,15 @@ async fn op_accept_tls( })), }) } + +async fn op_tls_handshake( + state: Rc>, + rid: ResourceId, + _: (), +) -> Result<(), AnyError> { + let resource = state + .borrow() + .resource_table + .get::(rid)?; + resource.handshake().await +} diff --git a/runtime/ops/http.rs b/runtime/ops/http.rs index 58783bbbc2..683dc1a576 100644 --- a/runtime/ops/http.rs +++ b/runtime/ops/http.rs @@ -7,7 +7,7 @@ use deno_core::Extension; use deno_core::OpState; use deno_core::ResourceId; use deno_net::io::TcpStreamResource; -use deno_net::io::TlsStreamResource; +use deno_net::ops_tls::TlsStreamResource; pub fn init() -> Extension { Extension::builder() diff --git a/runtime/ops/io.rs b/runtime/ops/io.rs index 0687fc397f..7eb27bd125 100644 --- a/runtime/ops/io.rs +++ b/runtime/ops/io.rs @@ -16,8 +16,8 @@ use deno_core::Resource; use deno_core::ResourceId; use deno_core::ZeroCopyBuf; use deno_net::io::TcpStreamResource; -use deno_net::io::TlsStreamResource; use deno_net::io::UnixStreamResource; +use deno_net::ops_tls::TlsStreamResource; use std::borrow::Cow; use std::cell::RefCell; use std::io::Read;