// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. use std::pin::Pin; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::Arc; use std::task::Context; use std::task::Poll; use bytes::Bytes; use deno_core::futures::Stream; use pin_project::pin_project; use tokio::io::AsyncRead; use tokio_util::io::ReaderStream; /// [ExternallyAbortableByteStream] adapts a [tokio::AsyncRead] into a [Stream]. /// It is used to bridge between the HTTP response body resource, and /// `hyper::Body`. The stream has the special property that it errors if the /// underlying reader is closed before an explicit EOF is sent (in the form of /// setting the `shutdown` flag to true). #[pin_project] pub struct ExternallyAbortableReaderStream<R: AsyncRead> { #[pin] inner: ReaderStream<R>, done: Arc<AtomicBool>, } pub struct ShutdownHandle(Arc<AtomicBool>); impl ShutdownHandle { pub fn shutdown(&self) { self.0.store(true, std::sync::atomic::Ordering::SeqCst); } } impl<R: AsyncRead> ExternallyAbortableReaderStream<R> { pub fn new(reader: R) -> (Self, ShutdownHandle) { let done = Arc::new(AtomicBool::new(false)); let this = Self { inner: ReaderStream::new(reader), done: done.clone(), }; (this, ShutdownHandle(done)) } } impl<R: AsyncRead> Stream for ExternallyAbortableReaderStream<R> { type Item = std::io::Result<Bytes>; fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Self::Item>> { let this = self.project(); let val = std::task::ready!(this.inner.poll_next(cx)); match val { None if this.done.load(Ordering::SeqCst) => Poll::Ready(None), None => Poll::Ready(Some(Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, "stream reader has shut down", )))), Some(val) => Poll::Ready(Some(val)), } } } #[cfg(test)] mod tests { use super::*; use bytes::Bytes; use deno_core::futures::StreamExt; use tokio::io::AsyncWriteExt; #[tokio::test] async fn success() { let (a, b) = tokio::io::duplex(64 * 1024); let (reader, _) = tokio::io::split(a); let (_, mut writer) = tokio::io::split(b); let (mut stream, shutdown_handle) = ExternallyAbortableReaderStream::new(reader); writer.write_all(b"hello").await.unwrap(); assert_eq!(stream.next().await.unwrap().unwrap(), Bytes::from("hello")); writer.write_all(b"world").await.unwrap(); assert_eq!(stream.next().await.unwrap().unwrap(), Bytes::from("world")); shutdown_handle.shutdown(); writer.shutdown().await.unwrap(); drop(writer); assert!(stream.next().await.is_none()); } #[tokio::test] async fn error() { let (a, b) = tokio::io::duplex(64 * 1024); let (reader, _) = tokio::io::split(a); let (_, mut writer) = tokio::io::split(b); let (mut stream, _shutdown_handle) = ExternallyAbortableReaderStream::new(reader); writer.write_all(b"hello").await.unwrap(); assert_eq!(stream.next().await.unwrap().unwrap(), Bytes::from("hello")); drop(writer); assert_eq!( stream.next().await.unwrap().unwrap_err().kind(), std::io::ErrorKind::UnexpectedEof ); } #[tokio::test] async fn error2() { let (a, b) = tokio::io::duplex(64 * 1024); let (reader, _) = tokio::io::split(a); let (_, mut writer) = tokio::io::split(b); let (mut stream, _shutdown_handle) = ExternallyAbortableReaderStream::new(reader); writer.write_all(b"hello").await.unwrap(); assert_eq!(stream.next().await.unwrap().unwrap(), Bytes::from("hello")); writer.shutdown().await.unwrap(); drop(writer); assert_eq!( stream.next().await.unwrap().unwrap_err().kind(), std::io::ErrorKind::UnexpectedEof ); } #[tokio::test] async fn write_after_shutdown() { let (a, b) = tokio::io::duplex(64 * 1024); let (reader, _) = tokio::io::split(a); let (_, mut writer) = tokio::io::split(b); let (mut stream, shutdown_handle) = ExternallyAbortableReaderStream::new(reader); writer.write_all(b"hello").await.unwrap(); assert_eq!(stream.next().await.unwrap().unwrap(), Bytes::from("hello")); writer.write_all(b"world").await.unwrap(); assert_eq!(stream.next().await.unwrap().unwrap(), Bytes::from("world")); shutdown_handle.shutdown(); writer.shutdown().await.unwrap(); assert!(writer.write_all(b"!").await.is_err()); drop(writer); assert!(stream.next().await.is_none()); } }