// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use std::marker::PhantomData;
use bytes::Bytes;
use bytes::BytesMut;
use httparse::Status;
use hyper::header::HeaderName;
use hyper::header::HeaderValue;
use hyper::Response;
use memmem::Searcher;
use memmem::TwoWaySearcher;
use once_cell::sync::OnceCell;
#[derive(Debug, thiserror::Error)]
pub enum WebSocketUpgradeError {
#[error("invalid headers")]
InvalidHeaders,
#[error("{0}")]
HttpParse(#[from] httparse::Error),
#[error("{0}")]
Http(#[from] http::Error),
#[error("{0}")]
Utf8(#[from] std::str::Utf8Error),
#[error("{0}")]
InvalidHeaderName(#[from] http::header::InvalidHeaderName),
#[error("{0}")]
InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
#[error("invalid HTTP status line")]
InvalidHttpStatusLine,
#[error("attempted to write to completed upgrade buffer")]
UpgradeBufferAlreadyCompleted,
}
/// Given a buffer that ends in `\n\n` or `\r\n\r\n`, returns a parsed [`Request
`].
fn parse_response(
header_bytes: &[u8],
) -> Result<(usize, Response), WebSocketUpgradeError> {
let mut headers = [httparse::EMPTY_HEADER; 16];
let status = httparse::parse_headers(header_bytes, &mut headers)?;
match status {
Status::Complete((index, parsed)) => {
let mut resp = Response::builder().status(101).body(T::default())?;
for header in parsed.iter() {
resp.headers_mut().append(
HeaderName::from_bytes(header.name.as_bytes())?,
HeaderValue::from_str(std::str::from_utf8(header.value)?)?,
);
}
Ok((index, resp))
}
_ => Err(WebSocketUpgradeError::InvalidHeaders),
}
}
/// Find a newline in a slice.
fn find_newline(slice: &[u8]) -> Option {
for (i, byte) in slice.iter().enumerate() {
if *byte == b'\n' {
return Some(i);
}
}
None
}
/// WebSocket upgrade state machine states.
#[derive(Default)]
enum WebSocketUpgradeState {
#[default]
Initial,
StatusLine,
Headers,
Complete,
}
static HEADER_SEARCHER: OnceCell = OnceCell::new();
static HEADER_SEARCHER2: OnceCell = OnceCell::new();
#[derive(Default)]
pub struct WebSocketUpgrade {
state: WebSocketUpgradeState,
buf: BytesMut,
_t: PhantomData,
}
impl WebSocketUpgrade {
/// Ensures that the status line starts with "HTTP/1.1 101 " which matches all of the node.js
/// WebSocket libraries that are known. We don't care about the trailing status text.
fn validate_status(
&self,
status: &[u8],
) -> Result<(), WebSocketUpgradeError> {
if status.starts_with(b"HTTP/1.1 101 ") {
Ok(())
} else {
Err(WebSocketUpgradeError::InvalidHttpStatusLine)
}
}
/// Writes bytes to our upgrade buffer, returning [`Ok(None)`] if we need to keep feeding it data,
/// [`Ok(Some(Response))`] if we got a valid upgrade header, or [`Err`] if something went badly.
pub fn write(
&mut self,
bytes: &[u8],
) -> Result