// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license. use anyhow::anyhow; use bytes::Bytes; use fastwebsockets::FragmentCollector; use fastwebsockets::Frame; use fastwebsockets::OpCode; use fastwebsockets::Role; use fastwebsockets::WebSocket; use futures::future::join3; use futures::future::poll_fn; use futures::Future; use futures::StreamExt; use h2::server::Handshake; use h2::server::SendResponse; use h2::Reason; use h2::RecvStream; use hyper::service::service_fn; use hyper::upgrade::Upgraded; use hyper::Body; use hyper::Method; use hyper::Request; use hyper::Response; use hyper::StatusCode; use pretty_assertions::assert_eq; use std::pin::Pin; use std::result::Result; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use super::get_tcp_listener_stream; use super::get_tls_listener_stream; use super::SupportedHttpVersions; pub async fn run_ws_server(port: u16) { let mut tcp = get_tcp_listener_stream("ws", port).await; while let Some(Ok(stream)) = tcp.next().await { spawn_ws_server(stream, |ws| Box::pin(echo_websocket_handler(ws))); } } pub async fn run_ws_ping_server(port: u16) { let mut tcp = get_tcp_listener_stream("ws (ping)", port).await; while let Some(Ok(stream)) = tcp.next().await { spawn_ws_server(stream, |ws| Box::pin(ping_websocket_handler(ws))); } } pub async fn run_wss_server(port: u16) { let mut tls = get_tls_listener_stream("wss", port, Default::default()).await; while let Some(Ok(tls_stream)) = tls.next().await { tokio::spawn(async move { spawn_ws_server(tls_stream, |ws| Box::pin(echo_websocket_handler(ws))); }); } } pub async fn run_ws_close_server(port: u16) { let mut tcp = get_tcp_listener_stream("ws (close)", port).await; while let Some(Ok(stream)) = tcp.next().await { spawn_ws_server(stream, |ws| Box::pin(close_websocket_handler(ws))); } } pub async fn run_wss2_server(port: u16) { let mut tls = get_tls_listener_stream( "wss2 (tls)", port, SupportedHttpVersions::Http2Only, ) .await; while let Some(Ok(tls)) = tls.next().await { tokio::spawn(async move { let mut h2 = h2::server::Builder::new(); h2.enable_connect_protocol(); // Using Bytes is pretty alloc-heavy but this is a test server let server: Handshake<_, Bytes> = h2.handshake(tls); let mut server = match server.await { Ok(server) => server, Err(e) => { println!("Failed to handshake h2: {e:?}"); return; } }; loop { let Some(conn) = server.accept().await else { break; }; let (recv, send) = match conn { Ok(conn) => conn, Err(e) => { println!("Failed to accept a connection: {e:?}"); break; } }; tokio::spawn(handle_wss_stream(recv, send)); } }); } } async fn echo_websocket_handler( ws: fastwebsockets::WebSocket, ) -> Result<(), anyhow::Error> { let mut ws = fastwebsockets::FragmentCollector::new(ws); loop { let frame = ws.read_frame().await.unwrap(); match frame.opcode { fastwebsockets::OpCode::Close => break, fastwebsockets::OpCode::Text | fastwebsockets::OpCode::Binary => { ws.write_frame(frame).await.unwrap(); } _ => {} } } Ok(()) } type WsHandler = fn( fastwebsockets::WebSocket, ) -> Pin> + Send>>; fn spawn_ws_server(stream: S, handler: WsHandler) where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { let srv_fn = service_fn(move |mut req: Request| async move { let (response, upgrade_fut) = fastwebsockets::upgrade::upgrade(&mut req) .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))?; tokio::spawn(async move { let ws = upgrade_fut .await .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e)) .unwrap(); if let Err(e) = handler(ws).await { eprintln!("Error in websocket connection: {}", e); } }); Ok::<_, anyhow::Error>(response) }); tokio::spawn(async move { let conn_fut = hyper::server::conn::Http::new() .serve_connection(stream, srv_fn) .with_upgrades(); if let Err(e) = conn_fut.await { eprintln!("websocket server error: {e:?}"); } }); } async fn handle_wss_stream( recv: Request, mut send: SendResponse, ) -> Result<(), h2::Error> { if recv.method() != Method::CONNECT { eprintln!("wss2: refusing non-CONNECT stream"); send.send_reset(Reason::REFUSED_STREAM); return Ok(()); } let Some(protocol) = recv.extensions().get::() else { eprintln!("wss2: refusing no-:protocol stream"); send.send_reset(Reason::REFUSED_STREAM); return Ok(()); }; if protocol.as_str() != "websocket" && protocol.as_str() != "WebSocket" { eprintln!("wss2: refusing non-websocket stream"); send.send_reset(Reason::REFUSED_STREAM); return Ok(()); } let mut body = recv.into_body(); let mut response = Response::new(()); *response.status_mut() = StatusCode::OK; let mut resp = send.send_response(response, false)?; // Use a duplex stream to talk to fastwebsockets because it's just faster to implement let (a, b) = tokio::io::duplex(65536); let f1 = tokio::spawn(tokio::task::unconstrained(async move { let ws = WebSocket::after_handshake(a, Role::Server); let mut ws = FragmentCollector::new(ws); loop { let frame = ws.read_frame().await.unwrap(); if frame.opcode == OpCode::Close { break; } ws.write_frame(frame).await.unwrap(); } })); let (mut br, mut bw) = tokio::io::split(b); let f2 = tokio::spawn(tokio::task::unconstrained(async move { loop { let Some(Ok(data)) = poll_fn(|cx| body.poll_data(cx)).await else { return; }; body.flow_control().release_capacity(data.len()).unwrap(); let Ok(_) = bw.write_all(&data).await else { break; }; } })); let f3 = tokio::spawn(tokio::task::unconstrained(async move { loop { let mut buf = [0; 65536]; let n = br.read(&mut buf).await.unwrap(); if n == 0 { break; } resp.reserve_capacity(n); poll_fn(|cx| resp.poll_capacity(cx)).await; resp .send_data(Bytes::copy_from_slice(&buf[0..n]), false) .unwrap(); } resp.send_data(Bytes::new(), true).unwrap(); })); _ = join3(f1, f2, f3).await; Ok(()) } async fn close_websocket_handler( ws: fastwebsockets::WebSocket, ) -> Result<(), anyhow::Error> { let mut ws = fastwebsockets::FragmentCollector::new(ws); ws.write_frame(fastwebsockets::Frame::close_raw(vec![].into())) .await .unwrap(); Ok(()) } async fn ping_websocket_handler( ws: fastwebsockets::WebSocket, ) -> Result<(), anyhow::Error> { let mut ws = fastwebsockets::FragmentCollector::new(ws); for i in 0..9 { ws.write_frame(Frame::new(true, OpCode::Ping, None, vec![].into())) .await .unwrap(); let frame = ws.read_frame().await.unwrap(); assert_eq!(frame.opcode, OpCode::Pong); assert!(frame.payload.is_empty()); ws.write_frame(Frame::text( format!("hello {}", i).as_bytes().to_vec().into(), )) .await .unwrap(); let frame = ws.read_frame().await.unwrap(); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload, format!("hello {}", i).as_bytes()); } ws.write_frame(fastwebsockets::Frame::close(1000, b"")) .await .unwrap(); Ok(()) }