2024-01-01 14:58:21 -05:00
|
|
|
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
|
2023-12-14 11:52:12 -05:00
|
|
|
|
|
|
|
use anyhow::anyhow;
|
|
|
|
use bytes::Bytes;
|
2023-12-26 15:53:28 -05:00
|
|
|
use fastwebsockets::FragmentCollector;
|
|
|
|
use fastwebsockets::Frame;
|
|
|
|
use fastwebsockets::OpCode;
|
|
|
|
use fastwebsockets::Role;
|
|
|
|
use fastwebsockets::WebSocket;
|
2023-12-14 11:52:12 -05:00
|
|
|
use futures::future::join3;
|
|
|
|
use futures::future::poll_fn;
|
|
|
|
use futures::Future;
|
|
|
|
use futures::StreamExt;
|
2023-12-27 08:38:44 -05:00
|
|
|
use h2::server::Handshake;
|
|
|
|
use h2::server::SendResponse;
|
|
|
|
use h2::Reason;
|
|
|
|
use h2::RecvStream;
|
2023-12-27 11:59:57 -05:00
|
|
|
use hyper::upgrade::Upgraded;
|
|
|
|
use hyper::Method;
|
|
|
|
use hyper::Request;
|
|
|
|
use hyper::Response;
|
|
|
|
use hyper::StatusCode;
|
2023-12-25 11:38:48 -05:00
|
|
|
use hyper_util::rt::TokioIo;
|
2023-12-14 11:52:12 -05:00
|
|
|
use pretty_assertions::assert_eq;
|
|
|
|
use std::pin::Pin;
|
|
|
|
use std::result::Result;
|
|
|
|
use tokio::io::AsyncReadExt;
|
|
|
|
use tokio::io::AsyncWriteExt;
|
|
|
|
|
|
|
|
use super::get_tcp_listener_stream;
|
|
|
|
use super::get_tls_listener_stream;
|
|
|
|
use super::SupportedHttpVersions;
|
|
|
|
|
|
|
|
pub async fn run_ws_server(port: u16) {
|
|
|
|
let mut tcp = get_tcp_listener_stream("ws", port).await;
|
|
|
|
while let Some(Ok(stream)) = tcp.next().await {
|
|
|
|
spawn_ws_server(stream, |ws| Box::pin(echo_websocket_handler(ws)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn run_ws_ping_server(port: u16) {
|
|
|
|
let mut tcp = get_tcp_listener_stream("ws (ping)", port).await;
|
|
|
|
while let Some(Ok(stream)) = tcp.next().await {
|
|
|
|
spawn_ws_server(stream, |ws| Box::pin(ping_websocket_handler(ws)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn run_wss_server(port: u16) {
|
|
|
|
let mut tls = get_tls_listener_stream("wss", port, Default::default()).await;
|
|
|
|
while let Some(Ok(tls_stream)) = tls.next().await {
|
|
|
|
tokio::spawn(async move {
|
|
|
|
spawn_ws_server(tls_stream, |ws| Box::pin(echo_websocket_handler(ws)));
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn run_ws_close_server(port: u16) {
|
|
|
|
let mut tcp = get_tcp_listener_stream("ws (close)", port).await;
|
|
|
|
while let Some(Ok(stream)) = tcp.next().await {
|
|
|
|
spawn_ws_server(stream, |ws| Box::pin(close_websocket_handler(ws)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub async fn run_wss2_server(port: u16) {
|
|
|
|
let mut tls = get_tls_listener_stream(
|
|
|
|
"wss2 (tls)",
|
|
|
|
port,
|
|
|
|
SupportedHttpVersions::Http2Only,
|
|
|
|
)
|
|
|
|
.await;
|
|
|
|
while let Some(Ok(tls)) = tls.next().await {
|
|
|
|
tokio::spawn(async move {
|
2023-12-27 08:38:44 -05:00
|
|
|
let mut h2 = h2::server::Builder::new();
|
2023-12-14 11:52:12 -05:00
|
|
|
h2.enable_connect_protocol();
|
|
|
|
// Using Bytes is pretty alloc-heavy but this is a test server
|
|
|
|
let server: Handshake<_, Bytes> = h2.handshake(tls);
|
|
|
|
let mut server = match server.await {
|
|
|
|
Ok(server) => server,
|
|
|
|
Err(e) => {
|
|
|
|
println!("Failed to handshake h2: {e:?}");
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
loop {
|
|
|
|
let Some(conn) = server.accept().await else {
|
|
|
|
break;
|
|
|
|
};
|
|
|
|
let (recv, send) = match conn {
|
|
|
|
Ok(conn) => conn,
|
|
|
|
Err(e) => {
|
|
|
|
println!("Failed to accept a connection: {e:?}");
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
tokio::spawn(handle_wss_stream(recv, send));
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn echo_websocket_handler(
|
2023-12-26 15:53:28 -05:00
|
|
|
ws: fastwebsockets::WebSocket<TokioIo<Upgraded>>,
|
2023-12-14 11:52:12 -05:00
|
|
|
) -> Result<(), anyhow::Error> {
|
2023-12-25 11:38:48 -05:00
|
|
|
let mut ws = FragmentCollector::new(ws);
|
2023-12-14 11:52:12 -05:00
|
|
|
|
|
|
|
loop {
|
|
|
|
let frame = ws.read_frame().await.unwrap();
|
|
|
|
match frame.opcode {
|
2023-12-25 11:38:48 -05:00
|
|
|
OpCode::Close => break,
|
|
|
|
OpCode::Text | OpCode::Binary => {
|
2023-12-14 11:52:12 -05:00
|
|
|
ws.write_frame(frame).await.unwrap();
|
|
|
|
}
|
|
|
|
_ => {}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
type WsHandler =
|
|
|
|
fn(
|
2023-12-26 15:53:28 -05:00
|
|
|
fastwebsockets::WebSocket<TokioIo<Upgraded>>,
|
2023-12-14 11:52:12 -05:00
|
|
|
) -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
|
|
|
|
|
|
|
|
fn spawn_ws_server<S>(stream: S, handler: WsHandler)
|
|
|
|
where
|
|
|
|
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
|
|
|
{
|
2023-12-27 11:59:57 -05:00
|
|
|
let service = hyper::service::service_fn(
|
|
|
|
move |mut req: http::Request<hyper::body::Incoming>| async move {
|
2023-12-26 15:53:28 -05:00
|
|
|
let (response, upgrade_fut) = fastwebsockets::upgrade::upgrade(&mut req)
|
|
|
|
.map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))?;
|
2023-12-14 11:52:12 -05:00
|
|
|
|
2023-12-25 11:38:48 -05:00
|
|
|
tokio::spawn(async move {
|
|
|
|
let ws = upgrade_fut
|
|
|
|
.await
|
|
|
|
.map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))
|
|
|
|
.unwrap();
|
2023-12-14 11:52:12 -05:00
|
|
|
|
2023-12-25 11:38:48 -05:00
|
|
|
if let Err(e) = handler(ws).await {
|
|
|
|
eprintln!("Error in websocket connection: {}", e);
|
|
|
|
}
|
|
|
|
});
|
2023-12-14 11:52:12 -05:00
|
|
|
|
2023-12-25 11:38:48 -05:00
|
|
|
Ok::<_, anyhow::Error>(response)
|
|
|
|
},
|
|
|
|
);
|
2023-12-14 11:52:12 -05:00
|
|
|
|
2023-12-25 11:38:48 -05:00
|
|
|
let io = TokioIo::new(stream);
|
2023-12-14 11:52:12 -05:00
|
|
|
tokio::spawn(async move {
|
2023-12-27 11:59:57 -05:00
|
|
|
let conn = hyper::server::conn::http1::Builder::new()
|
2023-12-25 11:38:48 -05:00
|
|
|
.serve_connection(io, service)
|
2023-12-14 11:52:12 -05:00
|
|
|
.with_upgrades();
|
|
|
|
|
2023-12-25 11:38:48 -05:00
|
|
|
if let Err(e) = conn.await {
|
2023-12-14 11:52:12 -05:00
|
|
|
eprintln!("websocket server error: {e:?}");
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn handle_wss_stream(
|
|
|
|
recv: Request<RecvStream>,
|
|
|
|
mut send: SendResponse<Bytes>,
|
2023-12-27 08:38:44 -05:00
|
|
|
) -> Result<(), h2::Error> {
|
2023-12-14 11:52:12 -05:00
|
|
|
if recv.method() != Method::CONNECT {
|
|
|
|
eprintln!("wss2: refusing non-CONNECT stream");
|
|
|
|
send.send_reset(Reason::REFUSED_STREAM);
|
|
|
|
return Ok(());
|
|
|
|
}
|
2023-12-27 08:38:44 -05:00
|
|
|
let Some(protocol) = recv.extensions().get::<h2::ext::Protocol>() else {
|
2023-12-14 11:52:12 -05:00
|
|
|
eprintln!("wss2: refusing no-:protocol stream");
|
|
|
|
send.send_reset(Reason::REFUSED_STREAM);
|
|
|
|
return Ok(());
|
|
|
|
};
|
|
|
|
if protocol.as_str() != "websocket" && protocol.as_str() != "WebSocket" {
|
|
|
|
eprintln!("wss2: refusing non-websocket stream");
|
|
|
|
send.send_reset(Reason::REFUSED_STREAM);
|
|
|
|
return Ok(());
|
|
|
|
}
|
|
|
|
let mut body = recv.into_body();
|
|
|
|
let mut response = Response::new(());
|
|
|
|
*response.status_mut() = StatusCode::OK;
|
|
|
|
let mut resp = send.send_response(response, false)?;
|
|
|
|
// Use a duplex stream to talk to fastwebsockets because it's just faster to implement
|
|
|
|
let (a, b) = tokio::io::duplex(65536);
|
|
|
|
let f1 = tokio::spawn(tokio::task::unconstrained(async move {
|
|
|
|
let ws = WebSocket::after_handshake(a, Role::Server);
|
|
|
|
let mut ws = FragmentCollector::new(ws);
|
|
|
|
loop {
|
|
|
|
let frame = ws.read_frame().await.unwrap();
|
|
|
|
if frame.opcode == OpCode::Close {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
ws.write_frame(frame).await.unwrap();
|
|
|
|
}
|
|
|
|
}));
|
|
|
|
let (mut br, mut bw) = tokio::io::split(b);
|
|
|
|
let f2 = tokio::spawn(tokio::task::unconstrained(async move {
|
|
|
|
loop {
|
|
|
|
let Some(Ok(data)) = poll_fn(|cx| body.poll_data(cx)).await else {
|
|
|
|
return;
|
|
|
|
};
|
|
|
|
body.flow_control().release_capacity(data.len()).unwrap();
|
|
|
|
let Ok(_) = bw.write_all(&data).await else {
|
|
|
|
break;
|
|
|
|
};
|
|
|
|
}
|
|
|
|
}));
|
|
|
|
let f3 = tokio::spawn(tokio::task::unconstrained(async move {
|
|
|
|
loop {
|
|
|
|
let mut buf = [0; 65536];
|
|
|
|
let n = br.read(&mut buf).await.unwrap();
|
|
|
|
if n == 0 {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
resp.reserve_capacity(n);
|
|
|
|
poll_fn(|cx| resp.poll_capacity(cx)).await;
|
|
|
|
resp
|
|
|
|
.send_data(Bytes::copy_from_slice(&buf[0..n]), false)
|
|
|
|
.unwrap();
|
|
|
|
}
|
|
|
|
resp.send_data(Bytes::new(), true).unwrap();
|
|
|
|
}));
|
|
|
|
_ = join3(f1, f2, f3).await;
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn close_websocket_handler(
|
2023-12-26 15:53:28 -05:00
|
|
|
ws: fastwebsockets::WebSocket<TokioIo<Upgraded>>,
|
2023-12-14 11:52:12 -05:00
|
|
|
) -> Result<(), anyhow::Error> {
|
2023-12-25 11:38:48 -05:00
|
|
|
let mut ws = FragmentCollector::new(ws);
|
2023-12-14 11:52:12 -05:00
|
|
|
|
2023-12-25 11:38:48 -05:00
|
|
|
ws.write_frame(Frame::close_raw(vec![].into()))
|
2023-12-14 11:52:12 -05:00
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn ping_websocket_handler(
|
2023-12-26 15:53:28 -05:00
|
|
|
ws: fastwebsockets::WebSocket<TokioIo<Upgraded>>,
|
2023-12-14 11:52:12 -05:00
|
|
|
) -> Result<(), anyhow::Error> {
|
2023-12-25 11:38:48 -05:00
|
|
|
let mut ws = FragmentCollector::new(ws);
|
2023-12-14 11:52:12 -05:00
|
|
|
|
|
|
|
for i in 0..9 {
|
|
|
|
ws.write_frame(Frame::new(true, OpCode::Ping, None, vec![].into()))
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
let frame = ws.read_frame().await.unwrap();
|
|
|
|
assert_eq!(frame.opcode, OpCode::Pong);
|
|
|
|
assert!(frame.payload.is_empty());
|
|
|
|
|
|
|
|
ws.write_frame(Frame::text(
|
|
|
|
format!("hello {}", i).as_bytes().to_vec().into(),
|
|
|
|
))
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
let frame = ws.read_frame().await.unwrap();
|
|
|
|
assert_eq!(frame.opcode, OpCode::Text);
|
|
|
|
assert_eq!(frame.payload, format!("hello {}", i).as_bytes());
|
|
|
|
}
|
|
|
|
|
2023-12-25 11:38:48 -05:00
|
|
|
ws.write_frame(Frame::close(1000, b"")).await.unwrap();
|
2023-12-14 11:52:12 -05:00
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|