From d047cab14b754d20a43c7119e327b451440aaed9 Mon Sep 17 00:00:00 2001 From: Leo Kettmeir Date: Fri, 18 Oct 2024 12:30:46 -0700 Subject: [PATCH] refactor(ext/websocket): use concrete error type (#26226) --- Cargo.lock | 1 + ext/http/http_next.rs | 6 +- ext/http/lib.rs | 8 +- ext/websocket/Cargo.toml | 1 + ext/websocket/lib.rs | 189 +++++++++++++++------------ runtime/errors.rs | 45 ++++++- runtime/ops/web_worker/sync_fetch.rs | 19 ++- 7 files changed, 172 insertions(+), 97 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9b2baf79c2..28b789c0cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2225,6 +2225,7 @@ dependencies = [ "once_cell", "rustls-tokio-stream", "serde", + "thiserror", "tokio", ] diff --git a/ext/http/http_next.rs b/ext/http/http_next.rs index abc54c8994..7a6cbfa45e 100644 --- a/ext/http/http_next.rs +++ b/ext/http/http_next.rs @@ -246,7 +246,11 @@ pub async fn op_http_upgrade_websocket_next( // Stage 3: take the extracted raw network stream and upgrade it to a websocket, then return it let (stream, bytes) = extract_network_stream(upgraded); - ws_create_server_stream(&mut state.borrow_mut(), stream, bytes) + Ok(ws_create_server_stream( + &mut state.borrow_mut(), + stream, + bytes, + )) } #[op2(fast)] diff --git a/ext/http/lib.rs b/ext/http/lib.rs index 934f8a0024..5461713aa8 100644 --- a/ext/http/lib.rs +++ b/ext/http/lib.rs @@ -1053,9 +1053,11 @@ async fn op_http_upgrade_websocket( let (transport, bytes) = extract_network_stream(hyper_v014::upgrade::on(request).await?); - let ws_rid = - ws_create_server_stream(&mut state.borrow_mut(), transport, bytes)?; - Ok(ws_rid) + Ok(ws_create_server_stream( + &mut state.borrow_mut(), + transport, + bytes, + )) } // Needed so hyper can use non Send futures diff --git a/ext/websocket/Cargo.toml b/ext/websocket/Cargo.toml index b265ba78a2..ec359100d4 100644 --- a/ext/websocket/Cargo.toml +++ b/ext/websocket/Cargo.toml @@ -28,4 +28,5 @@ hyper-util.workspace = true once_cell.workspace = true rustls-tokio-stream.workspace = true serde.workspace = true +thiserror.workspace = true tokio.workspace = true diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index b8043516b0..2a67ac5a17 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -1,10 +1,6 @@ // Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. use crate::stream::WebSocketStream; use bytes::Bytes; -use deno_core::anyhow::bail; -use deno_core::error::invalid_hostname; -use deno_core::error::type_error; -use deno_core::error::AnyError; use deno_core::futures::TryFutureExt; use deno_core::op2; use deno_core::unsync::spawn; @@ -43,7 +39,6 @@ use serde::Serialize; use std::borrow::Cow; use std::cell::Cell; use std::cell::RefCell; -use std::fmt; use std::future::Future; use std::num::NonZeroUsize; use std::path::PathBuf; @@ -75,11 +70,33 @@ static USE_WRITEV: Lazy = Lazy::new(|| { false }); +#[derive(Debug, thiserror::Error)] +pub enum WebsocketError { + #[error(transparent)] + Url(url::ParseError), + #[error(transparent)] + Permission(deno_core::error::AnyError), + #[error(transparent)] + Resource(deno_core::error::AnyError), + #[error(transparent)] + Uri(#[from] http::uri::InvalidUri), + #[error("{0}")] + Io(#[from] std::io::Error), + #[error(transparent)] + WebSocket(#[from] fastwebsockets::WebSocketError), + #[error("failed to connect to WebSocket: {0}")] + ConnectionFailed(#[from] HandshakeError), + #[error(transparent)] + Canceled(#[from] deno_core::Canceled), +} + #[derive(Clone)] pub struct WsRootStoreProvider(Option>); impl WsRootStoreProvider { - pub fn get_or_try_init(&self) -> Result, AnyError> { + pub fn get_or_try_init( + &self, + ) -> Result, deno_core::error::AnyError> { Ok(match &self.0 { Some(provider) => Some(provider.get_or_try_init()?.clone()), None => None, @@ -95,7 +112,7 @@ pub trait WebSocketPermissions { &mut self, _url: &url::Url, _api_name: &str, - ) -> Result<(), AnyError>; + ) -> Result<(), deno_core::error::AnyError>; } impl WebSocketPermissions for deno_permissions::PermissionsContainer { @@ -104,7 +121,7 @@ impl WebSocketPermissions for deno_permissions::PermissionsContainer { &mut self, url: &url::Url, api_name: &str, - ) -> Result<(), AnyError> { + ) -> Result<(), deno_core::error::AnyError> { deno_permissions::PermissionsContainer::check_net_url(self, url, api_name) } } @@ -137,13 +154,17 @@ pub fn op_ws_check_permission_and_cancel_handle( #[string] api_name: String, #[string] url: String, cancel_handle: bool, -) -> Result, AnyError> +) -> Result, WebsocketError> where WP: WebSocketPermissions + 'static, { state .borrow_mut::() - .check_net_url(&url::Url::parse(&url)?, &api_name)?; + .check_net_url( + &url::Url::parse(&url).map_err(WebsocketError::Url)?, + &api_name, + ) + .map_err(WebsocketError::Permission)?; if cancel_handle { let rid = state @@ -163,16 +184,46 @@ pub struct CreateResponse { extensions: String, } +#[derive(Debug, thiserror::Error)] +pub enum HandshakeError { + #[error("Missing path in url")] + MissingPath, + #[error("Invalid status code {0}")] + InvalidStatusCode(StatusCode), + #[error(transparent)] + Http(#[from] http::Error), + #[error(transparent)] + WebSocket(#[from] fastwebsockets::WebSocketError), + #[error("Didn't receive h2 alpn, aborting connection")] + NoH2Alpn, + #[error(transparent)] + Rustls(#[from] deno_tls::rustls::Error), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + H2(#[from] h2::Error), + #[error("Invalid hostname: '{0}'")] + InvalidHostname(String), + #[error(transparent)] + RootStoreError(deno_core::error::AnyError), + #[error(transparent)] + Tls(deno_tls::TlsError), + #[error(transparent)] + HeaderName(#[from] http::header::InvalidHeaderName), + #[error(transparent)] + HeaderValue(#[from] http::header::InvalidHeaderValue), +} + async fn handshake_websocket( state: &Rc>, uri: &Uri, protocols: &str, headers: Option>, -) -> Result<(WebSocket, http::HeaderMap), AnyError> { +) -> Result<(WebSocket, http::HeaderMap), HandshakeError> { let mut request = Request::builder().method(Method::GET).uri( uri .path_and_query() - .ok_or(type_error("Missing path in url".to_string()))? + .ok_or(HandshakeError::MissingPath)? .as_str(), ); @@ -194,7 +245,9 @@ async fn handshake_websocket( request = populate_common_request_headers(request, &user_agent, protocols, &headers)?; - let request = request.body(http_body_util::Empty::new())?; + let request = request + .body(http_body_util::Empty::new()) + .map_err(HandshakeError::Http)?; let domain = &uri.host().unwrap().to_string(); let port = &uri.port_u16().unwrap_or(match uri.scheme_str() { Some("wss") => 443, @@ -231,7 +284,7 @@ async fn handshake_websocket( async fn handshake_http1_ws( request: Request>, addr: &String, -) -> Result<(WebSocket, http::HeaderMap), AnyError> { +) -> Result<(WebSocket, http::HeaderMap), HandshakeError> { let tcp_socket = TcpStream::connect(addr).await?; handshake_connection(request, tcp_socket).await } @@ -241,11 +294,11 @@ async fn handshake_http1_wss( request: Request>, domain: &str, addr: &str, -) -> Result<(WebSocket, http::HeaderMap), AnyError> { +) -> Result<(WebSocket, http::HeaderMap), HandshakeError> { let tcp_socket = TcpStream::connect(addr).await?; let tls_config = create_ws_client_config(state, SocketUse::Http1Only)?; let dnsname = ServerName::try_from(domain.to_string()) - .map_err(|_| invalid_hostname(domain))?; + .map_err(|_| HandshakeError::InvalidHostname(domain.to_string()))?; let mut tls_connector = TlsStream::new_client_side( tcp_socket, ClientConnection::new(tls_config.into(), dnsname)?, @@ -266,11 +319,11 @@ async fn handshake_http2_wss( domain: &str, headers: &Option>, addr: &str, -) -> Result<(WebSocket, http::HeaderMap), AnyError> { +) -> Result<(WebSocket, http::HeaderMap), HandshakeError> { let tcp_socket = TcpStream::connect(addr).await?; let tls_config = create_ws_client_config(state, SocketUse::Http2Only)?; let dnsname = ServerName::try_from(domain.to_string()) - .map_err(|_| invalid_hostname(domain))?; + .map_err(|_| HandshakeError::InvalidHostname(domain.to_string()))?; // We need to better expose the underlying errors here let mut tls_connector = TlsStream::new_client_side( tcp_socket, @@ -279,7 +332,7 @@ async fn handshake_http2_wss( ); let handshake = tls_connector.handshake().await?; if handshake.alpn.is_none() { - bail!("Didn't receive h2 alpn, aborting connection"); + return Err(HandshakeError::NoH2Alpn); } let h2 = h2::client::Builder::new(); let (mut send, conn) = h2.handshake::<_, Bytes>(tls_connector).await?; @@ -298,7 +351,7 @@ async fn handshake_http2_wss( let (resp, send) = send.send_request(request.body(())?, false)?; let resp = resp.await?; if resp.status() != StatusCode::OK { - bail!("Invalid status code: {}", resp.status()); + return Err(HandshakeError::InvalidStatusCode(resp.status())); } let (http::response::Parts { headers, .. }, recv) = resp.into_parts(); let mut stream = WebSocket::after_handshake( @@ -317,7 +370,7 @@ async fn handshake_connection< >( request: Request>, socket: S, -) -> Result<(WebSocket, http::HeaderMap), AnyError> { +) -> Result<(WebSocket, http::HeaderMap), HandshakeError> { let (upgraded, response) = fastwebsockets::handshake::client(&LocalExecutor, request, socket).await?; @@ -332,7 +385,7 @@ async fn handshake_connection< pub fn create_ws_client_config( state: &Rc>, socket_use: SocketUse, -) -> Result { +) -> Result { let unsafely_ignore_certificate_errors: Option> = state .borrow() .try_borrow::() @@ -340,7 +393,8 @@ pub fn create_ws_client_config( let root_cert_store = state .borrow() .borrow::() - .get_or_try_init()?; + .get_or_try_init() + .map_err(HandshakeError::RootStoreError)?; create_client_config( root_cert_store, @@ -349,7 +403,7 @@ pub fn create_ws_client_config( TlsKeys::Null, socket_use, ) - .map_err(|e| e.into()) + .map_err(HandshakeError::Tls) } /// Headers common to both http/1.1 and h2 requests. @@ -358,7 +412,7 @@ fn populate_common_request_headers( user_agent: &str, protocols: &str, headers: &Option>, -) -> Result { +) -> Result { request = request .header("User-Agent", user_agent) .header("Sec-WebSocket-Version", "13"); @@ -369,10 +423,8 @@ fn populate_common_request_headers( if let Some(headers) = headers { for (key, value) in headers { - let name = HeaderName::from_bytes(key) - .map_err(|err| type_error(err.to_string()))?; - let v = HeaderValue::from_bytes(value) - .map_err(|err| type_error(err.to_string()))?; + let name = HeaderName::from_bytes(key)?; + let v = HeaderValue::from_bytes(value)?; let is_disallowed_header = matches!( name, @@ -402,14 +454,17 @@ pub async fn op_ws_create( #[string] protocols: String, #[smi] cancel_handle: Option, #[serde] headers: Option>, -) -> Result +) -> Result where WP: WebSocketPermissions + 'static, { { let mut s = state.borrow_mut(); s.borrow_mut::() - .check_net_url(&url::Url::parse(&url)?, &api_name) + .check_net_url( + &url::Url::parse(&url).map_err(WebsocketError::Url)?, + &api_name, + ) .expect( "Permission check should have been done in op_ws_check_permission", ); @@ -419,7 +474,8 @@ where let r = state .borrow_mut() .resource_table - .get::(cancel_rid)?; + .get::(cancel_rid) + .map_err(WebsocketError::Resource)?; Some(r.0.clone()) } else { None @@ -428,15 +484,11 @@ where let uri: Uri = url.parse()?; let handshake = handshake_websocket(&state, &uri, &protocols, headers) - .map_err(|err| { - AnyError::from(DomExceptionNetworkError::new(&format!( - "failed to connect to WebSocket: {err}" - ))) - }); + .map_err(WebsocketError::ConnectionFailed); let (stream, response) = match cancel_resource { - Some(rc) => handshake.try_or_cancel(rc).await, - None => handshake.await, - }?; + Some(rc) => handshake.try_or_cancel(rc).await?, + None => handshake.await?, + }; if let Some(cancel_rid) = cancel_handle { if let Ok(res) = state.borrow_mut().resource_table.take_any(cancel_rid) { @@ -521,14 +573,12 @@ impl ServerWebSocket { self: &Rc, lock: AsyncMutFuture>>, frame: Frame<'_>, - ) -> Result<(), AnyError> { + ) -> Result<(), WebsocketError> { let mut ws = lock.await; if ws.is_closed() { return Ok(()); } - ws.write_frame(frame) - .await - .map_err(|err| type_error(err.to_string()))?; + ws.write_frame(frame).await?; Ok(()) } } @@ -543,7 +593,7 @@ pub fn ws_create_server_stream( state: &mut OpState, transport: NetworkStream, read_buf: Bytes, -) -> Result { +) -> ResourceId { let mut ws = WebSocket::after_handshake( WebSocketStream::new( stream::WsStreamKind::Network(transport), @@ -555,8 +605,7 @@ pub fn ws_create_server_stream( ws.set_auto_close(true); ws.set_auto_pong(true); - let rid = state.resource_table.add(ServerWebSocket::new(ws)); - Ok(rid) + state.resource_table.add(ServerWebSocket::new(ws)) } fn send_binary(state: &mut OpState, rid: ResourceId, data: &[u8]) { @@ -626,11 +675,12 @@ pub async fn op_ws_send_binary_async( state: Rc>, #[smi] rid: ResourceId, #[buffer] data: JsBuffer, -) -> Result<(), AnyError> { +) -> Result<(), WebsocketError> { let resource = state .borrow_mut() .resource_table - .get::(rid)?; + .get::(rid) + .map_err(WebsocketError::Resource)?; let data = data.to_vec(); let lock = resource.reserve_lock(); resource @@ -644,11 +694,12 @@ pub async fn op_ws_send_text_async( state: Rc>, #[smi] rid: ResourceId, #[string] data: String, -) -> Result<(), AnyError> { +) -> Result<(), WebsocketError> { let resource = state .borrow_mut() .resource_table - .get::(rid)?; + .get::(rid) + .map_err(WebsocketError::Resource)?; let lock = resource.reserve_lock(); resource .write_frame( @@ -678,11 +729,12 @@ pub fn op_ws_get_buffered_amount( pub async fn op_ws_send_ping( state: Rc>, #[smi] rid: ResourceId, -) -> Result<(), AnyError> { +) -> Result<(), WebsocketError> { let resource = state .borrow_mut() .resource_table - .get::(rid)?; + .get::(rid) + .map_err(WebsocketError::Resource)?; let lock = resource.reserve_lock(); resource .write_frame( @@ -698,7 +750,7 @@ pub async fn op_ws_close( #[smi] rid: ResourceId, #[smi] code: Option, #[string] reason: Option, -) -> Result<(), AnyError> { +) -> Result<(), WebsocketError> { let Ok(resource) = state .borrow_mut() .resource_table @@ -713,8 +765,7 @@ pub async fn op_ws_close( resource.closed.set(true); let lock = resource.reserve_lock(); - resource.write_frame(lock, frame).await?; - Ok(()) + resource.write_frame(lock, frame).await } #[op2] @@ -868,32 +919,6 @@ pub fn get_declaration() -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("lib.deno_websocket.d.ts") } -#[derive(Debug)] -pub struct DomExceptionNetworkError { - pub msg: String, -} - -impl DomExceptionNetworkError { - pub fn new(msg: &str) -> Self { - DomExceptionNetworkError { - msg: msg.to_string(), - } - } -} - -impl fmt::Display for DomExceptionNetworkError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.pad(&self.msg) - } -} - -impl std::error::Error for DomExceptionNetworkError {} - -pub fn get_network_error_class_name(e: &AnyError) -> Option<&'static str> { - e.downcast_ref::() - .map(|_| "DOMExceptionNetworkError") -} - // Needed so hyper can use non Send futures #[derive(Clone)] struct LocalExecutor; diff --git a/runtime/errors.rs b/runtime/errors.rs index 25fc664a57..0e551d1219 100644 --- a/runtime/errors.rs +++ b/runtime/errors.rs @@ -36,6 +36,8 @@ use deno_web::CompressionError; use deno_web::MessagePortError; use deno_web::StreamResourceError; use deno_web::WebError; +use deno_websocket::HandshakeError; +use deno_websocket::WebsocketError; use deno_webstorage::WebStorageError; use std::env; use std::error::Error; @@ -368,6 +370,43 @@ fn get_broadcast_channel_error(error: &BroadcastChannelError) -> &'static str { } } +fn get_websocket_error(error: &WebsocketError) -> &'static str { + match error { + WebsocketError::Permission(e) | WebsocketError::Resource(e) => { + get_error_class_name(e).unwrap_or("Error") + } + WebsocketError::Url(e) => get_url_parse_error_class(e), + WebsocketError::Io(e) => get_io_error_class(e), + WebsocketError::WebSocket(_) => "TypeError", + WebsocketError::ConnectionFailed(_) => "DOMExceptionNetworkError", + WebsocketError::Uri(_) => "Error", + WebsocketError::Canceled(e) => { + let io_err: io::Error = e.to_owned().into(); + get_io_error_class(&io_err) + } + } +} + +fn get_websocket_handshake_error(error: &HandshakeError) -> &'static str { + match error { + HandshakeError::RootStoreError(e) => { + get_error_class_name(e).unwrap_or("Error") + } + HandshakeError::Tls(e) => get_tls_error_class(e), + HandshakeError::MissingPath => "TypeError", + HandshakeError::Http(_) => "Error", + HandshakeError::InvalidHostname(_) => "TypeError", + HandshakeError::Io(e) => get_io_error_class(e), + HandshakeError::Rustls(_) => "Error", + HandshakeError::H2(_) => "Error", + HandshakeError::NoH2Alpn => "Error", + HandshakeError::InvalidStatusCode(_) => "Error", + HandshakeError::WebSocket(_) => "TypeError", + HandshakeError::HeaderName(_) => "TypeError", + HandshakeError::HeaderValue(_) => "TypeError", + } +} + fn get_fs_error(error: &FsOpsError) -> &'static str { match error { FsOpsError::Io(e) => get_io_error_class(e), @@ -482,7 +521,6 @@ fn get_net_map_error(error: &deno_net::io::MapError) -> &'static str { pub fn get_error_class_name(e: &AnyError) -> Option<&'static str> { deno_core::error::get_custom_error_class(e) .or_else(|| deno_webgpu::error::get_error_class_name(e)) - .or_else(|| deno_websocket::get_network_error_class_name(e)) .or_else(|| e.downcast_ref::().map(get_napi_error_class)) .or_else(|| e.downcast_ref::().map(get_web_error_class)) .or_else(|| { @@ -518,6 +556,11 @@ pub fn get_error_class_name(e: &AnyError) -> Option<&'static str> { .or_else(|| e.downcast_ref::().map(get_cron_error_class)) .or_else(|| e.downcast_ref::().map(get_canvas_error)) .or_else(|| e.downcast_ref::().map(get_cache_error)) + .or_else(|| e.downcast_ref::().map(get_websocket_error)) + .or_else(|| { + e.downcast_ref::() + .map(get_websocket_handshake_error) + }) .or_else(|| e.downcast_ref::().map(get_kv_error)) .or_else(|| e.downcast_ref::().map(get_net_error)) .or_else(|| { diff --git a/runtime/ops/web_worker/sync_fetch.rs b/runtime/ops/web_worker/sync_fetch.rs index 87fc558405..bd55a5fc8c 100644 --- a/runtime/ops/web_worker/sync_fetch.rs +++ b/runtime/ops/web_worker/sync_fetch.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use crate::web_worker::WebWorkerInternalHandle; use crate::web_worker::WebWorkerType; +use deno_core::error::custom_error; use deno_core::error::type_error; use deno_core::error::AnyError; use deno_core::futures::StreamExt; @@ -12,7 +13,6 @@ use deno_core::url::Url; use deno_core::OpState; use deno_fetch::data_url::DataUrl; use deno_web::BlobStore; -use deno_websocket::DomExceptionNetworkError; use http_body_util::BodyExt; use hyper::body::Bytes; use serde::Deserialize; @@ -151,17 +151,16 @@ pub fn op_worker_sync_fetch( match mime_type.as_deref() { Some("application/javascript" | "text/javascript") => {} Some(mime_type) => { - return Err( - DomExceptionNetworkError { - msg: format!("Invalid MIME type {mime_type:?}."), - } - .into(), - ) + return Err(custom_error( + "DOMExceptionNetworkError", + format!("Invalid MIME type {mime_type:?}."), + )) } None => { - return Err( - DomExceptionNetworkError::new("Missing MIME type.").into(), - ) + return Err(custom_error( + "DOMExceptionNetworkError", + "Missing MIME type.", + )) } } }