1
0
Fork 0
mirror of https://github.com/denoland/deno.git synced 2024-12-20 06:15:44 -05:00
denoland-deno/tests/util/server/src/servers/ws.rs

269 lines
7.5 KiB
Rust
Raw Normal View History

// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use anyhow::anyhow;
use bytes::Bytes;
use fastwebsockets::FragmentCollector;
use fastwebsockets::Frame;
use fastwebsockets::OpCode;
use fastwebsockets::Role;
use fastwebsockets::WebSocket;
use futures::future::join3;
use futures::future::poll_fn;
use futures::Future;
use futures::StreamExt;
use h2::server::Handshake;
use h2::server::SendResponse;
use h2::Reason;
use h2::RecvStream;
use hyper::upgrade::Upgraded;
use hyper::Method;
use hyper::Request;
use hyper::Response;
use hyper::StatusCode;
use hyper_util::rt::TokioIo;
use pretty_assertions::assert_eq;
use std::pin::Pin;
use std::result::Result;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use super::get_tcp_listener_stream;
use super::get_tls_listener_stream;
use super::SupportedHttpVersions;
pub async fn run_ws_server(port: u16) {
let mut tcp = get_tcp_listener_stream("ws", port).await;
while let Some(Ok(stream)) = tcp.next().await {
spawn_ws_server(stream, |ws| Box::pin(echo_websocket_handler(ws)));
}
}
pub async fn run_ws_ping_server(port: u16) {
let mut tcp = get_tcp_listener_stream("ws (ping)", port).await;
while let Some(Ok(stream)) = tcp.next().await {
spawn_ws_server(stream, |ws| Box::pin(ping_websocket_handler(ws)));
}
}
pub async fn run_wss_server(port: u16) {
let mut tls = get_tls_listener_stream("wss", port, Default::default()).await;
while let Some(Ok(tls_stream)) = tls.next().await {
tokio::spawn(async move {
spawn_ws_server(tls_stream, |ws| Box::pin(echo_websocket_handler(ws)));
});
}
}
pub async fn run_ws_close_server(port: u16) {
let mut tcp = get_tcp_listener_stream("ws (close)", port).await;
while let Some(Ok(stream)) = tcp.next().await {
spawn_ws_server(stream, |ws| Box::pin(close_websocket_handler(ws)));
}
}
pub async fn run_wss2_server(port: u16) {
let mut tls = get_tls_listener_stream(
"wss2 (tls)",
port,
SupportedHttpVersions::Http2Only,
)
.await;
while let Some(Ok(tls)) = tls.next().await {
tokio::spawn(async move {
let mut h2 = h2::server::Builder::new();
h2.enable_connect_protocol();
// Using Bytes is pretty alloc-heavy but this is a test server
let server: Handshake<_, Bytes> = h2.handshake(tls);
let mut server = match server.await {
Ok(server) => server,
Err(e) => {
println!("Failed to handshake h2: {e:?}");
return;
}
};
loop {
let Some(conn) = server.accept().await else {
break;
};
let (recv, send) = match conn {
Ok(conn) => conn,
Err(e) => {
println!("Failed to accept a connection: {e:?}");
break;
}
};
tokio::spawn(handle_wss_stream(recv, send));
}
});
}
}
async fn echo_websocket_handler(
ws: fastwebsockets::WebSocket<TokioIo<Upgraded>>,
) -> Result<(), anyhow::Error> {
let mut ws = FragmentCollector::new(ws);
loop {
let frame = ws.read_frame().await.unwrap();
match frame.opcode {
OpCode::Close => break,
OpCode::Text | OpCode::Binary => {
ws.write_frame(frame).await.unwrap();
}
_ => {}
}
}
Ok(())
}
type WsHandler =
fn(
fastwebsockets::WebSocket<TokioIo<Upgraded>>,
) -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
fn spawn_ws_server<S>(stream: S, handler: WsHandler)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let service = hyper::service::service_fn(
move |mut req: http::Request<hyper::body::Incoming>| async move {
let (response, upgrade_fut) = fastwebsockets::upgrade::upgrade(&mut req)
.map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))?;
tokio::spawn(async move {
let ws = upgrade_fut
.await
.map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))
.unwrap();
if let Err(e) = handler(ws).await {
eprintln!("Error in websocket connection: {}", e);
}
});
Ok::<_, anyhow::Error>(response)
},
);
let io = TokioIo::new(stream);
tokio::spawn(async move {
let conn = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.with_upgrades();
if let Err(e) = conn.await {
eprintln!("websocket server error: {e:?}");
}
});
}
async fn handle_wss_stream(
recv: Request<RecvStream>,
mut send: SendResponse<Bytes>,
) -> Result<(), h2::Error> {
if recv.method() != Method::CONNECT {
eprintln!("wss2: refusing non-CONNECT stream");
send.send_reset(Reason::REFUSED_STREAM);
return Ok(());
}
let Some(protocol) = recv.extensions().get::<h2::ext::Protocol>() else {
eprintln!("wss2: refusing no-:protocol stream");
send.send_reset(Reason::REFUSED_STREAM);
return Ok(());
};
if protocol.as_str() != "websocket" && protocol.as_str() != "WebSocket" {
eprintln!("wss2: refusing non-websocket stream");
send.send_reset(Reason::REFUSED_STREAM);
return Ok(());
}
let mut body = recv.into_body();
let mut response = Response::new(());
*response.status_mut() = StatusCode::OK;
let mut resp = send.send_response(response, false)?;
// Use a duplex stream to talk to fastwebsockets because it's just faster to implement
let (a, b) = tokio::io::duplex(65536);
let f1 = tokio::spawn(tokio::task::unconstrained(async move {
let ws = WebSocket::after_handshake(a, Role::Server);
let mut ws = FragmentCollector::new(ws);
loop {
let frame = ws.read_frame().await.unwrap();
if frame.opcode == OpCode::Close {
break;
}
ws.write_frame(frame).await.unwrap();
}
}));
let (mut br, mut bw) = tokio::io::split(b);
let f2 = tokio::spawn(tokio::task::unconstrained(async move {
loop {
let Some(Ok(data)) = poll_fn(|cx| body.poll_data(cx)).await else {
return;
};
body.flow_control().release_capacity(data.len()).unwrap();
let Ok(_) = bw.write_all(&data).await else {
break;
};
}
}));
let f3 = tokio::spawn(tokio::task::unconstrained(async move {
loop {
let mut buf = [0; 65536];
let n = br.read(&mut buf).await.unwrap();
if n == 0 {
break;
}
resp.reserve_capacity(n);
poll_fn(|cx| resp.poll_capacity(cx)).await;
resp
.send_data(Bytes::copy_from_slice(&buf[0..n]), false)
.unwrap();
}
resp.send_data(Bytes::new(), true).unwrap();
}));
_ = join3(f1, f2, f3).await;
Ok(())
}
async fn close_websocket_handler(
ws: fastwebsockets::WebSocket<TokioIo<Upgraded>>,
) -> Result<(), anyhow::Error> {
let mut ws = FragmentCollector::new(ws);
ws.write_frame(Frame::close_raw(vec![].into()))
.await
.unwrap();
Ok(())
}
async fn ping_websocket_handler(
ws: fastwebsockets::WebSocket<TokioIo<Upgraded>>,
) -> Result<(), anyhow::Error> {
let mut ws = FragmentCollector::new(ws);
for i in 0..9 {
ws.write_frame(Frame::new(true, OpCode::Ping, None, vec![].into()))
.await
.unwrap();
let frame = ws.read_frame().await.unwrap();
assert_eq!(frame.opcode, OpCode::Pong);
assert!(frame.payload.is_empty());
ws.write_frame(Frame::text(
format!("hello {}", i).as_bytes().to_vec().into(),
))
.await
.unwrap();
let frame = ws.read_frame().await.unwrap();
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload, format!("hello {}", i).as_bytes());
}
ws.write_frame(Frame::close(1000, b"")).await.unwrap();
Ok(())
}