1
0
Fork 0
mirror of https://github.com/denoland/deno.git synced 2025-01-12 00:54:02 -05:00

refactor(ext/websocket): Remove dep on tungstenite by reworking code (#18812)

This commit is contained in:
Matt Mastracci 2023-04-23 14:07:37 -06:00 committed by GitHub
parent c95477c49f
commit fafb2584ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 32 deletions

View file

@ -38,11 +38,12 @@ use std::future::Future;
use std::path::PathBuf; use std::path::PathBuf;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::rustls::RootCertStore; use tokio_rustls::rustls::RootCertStore;
use tokio_rustls::rustls::ServerName; use tokio_rustls::rustls::ServerName;
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
use tokio_tungstenite::MaybeTlsStream;
use fastwebsockets::CloseCode; use fastwebsockets::CloseCode;
use fastwebsockets::FragmentCollector; use fastwebsockets::FragmentCollector;
@ -129,6 +130,33 @@ pub struct CreateResponse {
extensions: String, extensions: String,
} }
async fn handshake<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
cancel_resource: Option<Rc<CancelHandle>>,
request: Request<Body>,
socket: S,
) -> Result<(WebSocket<WebSocketStream>, http::Response<Body>), AnyError> {
let client =
fastwebsockets::handshake::client(&LocalExecutor, request, socket);
let (upgraded, response) = if let Some(cancel_resource) = cancel_resource {
client.or_cancel(cancel_resource).await?
} else {
client.await
}
.map_err(|err| {
DomExceptionNetworkError::new(&format!(
"failed to connect to WebSocket: {err}"
))
})?;
let upgraded = upgraded.into_inner();
let stream =
WebSocketStream::new(stream::WsStreamKind::Upgraded(upgraded), None);
let stream = WebSocket::after_handshake(stream, Role::Client);
Ok((stream, response))
}
#[op] #[op]
pub async fn op_ws_create<WP>( pub async fn op_ws_create<WP>(
state: Rc<RefCell<OpState>>, state: Rc<RefCell<OpState>>,
@ -155,7 +183,7 @@ where
.borrow_mut() .borrow_mut()
.resource_table .resource_table
.get::<WsCancelResource>(cancel_rid)?; .get::<WsCancelResource>(cancel_rid)?;
Some(r) Some(r.0.clone())
} else { } else {
None None
}; };
@ -223,8 +251,8 @@ where
let addr = format!("{domain}:{port}"); let addr = format!("{domain}:{port}");
let tcp_socket = TcpStream::connect(addr).await?; let tcp_socket = TcpStream::connect(addr).await?;
let socket: MaybeTlsStream<TcpStream> = match uri.scheme_str() { let (stream, response) = match uri.scheme_str() {
Some("ws") => MaybeTlsStream::Plain(tcp_socket), Some("ws") => handshake(cancel_resource, request, tcp_socket).await?,
Some("wss") => { Some("wss") => {
let tls_config = create_client_config( let tls_config = create_client_config(
root_cert_store, root_cert_store,
@ -236,30 +264,11 @@ where
let dnsname = ServerName::try_from(domain.as_str()) let dnsname = ServerName::try_from(domain.as_str())
.map_err(|_| invalid_hostname(domain))?; .map_err(|_| invalid_hostname(domain))?;
let tls_socket = tls_connector.connect(dnsname, tcp_socket).await?; let tls_socket = tls_connector.connect(dnsname, tcp_socket).await?;
MaybeTlsStream::Rustls(tls_socket) handshake(cancel_resource, request, tls_socket).await?
} }
_ => unreachable!(), _ => unreachable!(),
}; };
let client =
fastwebsockets::handshake::client(&LocalExecutor, request, socket);
let (upgraded, response) = if let Some(cancel_resource) = cancel_resource {
client.or_cancel(cancel_resource.0.to_owned()).await?
} else {
client.await
}
.map_err(|err| {
DomExceptionNetworkError::new(&format!(
"failed to connect to WebSocket: {err}"
))
})?;
let inner = MaybeTlsStream::Plain(upgraded.into_inner());
let stream =
WebSocketStream::new(stream::WsStreamKind::Tungstenite(inner), None);
let stream = WebSocket::after_handshake(stream, Role::Client);
if let Some(cancel_rid) = cancel_handle { if let Some(cancel_rid) = cancel_handle {
state.borrow_mut().resource_table.close(cancel_rid).ok(); state.borrow_mut().resource_table.close(cancel_rid).ok();
} }

View file

@ -8,11 +8,10 @@ use std::task::Poll;
use tokio::io::AsyncRead; use tokio::io::AsyncRead;
use tokio::io::AsyncWrite; use tokio::io::AsyncWrite;
use tokio::io::ReadBuf; use tokio::io::ReadBuf;
use tokio_tungstenite::MaybeTlsStream;
// TODO(bartlomieju): remove this // TODO(bartlomieju): remove this
pub(crate) enum WsStreamKind { pub(crate) enum WsStreamKind {
Tungstenite(MaybeTlsStream<Upgraded>), Upgraded(Upgraded),
Network(NetworkStream), Network(NetworkStream),
} }
@ -54,7 +53,7 @@ impl AsyncRead for WebSocketStream {
} }
match &mut self.stream { match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_read(cx, buf), WsStreamKind::Network(stream) => Pin::new(stream).poll_read(cx, buf),
WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_read(cx, buf), WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_read(cx, buf),
} }
} }
} }
@ -67,7 +66,7 @@ impl AsyncWrite for WebSocketStream {
) -> std::task::Poll<Result<usize, std::io::Error>> { ) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut self.stream { match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_write(cx, buf), WsStreamKind::Network(stream) => Pin::new(stream).poll_write(cx, buf),
WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_write(cx, buf), WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_write(cx, buf),
} }
} }
@ -77,7 +76,7 @@ impl AsyncWrite for WebSocketStream {
) -> std::task::Poll<Result<(), std::io::Error>> { ) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.stream { match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_flush(cx), WsStreamKind::Network(stream) => Pin::new(stream).poll_flush(cx),
WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_flush(cx), WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_flush(cx),
} }
} }
@ -87,14 +86,14 @@ impl AsyncWrite for WebSocketStream {
) -> std::task::Poll<Result<(), std::io::Error>> { ) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.stream { match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_shutdown(cx), WsStreamKind::Network(stream) => Pin::new(stream).poll_shutdown(cx),
WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_shutdown(cx), WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_shutdown(cx),
} }
} }
fn is_write_vectored(&self) -> bool { fn is_write_vectored(&self) -> bool {
match &self.stream { match &self.stream {
WsStreamKind::Network(stream) => stream.is_write_vectored(), WsStreamKind::Network(stream) => stream.is_write_vectored(),
WsStreamKind::Tungstenite(stream) => stream.is_write_vectored(), WsStreamKind::Upgraded(stream) => stream.is_write_vectored(),
} }
} }
@ -107,7 +106,7 @@ impl AsyncWrite for WebSocketStream {
WsStreamKind::Network(stream) => { WsStreamKind::Network(stream) => {
Pin::new(stream).poll_write_vectored(cx, bufs) Pin::new(stream).poll_write_vectored(cx, bufs)
} }
WsStreamKind::Tungstenite(stream) => { WsStreamKind::Upgraded(stream) => {
Pin::new(stream).poll_write_vectored(cx, bufs) Pin::new(stream).poll_write_vectored(cx, bufs)
} }
} }