diff --git a/cli/tests/unit/websocket_test.ts b/cli/tests/unit/websocket_test.ts index 795d5ebc18..16384da400 100644 --- a/cli/tests/unit/websocket_test.ts +++ b/cli/tests/unit/websocket_test.ts @@ -150,7 +150,7 @@ Deno.test({ Deno.test( { sanitizeOps: false }, - function websocketConstructorWithPrototypePollusion() { + function websocketConstructorWithPrototypePollution() { const originalSymbolIterator = Array.prototype[Symbol.iterator]; try { Array.prototype[Symbol.iterator] = () => { diff --git a/ext/websocket/01_websocket.js b/ext/websocket/01_websocket.js index f6cb6599d8..e6b0053c63 100644 --- a/ext/websocket/01_websocket.js +++ b/ext/websocket/01_websocket.js @@ -53,6 +53,9 @@ const { op_ws_send_binary, op_ws_send_text, op_ws_next_event, + op_ws_get_buffer, + op_ws_get_buffer_as_string, + op_ws_get_error, op_ws_send_ping, op_ws_get_buffered_amount, } = core.ensureFastOps(); @@ -407,15 +410,16 @@ class WebSocket extends EventTarget { } async [_eventLoop]() { + const rid = this[_rid]; while (this[_readyState] !== CLOSED) { - const { 0: kind, 1: value } = await op_ws_next_event(this[_rid]); + const kind = await op_ws_next_event(rid); switch (kind) { case 0: { /* string */ this[_serverHandleIdleTimeout](); const event = new MessageEvent("message", { - data: value, + data: op_ws_get_buffer_as_string(rid), origin: this[_url], }); dispatch(this, event); @@ -424,12 +428,13 @@ class WebSocket extends EventTarget { case 1: { /* binary */ this[_serverHandleIdleTimeout](); + // deno-lint-ignore prefer-primordials + const buffer = op_ws_get_buffer(rid).buffer; let data; - if (this.binaryType === "blob") { - data = new Blob([value]); + data = new Blob([buffer]); } else { - data = value; + data = buffer; } const event = new MessageEvent("message", { @@ -450,13 +455,13 @@ class WebSocket extends EventTarget { this[_readyState] = CLOSED; const errorEv = new ErrorEvent("error", { - message: value, + message: op_ws_get_error(rid), }); this.dispatchEvent(errorEv); const closeEv = new CloseEvent("close"); this.dispatchEvent(closeEv); - core.tryClose(this[_rid]); + core.tryClose(rid); break; } default: { @@ -469,9 +474,9 @@ class WebSocket extends EventTarget { if (prevState === OPEN) { try { await op_ws_close( - this[_rid], + rid, code, - value, + op_ws_get_error(rid), ); } catch { // ignore failures @@ -481,10 +486,10 @@ class WebSocket extends EventTarget { const event = new CloseEvent("close", { wasClean: true, code: code, - reason: value, + reason: op_ws_get_error(rid), }); this.dispatchEvent(event); - core.tryClose(this[_rid]); + core.tryClose(rid); break; } } diff --git a/ext/websocket/02_websocketstream.js b/ext/websocket/02_websocketstream.js index be1001eb60..068fa3e1b9 100644 --- a/ext/websocket/02_websocketstream.js +++ b/ext/websocket/02_websocketstream.js @@ -37,6 +37,9 @@ const { op_ws_send_text_async, op_ws_send_binary_async, op_ws_next_event, + op_ws_get_buffer, + op_ws_get_buffer_as_string, + op_ws_get_error, op_ws_create, op_ws_close, } = core.ensureFastOps(); @@ -177,7 +180,7 @@ class WebSocketStream { PromisePrototypeThen( (async () => { while (true) { - const { 0: kind } = await op_ws_next_event(create.rid); + const kind = await op_ws_next_event(create.rid); if (kind > 5) { /* close */ @@ -239,14 +242,16 @@ class WebSocketStream { }, }); const pull = async (controller) => { - const { 0: kind, 1: value } = await op_ws_next_event(this[_rid]); + const kind = await op_ws_next_event(this[_rid]); switch (kind) { case 0: - case 1: { /* string */ + controller.enqueue(op_ws_get_buffer_as_string(this[_rid])); + break; + case 1: { /* binary */ - controller.enqueue(value); + controller.enqueue(op_ws_get_buffer(this[_rid])); break; } case 2: { @@ -255,7 +260,7 @@ class WebSocketStream { } case 3: { /* error */ - const err = new Error(value); + const err = new Error(op_ws_get_error(this[_rid])); this[_closed].reject(err); controller.error(err); core.tryClose(this[_rid]); @@ -271,7 +276,7 @@ class WebSocketStream { /* close */ this[_closed].resolve({ code: kind, - reason: value, + reason: op_ws_get_error(this[_rid]), }); core.tryClose(this[_rid]); break; @@ -289,7 +294,7 @@ class WebSocketStream { return pull(controller); } - this[_closed].resolve(value); + this[_closed].resolve(op_ws_get_error(this[_rid])); core.tryClose(this[_rid]); } }; diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index b492be0c02..1df71abaa4 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -14,7 +14,6 @@ use deno_core::OpState; use deno_core::RcRef; use deno_core::Resource; use deno_core::ResourceId; -use deno_core::StringOrBuffer; use deno_core::ZeroCopyBuf; use deno_net::raw::NetworkStream; use deno_tls::create_client_config; @@ -290,15 +289,8 @@ where state.borrow_mut().resource_table.close(cancel_rid).ok(); } - let resource = ServerWebSocket { - buffered: Cell::new(0), - errored: Cell::new(None), - ws: AsyncRefCell::new(FragmentCollector::new(stream)), - closed: Cell::new(false), - tx_lock: AsyncRefCell::new(()), - }; let mut state = state.borrow_mut(); - let rid = state.resource_table.add(resource); + let rid = state.resource_table.add(ServerWebSocket::new(stream)); let protocol = match response.headers().get("Sec-WebSocket-Protocol") { Some(header) => header.to_str().unwrap(), @@ -323,18 +315,43 @@ pub enum MessageKind { Binary = 1, Pong = 2, Error = 3, - Closed = 4, + ClosedDefault = 1005, } +/// To avoid locks, we keep as much as we can inside of [`Cell`]s. pub struct ServerWebSocket { buffered: Cell, - errored: Cell>, - ws: AsyncRefCell>, + error: Cell>, + errored: Cell, closed: Cell, + buffer: Cell>>, + ws: AsyncRefCell>, tx_lock: AsyncRefCell<()>, } impl ServerWebSocket { + fn new(ws: WebSocket) -> Self { + Self { + buffered: Cell::new(0), + error: Cell::new(None), + errored: Cell::new(false), + closed: Cell::new(false), + buffer: Cell::new(None), + ws: AsyncRefCell::new(FragmentCollector::new(ws)), + tx_lock: AsyncRefCell::new(()), + } + } + + fn set_error(&self, error: Option) { + if let Some(error) = error { + self.error.set(Some(error)); + self.errored.set(true); + } else { + self.error.set(None); + self.errored.set(false); + } + } + #[inline] pub async fn write_frame( self: &Rc, @@ -374,15 +391,7 @@ pub fn ws_create_server_stream( ws.set_auto_close(true); ws.set_auto_pong(true); - let ws_resource = ServerWebSocket { - buffered: Cell::new(0), - errored: Cell::new(None), - ws: AsyncRefCell::new(FragmentCollector::new(ws)), - closed: Cell::new(false), - tx_lock: AsyncRefCell::new(()), - }; - - let rid = state.resource_table.add(ws_resource); + let rid = state.resource_table.add(ServerWebSocket::new(ws)); Ok(rid) } @@ -401,7 +410,7 @@ pub fn op_ws_send_binary( .write_frame(Frame::new(true, OpCode::Binary, None, data)) .await { - resource.errored.set(Some(err)); + resource.set_error(Some(err.to_string())); } else { resource.buffered.set(resource.buffered.get() - len); } @@ -418,7 +427,7 @@ pub fn op_ws_send_text(state: &mut OpState, rid: ResourceId, data: String) { .write_frame(Frame::new(true, OpCode::Text, None, data.into_bytes())) .await { - resource.errored.set(Some(err)); + resource.set_error(Some(err.to_string())); } else { resource.buffered.set(resource.buffered.get() - len); } @@ -514,18 +523,47 @@ pub async fn op_ws_close( Ok(()) } +#[op] +pub fn op_ws_get_buffer(state: &mut OpState, rid: ResourceId) -> ZeroCopyBuf { + let resource = state.resource_table.get::(rid).unwrap(); + resource.buffer.take().unwrap().into() +} + +#[op] +pub fn op_ws_get_buffer_as_string( + state: &mut OpState, + rid: ResourceId, +) -> String { + let resource = state.resource_table.get::(rid).unwrap(); + // TODO(mmastrac): We won't panic on a bad string, but we return an empty one. + String::from_utf8(resource.buffer.take().unwrap()).unwrap_or_default() +} + +#[op] +pub fn op_ws_get_error(state: &mut OpState, rid: ResourceId) -> String { + let Ok(resource) = state.resource_table.get::(rid) else { + return "Bad resource".into(); + }; + resource.errored.set(false); + resource.error.take().unwrap_or_default() +} + #[op(fast)] pub async fn op_ws_next_event( state: Rc>, rid: ResourceId, -) -> Result<(u16, StringOrBuffer), AnyError> { - let resource = state +) -> u16 { + let Ok(resource) = state .borrow_mut() .resource_table - .get::(rid)?; + .get::(rid) else { + // op_ws_get_error will correctly handle a bad resource + return MessageKind::Error as u16; + }; - if let Some(err) = resource.errored.take() { - return Err(err); + // If there's a pending error, this always returns error + if resource.errored.get() { + return MessageKind::Error as u16; } let mut ws = RcRef::map(&resource, |r| &r.ws).borrow_mut().await; @@ -537,46 +575,44 @@ pub async fn op_ws_next_event( // Try close the stream, ignoring any errors, and report closed status to JavaScript. if resource.closed.get() { let _ = state.borrow_mut().resource_table.close(rid); - return Ok(( - MessageKind::Closed as u16, - StringOrBuffer::Buffer(vec![].into()), - )); + resource.set_error(None); + return MessageKind::ClosedDefault as u16; } - return Ok(( - MessageKind::Error as u16, - StringOrBuffer::String(err.to_string()), - )); + resource.set_error(Some(err.to_string())); + return MessageKind::Error as u16; } }; - break Ok(match val.opcode { - OpCode::Text => ( - MessageKind::Text as u16, - StringOrBuffer::String(String::from_utf8(val.payload).unwrap()), - ), - OpCode::Binary => ( - MessageKind::Binary as u16, - StringOrBuffer::Buffer(val.payload.into()), - ), - OpCode::Close => { - if val.payload.len() < 2 { - return Ok((1005, StringOrBuffer::String("".to_string()))); - } - - let close_code = - CloseCode::from(u16::from_be_bytes([val.payload[0], val.payload[1]])); - let reason = String::from_utf8(val.payload[2..].to_vec()).unwrap(); - (close_code.into(), StringOrBuffer::String(reason)) + break match val.opcode { + OpCode::Text => { + resource.buffer.set(Some(val.payload)); + MessageKind::Text as u16 } - OpCode::Pong => ( - MessageKind::Pong as u16, - StringOrBuffer::Buffer(vec![].into()), - ), + OpCode::Binary => { + resource.buffer.set(Some(val.payload)); + MessageKind::Binary as u16 + } + OpCode::Close => { + // Close reason is returned through error + if val.payload.len() < 2 { + resource.set_error(None); + MessageKind::ClosedDefault as u16 + } else { + let close_code = CloseCode::from(u16::from_be_bytes([ + val.payload[0], + val.payload[1], + ])); + let reason = String::from_utf8(val.payload[2..].to_vec()).ok(); + resource.set_error(reason); + close_code.into() + } + } + OpCode::Pong => MessageKind::Pong as u16, OpCode::Continuation | OpCode::Ping => { continue; } - }); + }; } } @@ -588,6 +624,9 @@ deno_core::extension!(deno_websocket, op_ws_create

, op_ws_close, op_ws_next_event, + op_ws_get_buffer, + op_ws_get_buffer_as_string, + op_ws_get_error, op_ws_send_binary, op_ws_send_text, op_ws_send_binary_async,