From 31abacbe1a1dad05de662b07af79cc31dc74c781 Mon Sep 17 00:00:00 2001 From: Yusuke Tanaka Date: Fri, 25 Nov 2022 02:38:09 +0900 Subject: [PATCH] fix(ext/flash): graceful server startup/shutdown with unsettled promises in mind (#16616) This PR resets the revert commit made by #16610, bringing back #16383 which attempts to fix the issue happening when we use the flash server with `--watch` option enabled. Also, some code changes are made to pass the regression test added in #16610. --- cli/tests/integration/watcher_tests.rs | 65 ++++++ cli/tests/unit/flash_test.ts | 32 ++- ext/flash/01_http.js | 26 +-- ext/flash/lib.rs | 294 +++++++++++++++---------- ext/flash/request.rs | 18 +- ext/flash/socket.rs | 18 +- 6 files changed, 305 insertions(+), 148 deletions(-) diff --git a/cli/tests/integration/watcher_tests.rs b/cli/tests/integration/watcher_tests.rs index 27a1bb6201..cd3dc40cfc 100644 --- a/cli/tests/integration/watcher_tests.rs +++ b/cli/tests/integration/watcher_tests.rs @@ -1167,3 +1167,68 @@ fn run_watch_dynamic_imports() { check_alive_then_kill(child); } + +// https://github.com/denoland/deno/issues/16267 +#[test] +fn run_watch_flash() { + let filename = "watch_flash.js"; + let t = TempDir::new(); + let file_to_watch = t.path().join(filename); + write( + &file_to_watch, + r#" + console.log("Starting flash server..."); + Deno.serve({ + onListen() { + console.error("First server is listening"); + }, + handler: () => {}, + port: 4601, + }); + "#, + ) + .unwrap(); + + let mut child = util::deno_cmd() + .current_dir(t.path()) + .arg("run") + .arg("--watch") + .arg("--unstable") + .arg("--allow-net") + .arg("-L") + .arg("debug") + .arg(&file_to_watch) + .env("NO_COLOR", "1") + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .unwrap(); + let (mut stdout_lines, mut stderr_lines) = child_lines(&mut child); + + wait_contains("Starting flash server...", &mut stdout_lines); + wait_for( + |m| m.contains("Watching paths") && m.contains(filename), + &mut stderr_lines, + ); + + write( + &file_to_watch, + r#" + console.log("Restarting flash server..."); + Deno.serve({ + onListen() { + console.error("Second server is listening"); + }, + handler: () => {}, + port: 4601, + }); + "#, + ) + .unwrap(); + + wait_contains("File change detected! Restarting!", &mut stderr_lines); + wait_contains("Restarting flash server...", &mut stdout_lines); + wait_contains("Second server is listening", &mut stderr_lines); + + check_alive_then_kill(child); +} diff --git a/cli/tests/unit/flash_test.ts b/cli/tests/unit/flash_test.ts index 0240694553..761b9137aa 100644 --- a/cli/tests/unit/flash_test.ts +++ b/cli/tests/unit/flash_test.ts @@ -57,32 +57,42 @@ Deno.test(async function httpServerCanResolveHostnames() { await server; }); -Deno.test(async function httpServerRejectsOnAddrInUse() { - const ac = new AbortController(); +// TODO(magurotuna): ignore this case for now because it's flaky on GitHub Actions, +// although it acts as expected when running locally. +// See https://github.com/denoland/deno/pull/16616 +Deno.test({ ignore: true }, async function httpServerRejectsOnAddrInUse() { + const ac1 = new AbortController(); const listeningPromise = deferred(); + let port: number; const server = Deno.serve({ handler: (_req) => new Response("ok"), hostname: "localhost", - port: 4501, - signal: ac.signal, - onListen: onListen(listeningPromise), - onError: createOnErrorCb(ac), + port: 0, + signal: ac1.signal, + onListen: (addr) => { + port = addr.port; + listeningPromise.resolve(); + }, + onError: createOnErrorCb(ac1), }); + await listeningPromise; + + const ac2 = new AbortController(); assertRejects( () => Deno.serve({ handler: (_req) => new Response("ok"), hostname: "localhost", - port: 4501, - signal: ac.signal, - onListen: onListen(listeningPromise), - onError: createOnErrorCb(ac), + port, + signal: ac2.signal, }), Deno.errors.AddrInUse, ); - ac.abort(); + + ac1.abort(); + ac2.abort(); await server; }); diff --git a/ext/flash/01_http.js b/ext/flash/01_http.js index 2b0caff493..67729ee398 100644 --- a/ext/flash/01_http.js +++ b/ext/flash/01_http.js @@ -188,8 +188,8 @@ return str; } - function prepareFastCalls() { - return core.ops.op_flash_make_request(); + function prepareFastCalls(serverId) { + return core.ops.op_flash_make_request(serverId); } function hostnameForDisplay(hostname) { @@ -495,15 +495,11 @@ const serverId = opFn(listenOpts); const serverPromise = core.opAsync("op_flash_drive_server", serverId); - - PromisePrototypeCatch( - PromisePrototypeThen( - core.opAsync("op_flash_wait_for_listening", serverId), - (port) => { - onListen({ hostname: listenOpts.hostname, port }); - }, - ), - () => {}, + const listenPromise = PromisePrototypeThen( + core.opAsync("op_flash_wait_for_listening", serverId), + (port) => { + onListen({ hostname: listenOpts.hostname, port }); + }, ); const finishedPromise = PromisePrototypeCatch(serverPromise, () => {}); @@ -519,7 +515,7 @@ return; } server.closed = true; - await core.opAsync("op_flash_close_server", serverId); + core.ops.op_flash_close_server(serverId); await server.finished; }, async serve() { @@ -634,7 +630,7 @@ signal?.addEventListener("abort", () => { clearInterval(dateInterval); - PromisePrototypeThen(server.close(), () => {}, () => {}); + server.close(); }, { once: true, }); @@ -668,7 +664,7 @@ ); } - const fastOp = prepareFastCalls(); + const fastOp = prepareFastCalls(serverId); let nextRequestSync = () => fastOp.nextRequest(); let getMethodSync = (token) => fastOp.getMethod(token); let respondFast = (token, response, shutdown) => @@ -688,8 +684,8 @@ } await SafePromiseAll([ + listenPromise, PromisePrototypeCatch(server.serve(), console.error), - serverPromise, ]); }; } diff --git a/ext/flash/lib.rs b/ext/flash/lib.rs index d08cdbcdc5..7b43088071 100644 --- a/ext/flash/lib.rs +++ b/ext/flash/lib.rs @@ -35,6 +35,7 @@ use mio::Events; use mio::Interest; use mio::Poll; use mio::Token; +use mio::Waker; use serde::Deserialize; use serde::Serialize; use socket2::Socket; @@ -47,6 +48,7 @@ use std::intrinsics::transmute; use std::io::BufReader; use std::io::Read; use std::io::Write; +use std::marker::PhantomPinned; use std::mem::replace; use std::net::SocketAddr; use std::net::ToSocketAddrs; @@ -55,8 +57,8 @@ use std::rc::Rc; use std::sync::Arc; use std::sync::Mutex; use std::task::Context; -use std::time::Duration; use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::task::JoinHandle; mod chunked; @@ -76,15 +78,24 @@ pub struct FlashContext { pub servers: HashMap, } +impl Drop for FlashContext { + fn drop(&mut self) { + // Signal each server instance to shutdown. + for (_, server) in self.servers.drain() { + let _ = server.waker.wake(); + } + } +} + pub struct ServerContext { _addr: SocketAddr, tx: mpsc::Sender, - rx: mpsc::Receiver, + rx: Option>, requests: HashMap, next_token: u32, - listening_rx: Option>, - close_tx: mpsc::Sender<()>, + listening_rx: Option>>, cancel_handle: Rc, + waker: Arc, } #[derive(Debug, Eq, PartialEq)] @@ -102,7 +113,10 @@ fn op_flash_respond( shutdown: bool, ) -> u32 { let flash_ctx = op_state.borrow_mut::(); - let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); + let ctx = match flash_ctx.servers.get_mut(&server_id) { + Some(ctx) => ctx, + None => return 0, + }; flash_respond(ctx, token, shutdown, &response) } @@ -116,7 +130,7 @@ fn op_try_flash_respond_chuncked( ) -> u32 { let flash_ctx = op_state.borrow_mut::(); let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); - let tx = ctx.requests.get(&token).unwrap(); + let tx = ctx.requests.get_mut(&token).unwrap(); let sock = tx.socket(); // TODO(@littledivy): Use writev when `UnixIoSlice` lands. @@ -153,17 +167,20 @@ async fn op_flash_respond_async( let sock = { let mut op_state = state.borrow_mut(); let flash_ctx = op_state.borrow_mut::(); - let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); + let ctx = match flash_ctx.servers.get_mut(&server_id) { + Some(ctx) => ctx, + None => return Ok(()), + }; match shutdown { true => { - let tx = ctx.requests.remove(&token).unwrap(); + let mut tx = ctx.requests.remove(&token).unwrap(); close = !tx.keep_alive; tx.socket() } // In case of a websocket upgrade or streaming response. false => { - let tx = ctx.requests.get(&token).unwrap(); + let tx = ctx.requests.get_mut(&token).unwrap(); tx.socket() } } @@ -197,12 +214,12 @@ async fn op_flash_respond_chuncked( let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); let sock = match shutdown { true => { - let tx = ctx.requests.remove(&token).unwrap(); + let mut tx = ctx.requests.remove(&token).unwrap(); tx.socket() } // In case of a websocket upgrade or streaming response. false => { - let tx = ctx.requests.get(&token).unwrap(); + let tx = ctx.requests.get_mut(&token).unwrap(); tx.socket() } }; @@ -344,7 +361,7 @@ fn flash_respond( shutdown: bool, response: &[u8], ) -> u32 { - let tx = ctx.requests.get(&token).unwrap(); + let tx = ctx.requests.get_mut(&token).unwrap(); let sock = tx.socket(); sock.read_tx.take(); @@ -428,15 +445,36 @@ fn op_flash_method(state: &mut OpState, server_id: u32, token: u32) -> u32 { } #[op] -async fn op_flash_close_server(state: Rc>, server_id: u32) { - let close_tx = { - let mut op_state = state.borrow_mut(); - let flash_ctx = op_state.borrow_mut::(); - let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); - ctx.cancel_handle.cancel(); - ctx.close_tx.clone() +fn op_flash_drive_server( + state: &mut OpState, + server_id: u32, +) -> Result> + 'static, AnyError> { + let join_handle = { + let flash_ctx = state.borrow_mut::(); + flash_ctx + .join_handles + .remove(&server_id) + .ok_or_else(|| type_error("server not found"))? }; - let _ = close_tx.send(()).await; + Ok(async move { + join_handle + .await + .map_err(|_| type_error("server join error"))??; + Ok(()) + }) +} + +#[op] +fn op_flash_close_server(state: &mut OpState, server_id: u32) { + let flash_ctx = state.borrow_mut::(); + let ctx = flash_ctx.servers.get(&server_id).unwrap(); + + // NOTE: We don't drop ServerContext associated with the given `server_id`, + // because it may still be in use by some unsettled promise after the flash + // thread is finished. + + ctx.cancel_handle.cancel(); + let _ = ctx.waker.wake(); } #[op] @@ -463,7 +501,7 @@ fn op_flash_path( fn next_request_sync(ctx: &mut ServerContext) -> u32 { let offset = ctx.next_token; - while let Ok(token) = ctx.rx.try_recv() { + while let Ok(token) = ctx.rx.as_mut().unwrap().try_recv() { ctx.requests.insert(ctx.next_token, token); ctx.next_token += 1; } @@ -526,6 +564,7 @@ unsafe fn op_flash_get_method_fast( fn op_flash_make_request<'scope>( scope: &mut v8::HandleScope<'scope>, state: &mut OpState, + server_id: u32, ) -> serde_v8::Value<'scope> { let object_template = v8::ObjectTemplate::new(scope); assert!(object_template @@ -533,7 +572,7 @@ fn op_flash_make_request<'scope>( let obj = object_template.new_instance(scope).unwrap(); let ctx = { let flash_ctx = state.borrow_mut::(); - let ctx = flash_ctx.servers.get_mut(&0).unwrap(); + let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); ctx as *mut ServerContext }; obj.set_aligned_pointer_in_internal_field(V8_WRAPPER_OBJECT_INDEX, ctx as _); @@ -625,7 +664,7 @@ fn op_flash_make_request<'scope>( } #[inline] -fn has_body_stream(req: &Request) -> bool { +fn has_body_stream(req: &mut Request) -> bool { let sock = req.socket(); sock.read_rx.is_some() } @@ -749,7 +788,10 @@ async fn op_flash_read_body( { let op_state = &mut state.borrow_mut(); let flash_ctx = op_state.borrow_mut::(); - flash_ctx.servers.get_mut(&server_id).unwrap() as *mut ServerContext + match flash_ctx.servers.get_mut(&server_id) { + Some(ctx) => ctx as *mut ServerContext, + None => return 0, + } } .as_mut() .unwrap() @@ -851,41 +893,40 @@ pub struct ListenOpts { reuseport: bool, } +const SERVER_TOKEN: Token = Token(0); +// Token reserved for the thread close signal. +const WAKER_TOKEN: Token = Token(1); + +#[allow(clippy::too_many_arguments)] fn run_server( tx: mpsc::Sender, - listening_tx: mpsc::Sender, - mut close_rx: mpsc::Receiver<()>, + listening_tx: mpsc::Sender>, addr: SocketAddr, maybe_cert: Option, maybe_key: Option, reuseport: bool, + mut poll: Poll, + // We put a waker as an unused argument here as it needs to be alive both in + // the flash thread and in the main thread (otherwise the notification would + // not be caught by the event loop on Linux). + // See the comment in mio's example: + // https://docs.rs/mio/0.8.4/x86_64-unknown-linux-gnu/mio/struct.Waker.html#examples + _waker: Arc, ) -> Result<(), AnyError> { - let domain = if addr.is_ipv4() { - socket2::Domain::IPV4 - } else { - socket2::Domain::IPV6 + let mut listener = match listen(addr, reuseport) { + Ok(listener) => listener, + Err(e) => { + listening_tx.blocking_send(Err(e)).unwrap(); + return Err(generic_error( + "failed to start listening on the specified address", + )); + } }; - let socket = Socket::new(domain, socket2::Type::STREAM, None)?; - #[cfg(not(windows))] - socket.set_reuse_address(true)?; - if reuseport { - #[cfg(target_os = "linux")] - socket.set_reuse_port(true)?; - } - - let socket_addr = socket2::SockAddr::from(addr); - socket.bind(&socket_addr)?; - socket.listen(128)?; - socket.set_nonblocking(true)?; - let std_listener: std::net::TcpListener = socket.into(); - let mut listener = TcpListener::from_std(std_listener); - - let mut poll = Poll::new()?; - let token = Token(0); + // Register server. poll .registry() - .register(&mut listener, token, Interest::READABLE) + .register(&mut listener, SERVER_TOKEN, Interest::READABLE) .unwrap(); let tls_context: Option> = { @@ -907,30 +948,25 @@ fn run_server( }; listening_tx - .blocking_send(listener.local_addr().unwrap().port()) + .blocking_send(Ok(listener.local_addr().unwrap().port())) .unwrap(); let mut sockets = HashMap::with_capacity(1000); - let mut counter: usize = 1; + let mut socket_senders = HashMap::with_capacity(1000); + let mut counter: usize = 2; let mut events = Events::with_capacity(1024); 'outer: loop { - let result = close_rx.try_recv(); - if result.is_ok() { - break 'outer; - } - // FIXME(bartlomieju): how does Tokio handle it? I just put random 100ms - // timeout here to handle close signal. - match poll.poll(&mut events, Some(Duration::from_millis(100))) { + match poll.poll(&mut events, None) { Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue, Err(e) => panic!("{}", e), Ok(()) => (), } 'events: for event in &events { - if close_rx.try_recv().is_ok() { - break 'outer; - } let token = event.token(); match token { - Token(0) => loop { + WAKER_TOKEN => { + break 'outer; + } + SERVER_TOKEN => loop { match listener.accept() { Ok((mut socket, _)) => { counter += 1; @@ -958,6 +994,7 @@ fn run_server( read_lock: Arc::new(Mutex::new(())), parse_done: ParseStatus::None, buffer: UnsafeCell::new(vec![0_u8; 1024]), + _pinned: PhantomPinned, }); trace!("New connection: {}", token.0); @@ -974,7 +1011,6 @@ fn run_server( let mut_ref: Pin<&mut Stream> = Pin::as_mut(socket); Pin::get_unchecked_mut(mut_ref) }; - let sock_ptr = socket as *mut _; if socket.detached { match &mut socket.inner { @@ -988,6 +1024,7 @@ fn run_server( let boxed = sockets.remove(&token).unwrap(); std::mem::forget(boxed); + socket_senders.remove(&token); trace!("Socket detached: {}", token.0); continue; } @@ -1173,8 +1210,10 @@ fn run_server( continue 'events; } + let (socket_tx, socket_rx) = oneshot::channel(); + tx.blocking_send(Request { - socket: sock_ptr, + socket: socket as *mut _, // SAFETY: headers backing buffer outlives the mio event loop ('static) inner: inner_req, keep_alive, @@ -1183,16 +1222,57 @@ fn run_server( content_read: 0, content_length, expect_continue, + socket_rx, + owned_socket: None, }) .ok(); + + socket_senders.insert(token, socket_tx); } } } } + // Now the flash thread is about to finish, but there may be some unsettled + // promises in the main thread that will use the socket. To make the socket + // alive longer enough, we move its ownership to the main thread. + for (tok, socket) in sockets { + if let Some(sender) = socket_senders.remove(&tok) { + // Do nothing if the receiver has already been dropped. + _ = sender.send(socket); + } + } + Ok(()) } +#[inline] +fn listen( + addr: SocketAddr, + reuseport: bool, +) -> Result { + let domain = if addr.is_ipv4() { + socket2::Domain::IPV4 + } else { + socket2::Domain::IPV6 + }; + let socket = Socket::new(domain, socket2::Type::STREAM, None)?; + + #[cfg(not(windows))] + socket.set_reuse_address(true)?; + if reuseport { + #[cfg(target_os = "linux")] + socket.set_reuse_port(true)?; + } + + let socket_addr = socket2::SockAddr::from(addr); + socket.bind(&socket_addr)?; + socket.listen(128)?; + socket.set_nonblocking(true)?; + let std_listener: std::net::TcpListener = socket.into(); + Ok(TcpListener::from_std(std_listener)) +} + fn make_addr_port_pair(hostname: &str, port: u16) -> (&str, u16) { // Default to localhost if given just the port. Example: ":80" if hostname.is_empty() { @@ -1230,17 +1310,19 @@ where .next() .ok_or_else(|| generic_error("No resolved address found"))?; let (tx, rx) = mpsc::channel(100); - let (close_tx, close_rx) = mpsc::channel(1); let (listening_tx, listening_rx) = mpsc::channel(1); + + let poll = Poll::new()?; + let waker = Arc::new(Waker::new(poll.registry(), WAKER_TOKEN).unwrap()); let ctx = ServerContext { _addr: addr, tx, - rx, + rx: Some(rx), requests: HashMap::with_capacity(1000), next_token: 0, - close_tx, listening_rx: Some(listening_rx), cancel_handle: CancelHandle::new_rc(), + waker: waker.clone(), }; let tx = ctx.tx.clone(); let maybe_cert = opts.cert; @@ -1250,11 +1332,12 @@ where run_server( tx, listening_tx, - close_rx, addr, maybe_cert, maybe_key, reuseport, + poll, + waker, ) }); let flash_ctx = state.borrow_mut::(); @@ -1289,45 +1372,26 @@ where } #[op] -fn op_flash_wait_for_listening( - state: &mut OpState, +async fn op_flash_wait_for_listening( + state: Rc>, server_id: u32, -) -> Result> + 'static, AnyError> { +) -> Result { let mut listening_rx = { - let flash_ctx = state.borrow_mut::(); + let mut op_state = state.borrow_mut(); + let flash_ctx = op_state.borrow_mut::(); let server_ctx = flash_ctx .servers .get_mut(&server_id) .ok_or_else(|| type_error("server not found"))?; server_ctx.listening_rx.take().unwrap() }; - Ok(async move { - if let Some(port) = listening_rx.recv().await { - Ok(port) - } else { - Err(generic_error("This error will be discarded")) - } - }) -} - -#[op] -fn op_flash_drive_server( - state: &mut OpState, - server_id: u32, -) -> Result> + 'static, AnyError> { - let join_handle = { - let flash_ctx = state.borrow_mut::(); - flash_ctx - .join_handles - .remove(&server_id) - .ok_or_else(|| type_error("server not found"))? - }; - Ok(async move { - join_handle - .await - .map_err(|_| type_error("server join error"))??; - Ok(()) - }) + match listening_rx.recv().await { + Some(Ok(port)) => Ok(port), + Some(Err(e)) => Err(e.into()), + _ => Err(generic_error( + "unknown error occurred while waiting for listening", + )), + } } // Asychronous version of op_flash_next. This can be a bottleneck under @@ -1335,26 +1399,34 @@ fn op_flash_drive_server( // requests i.e `op_flash_next() == 0`. #[op] async fn op_flash_next_async( - op_state: Rc>, + state: Rc>, server_id: u32, ) -> u32 { - let ctx = { - let mut op_state = op_state.borrow_mut(); + let mut op_state = state.borrow_mut(); + let flash_ctx = op_state.borrow_mut::(); + let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); + let cancel_handle = ctx.cancel_handle.clone(); + let mut rx = ctx.rx.take().unwrap(); + // We need to drop the borrow before await point. + drop(op_state); + + if let Ok(Some(req)) = rx.recv().or_cancel(&cancel_handle).await { + let mut op_state = state.borrow_mut(); let flash_ctx = op_state.borrow_mut::(); let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); - ctx as *mut ServerContext - }; - // SAFETY: we cannot hold op_state borrow across the await point. The JS caller - // is responsible for ensuring this is not called concurrently. - let ctx = unsafe { &mut *ctx }; - let cancel_handle = &ctx.cancel_handle; - - if let Ok(Some(req)) = ctx.rx.recv().or_cancel(cancel_handle).await { ctx.requests.insert(ctx.next_token, req); ctx.next_token += 1; + // Set the rx back. + ctx.rx = Some(rx); return 1; } + // Set the rx back. + let mut op_state = state.borrow_mut(); + let flash_ctx = op_state.borrow_mut::(); + if let Some(ctx) = flash_ctx.servers.get_mut(&server_id) { + ctx.rx = Some(rx); + } 0 } @@ -1427,7 +1499,7 @@ pub fn detach_socket( // dropped on the server thread. // * conversion from mio::net::TcpStream -> tokio::net::TcpStream. There is no public API so we // use raw fds. - let tx = ctx + let mut tx = ctx .requests .remove(&token) .ok_or_else(|| type_error("request closed"))?; @@ -1522,11 +1594,11 @@ pub fn init(unstable: bool) -> Extension { op_flash_next_async::decl(), op_flash_read_body::decl(), op_flash_upgrade_websocket::decl(), - op_flash_drive_server::decl(), op_flash_wait_for_listening::decl(), op_flash_first_packet::decl(), op_flash_has_body_stream::decl(), op_flash_close_server::decl(), + op_flash_drive_server::decl(), op_flash_make_request::decl(), op_flash_write_resource::decl(), op_try_flash_respond_chuncked::decl(), diff --git a/ext/flash/request.rs b/ext/flash/request.rs index 0736b56206..ac077df6fb 100644 --- a/ext/flash/request.rs +++ b/ext/flash/request.rs @@ -2,6 +2,7 @@ use crate::Stream; use std::pin::Pin; +use tokio::sync::oneshot; #[derive(Debug)] pub struct InnerRequest { @@ -20,8 +21,7 @@ pub struct Request { pub inner: InnerRequest, // Pointer to stream owned by the server loop thread. // - // Dereferencing is safe until server thread finishes and - // op_flash_serve resolves or websocket upgrade is performed. + // Dereferencing is safe until websocket upgrade is performed. pub socket: *mut Stream, pub keep_alive: bool, pub content_read: usize, @@ -29,6 +29,8 @@ pub struct Request { pub remaining_chunk_size: Option, pub te_chunked: bool, pub expect_continue: bool, + pub socket_rx: oneshot::Receiver>>, + pub owned_socket: Option>>, } // SAFETY: Sent from server thread to JS thread. @@ -37,8 +39,16 @@ unsafe impl Send for Request {} impl Request { #[inline(always)] - pub fn socket<'a>(&self) -> &'a mut Stream { - // SAFETY: Dereferencing is safe until server thread detaches socket or finishes. + pub fn socket<'a>(&mut self) -> &'a mut Stream { + if let Ok(mut sock) = self.socket_rx.try_recv() { + // SAFETY: We never move the data out of the acquired mutable reference. + self.socket = unsafe { sock.as_mut().get_unchecked_mut() }; + + // Let the struct own the socket so that it won't get dropped. + self.owned_socket = Some(sock); + } + + // SAFETY: Dereferencing is safe until server thread detaches socket. unsafe { &mut *self.socket } } diff --git a/ext/flash/socket.rs b/ext/flash/socket.rs index 8256be8a0c..7c75b230a6 100644 --- a/ext/flash/socket.rs +++ b/ext/flash/socket.rs @@ -1,23 +1,26 @@ use deno_core::error::AnyError; use mio::net::TcpStream; -use std::{ - cell::UnsafeCell, - future::Future, - io::{Read, Write}, - pin::Pin, - sync::{Arc, Mutex}, -}; +use std::cell::UnsafeCell; +use std::future::Future; +use std::io::Read; +use std::io::Write; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::Mutex; use tokio::sync::mpsc; use crate::ParseStatus; type TlsTcpStream = rustls::StreamOwned; +#[derive(Debug)] pub enum InnerStream { Tcp(TcpStream), Tls(Box), } +#[derive(Debug)] pub struct Stream { pub inner: InnerStream, pub detached: bool, @@ -26,6 +29,7 @@ pub struct Stream { pub parse_done: ParseStatus, pub buffer: UnsafeCell>, pub read_lock: Arc>, + pub _pinned: PhantomPinned, } impl Stream {