mirror of
https://github.com/denoland/deno.git
synced 2025-01-09 23:58:23 -05:00
114 lines
3.3 KiB
Rust
114 lines
3.3 KiB
Rust
// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
|
|
use bytes::Buf;
|
|
use bytes::Bytes;
|
|
use deno_net::raw::NetworkStream;
|
|
use hyper::upgrade::Upgraded;
|
|
use std::pin::Pin;
|
|
use std::task::Poll;
|
|
use tokio::io::AsyncRead;
|
|
use tokio::io::AsyncWrite;
|
|
use tokio::io::ReadBuf;
|
|
|
|
// TODO(bartlomieju): remove this
|
|
pub(crate) enum WsStreamKind {
|
|
Upgraded(Upgraded),
|
|
Network(NetworkStream),
|
|
}
|
|
|
|
pub(crate) struct WebSocketStream {
|
|
stream: WsStreamKind,
|
|
pre: Option<Bytes>,
|
|
}
|
|
|
|
impl WebSocketStream {
|
|
pub fn new(stream: WsStreamKind, buffer: Option<Bytes>) -> Self {
|
|
Self {
|
|
stream,
|
|
pre: buffer,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AsyncRead for WebSocketStream {
|
|
// From hyper's Rewind (https://github.com/hyperium/hyper), MIT License, Copyright (c) Sean McArthur
|
|
fn poll_read(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<std::io::Result<()>> {
|
|
if let Some(mut prefix) = self.pre.take() {
|
|
// If there are no remaining bytes, let the bytes get dropped.
|
|
if !prefix.is_empty() {
|
|
let copy_len = std::cmp::min(prefix.len(), buf.remaining());
|
|
// TODO: There should be a way to do following two lines cleaner...
|
|
buf.put_slice(&prefix[..copy_len]);
|
|
prefix.advance(copy_len);
|
|
// Put back what's left
|
|
if !prefix.is_empty() {
|
|
self.pre = Some(prefix);
|
|
}
|
|
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
}
|
|
match &mut self.stream {
|
|
WsStreamKind::Network(stream) => Pin::new(stream).poll_read(cx, buf),
|
|
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_read(cx, buf),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AsyncWrite for WebSocketStream {
|
|
fn poll_write(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
buf: &[u8],
|
|
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
|
match &mut self.stream {
|
|
WsStreamKind::Network(stream) => Pin::new(stream).poll_write(cx, buf),
|
|
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_write(cx, buf),
|
|
}
|
|
}
|
|
|
|
fn poll_flush(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Result<(), std::io::Error>> {
|
|
match &mut self.stream {
|
|
WsStreamKind::Network(stream) => Pin::new(stream).poll_flush(cx),
|
|
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_flush(cx),
|
|
}
|
|
}
|
|
|
|
fn poll_shutdown(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Result<(), std::io::Error>> {
|
|
match &mut self.stream {
|
|
WsStreamKind::Network(stream) => Pin::new(stream).poll_shutdown(cx),
|
|
WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_shutdown(cx),
|
|
}
|
|
}
|
|
|
|
fn is_write_vectored(&self) -> bool {
|
|
match &self.stream {
|
|
WsStreamKind::Network(stream) => stream.is_write_vectored(),
|
|
WsStreamKind::Upgraded(stream) => stream.is_write_vectored(),
|
|
}
|
|
}
|
|
|
|
fn poll_write_vectored(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
bufs: &[std::io::IoSlice<'_>],
|
|
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
|
match &mut self.stream {
|
|
WsStreamKind::Network(stream) => {
|
|
Pin::new(stream).poll_write_vectored(cx, bufs)
|
|
}
|
|
WsStreamKind::Upgraded(stream) => {
|
|
Pin::new(stream).poll_write_vectored(cx, bufs)
|
|
}
|
|
}
|
|
}
|
|
}
|