diff options
Diffstat (limited to 'ext/web/stream_resource.rs')
-rw-r--r-- | ext/web/stream_resource.rs | 540 |
1 files changed, 456 insertions, 84 deletions
diff --git a/ext/web/stream_resource.rs b/ext/web/stream_resource.rs index 8a454da73..1ee6ff963 100644 --- a/ext/web/stream_resource.rs +++ b/ext/web/stream_resource.rs @@ -1,9 +1,10 @@ // Copyright 2018-2023 the Deno authors. All rights reserved. MIT license. -use deno_core::anyhow::Error; +use bytes::BytesMut; use deno_core::error::type_error; use deno_core::error::AnyError; use deno_core::op2; -use deno_core::AsyncRefCell; +use deno_core::serde_v8::V8Slice; +use deno_core::unsync::TaskQueue; use deno_core::AsyncResult; use deno_core::BufView; use deno_core::CancelFuture; @@ -14,61 +15,330 @@ use deno_core::RcLike; use deno_core::RcRef; use deno_core::Resource; use deno_core::ResourceId; -use futures::stream::Peekable; -use futures::Stream; -use futures::StreamExt; +use futures::future::poll_fn; use std::borrow::Cow; use std::cell::RefCell; +use std::cell::RefMut; use std::ffi::c_void; use std::future::Future; +use std::marker::PhantomData; +use std::mem::MaybeUninit; use std::pin::Pin; use std::rc::Rc; use std::task::Context; use std::task::Poll; use std::task::Waker; -use tokio::sync::mpsc::Receiver; -use tokio::sync::mpsc::Sender; -type SenderCell = RefCell<Option<Sender<Result<BufView, Error>>>>; +// How many buffers we'll allow in the channel before we stop allowing writes. +const BUFFER_CHANNEL_SIZE: u16 = 1024; -// This indirection allows us to more easily integrate the fast streams work at a later date -#[repr(transparent)] -struct ChannelStreamAdapter<C>(C); - -impl<C> Stream for ChannelStreamAdapter<C> -where - C: ChannelBytesRead, -{ - type Item = Result<BufView, AnyError>; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<Option<Self::Item>> { - self.0.poll_recv(cx) +// How much data is in the channel before we stop allowing writes. +const BUFFER_BACKPRESSURE_LIMIT: usize = 64 * 1024; + +// Optimization: prevent multiple small writes from adding overhead. +// +// If the total size of the channel is less than this value and there is more than one buffer available +// to read, we will allocate a buffer to store the entire contents of the channel and copy each value from +// the channel rather than yielding them one at a time. +const BUFFER_AGGREGATION_LIMIT: usize = 1024; + +struct BoundedBufferChannelInner { + buffers: [MaybeUninit<V8Slice<u8>>; BUFFER_CHANNEL_SIZE as _], + ring_producer: u16, + ring_consumer: u16, + error: Option<AnyError>, + current_size: usize, + // TODO(mmastrac): we can math this field instead of accounting for it + len: usize, + closed: bool, + + read_waker: Option<Waker>, + write_waker: Option<Waker>, + + _unsend: PhantomData<std::sync::MutexGuard<'static, ()>>, +} + +impl Default for BoundedBufferChannelInner { + fn default() -> Self { + Self::new() } } -pub trait ChannelBytesRead: Unpin + 'static { - fn poll_recv( - &mut self, - cx: &mut Context<'_>, - ) -> Poll<Option<Result<BufView, AnyError>>>; +impl std::fmt::Debug for BoundedBufferChannelInner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!( + "[BoundedBufferChannel closed={} error={:?} ring={}->{} len={} size={}]", + self.closed, + self.error, + self.ring_producer, + self.ring_consumer, + self.len, + self.current_size + )) + } } -impl ChannelBytesRead for tokio::sync::mpsc::Receiver<Result<BufView, Error>> { - fn poll_recv( - &mut self, - cx: &mut Context<'_>, - ) -> Poll<Option<Result<BufView, AnyError>>> { - self.poll_recv(cx) +impl BoundedBufferChannelInner { + pub fn new() -> Self { + const UNINIT: MaybeUninit<V8Slice<u8>> = MaybeUninit::uninit(); + Self { + buffers: [UNINIT; BUFFER_CHANNEL_SIZE as _], + ring_producer: 0, + ring_consumer: 0, + len: 0, + closed: false, + error: None, + current_size: 0, + read_waker: None, + write_waker: None, + _unsend: PhantomData, + } + } + + /// # Safety + /// + /// This doesn't check whether `ring_consumer` is valid, so you'd better make sure it is before + /// calling this. + #[inline(always)] + unsafe fn next_unsafe(&mut self) -> &mut V8Slice<u8> { + self + .buffers + .get_unchecked_mut(self.ring_consumer as usize) + .assume_init_mut() + } + + /// # Safety + /// + /// This doesn't check whether `ring_consumer` is valid, so you'd better make sure it is before + /// calling this. + #[inline(always)] + unsafe fn take_next_unsafe(&mut self) -> V8Slice<u8> { + let res = std::ptr::read(self.next_unsafe()); + self.ring_consumer = (self.ring_consumer + 1) % BUFFER_CHANNEL_SIZE; + + res + } + + fn drain(&mut self, mut f: impl FnMut(V8Slice<u8>)) { + while self.ring_producer != self.ring_consumer { + // SAFETY: We know the ring indexes are valid + let res = unsafe { std::ptr::read(self.next_unsafe()) }; + self.ring_consumer = (self.ring_consumer + 1) % BUFFER_CHANNEL_SIZE; + f(res); + } + self.current_size = 0; + self.ring_producer = 0; + self.ring_consumer = 0; + self.len = 0; + } + + pub fn read(&mut self, limit: usize) -> Result<Option<BufView>, AnyError> { + // Empty buffers will return the error, if one exists, or None + if self.len == 0 { + if let Some(error) = self.error.take() { + return Err(error); + } else { + return Ok(None); + } + } + + // If we have less than the aggregation limit AND we have more than one buffer in the channel, + // aggregate and return everything in a single buffer. + if limit >= BUFFER_AGGREGATION_LIMIT + && self.current_size <= BUFFER_AGGREGATION_LIMIT + && self.len > 1 + { + let mut bytes = BytesMut::with_capacity(BUFFER_AGGREGATION_LIMIT); + self.drain(|slice| { + bytes.extend_from_slice(slice.as_ref()); + }); + + // We can always write again + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + + return Ok(Some(BufView::from(bytes.freeze()))); + } + + // SAFETY: We know this exists + let buf = unsafe { self.next_unsafe() }; + let buf = if buf.len() <= limit { + self.current_size -= buf.len(); + self.len -= 1; + // SAFETY: We know this exists + unsafe { self.take_next_unsafe() } + } else { + let buf = buf.split_to(limit); + self.current_size -= limit; + buf + }; + + // If current_size is zero, len must be zero (and if not, len must not be) + debug_assert!( + !((self.current_size == 0) ^ (self.len == 0)), + "Length accounting mismatch: {self:?}" + ); + + if self.write_waker.is_some() { + // We may be able to write again if we have buffer and byte room in the channel + if self.can_write() { + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + } + } + + Ok(Some(BufView::from(JsBuffer::from_parts(buf)))) + } + + pub fn write(&mut self, buffer: V8Slice<u8>) -> Result<(), V8Slice<u8>> { + let next_producer_index = (self.ring_producer + 1) % BUFFER_CHANNEL_SIZE; + if next_producer_index == self.ring_consumer { + return Err(buffer); + } + + self.current_size += buffer.len(); + + // SAFETY: we know the ringbuffer bounds are correct + unsafe { + *self.buffers.get_unchecked_mut(self.ring_producer as usize) = + MaybeUninit::new(buffer) + }; + self.ring_producer = next_producer_index; + self.len += 1; + debug_assert!(self.ring_producer != self.ring_consumer); + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + Ok(()) + } + + pub fn write_error(&mut self, error: AnyError) { + self.error = Some(error); + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + } + + #[inline(always)] + pub fn can_read(&self) -> bool { + // Read will return if: + // - the stream is closed + // - there is an error + // - the stream is not empty + self.closed + || self.error.is_some() + || self.ring_consumer != self.ring_producer + } + + #[inline(always)] + pub fn can_write(&self) -> bool { + // Write will return if: + // - the stream is closed + // - there is an error + // - the stream is not full (either buffer or byte count) + let next_producer_index = (self.ring_producer + 1) % BUFFER_CHANNEL_SIZE; + self.closed + || self.error.is_some() + || (next_producer_index != self.ring_consumer + && self.current_size < BUFFER_BACKPRESSURE_LIMIT) + } + + pub fn poll_read_ready(&mut self, cx: &mut Context) -> Poll<()> { + if !self.can_read() { + self.read_waker = Some(cx.waker().clone()); + Poll::Pending + } else { + self.read_waker.take(); + Poll::Ready(()) + } + } + + pub fn poll_write_ready(&mut self, cx: &mut Context) -> Poll<()> { + if !self.can_write() { + self.write_waker = Some(cx.waker().clone()); + Poll::Pending + } else { + self.write_waker.take(); + Poll::Ready(()) + } + } + + pub fn close(&mut self) { + self.closed = true; + // Wake up reads and writes, since they'll both be able to proceed forever now + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + } +} + +#[repr(transparent)] +#[derive(Clone, Default)] +struct BoundedBufferChannel { + inner: Rc<RefCell<BoundedBufferChannelInner>>, +} + +impl BoundedBufferChannel { + // TODO(mmastrac): in release mode we should be able to make this an UnsafeCell + #[inline(always)] + fn inner(&self) -> RefMut<BoundedBufferChannelInner> { + self.inner.borrow_mut() + } + + pub fn into_raw(self) -> *const BoundedBufferChannel { + Rc::into_raw(self.inner) as _ + } + + pub unsafe fn clone_from_raw(ptr: *const BoundedBufferChannel) -> Self { + let rc = Rc::from_raw(ptr as *const RefCell<BoundedBufferChannelInner>); + let clone = rc.clone(); + std::mem::forget(rc); + std::mem::transmute(clone) + } + + pub fn read(&self, limit: usize) -> Result<Option<BufView>, AnyError> { + self.inner().read(limit) + } + + pub fn write(&self, buffer: V8Slice<u8>) -> Result<(), V8Slice<u8>> { + self.inner().write(buffer) + } + + pub fn write_error(&self, error: AnyError) { + self.inner().write_error(error) + } + + pub fn poll_read_ready(&self, cx: &mut Context) -> Poll<()> { + self.inner().poll_read_ready(cx) + } + + pub fn poll_write_ready(&self, cx: &mut Context) -> Poll<()> { + self.inner().poll_write_ready(cx) + } + + pub fn closed(&self) -> bool { + self.inner().closed + } + + #[cfg(test)] + pub fn byte_size(&self) -> usize { + self.inner().current_size + } + + pub fn close(&self) { + self.inner().close() } } #[allow(clippy::type_complexity)] struct ReadableStreamResource { - reader: AsyncRefCell< - Peekable<ChannelStreamAdapter<Receiver<Result<BufView, Error>>>>, - >, + read_queue: Rc<TaskQueue>, + channel: BoundedBufferChannel, cancel_handle: CancelHandle, data: ReadableStreamResourceData, } @@ -80,26 +350,15 @@ impl ReadableStreamResource { async fn read(self: Rc<Self>, limit: usize) -> Result<BufView, AnyError> { let cancel_handle = self.cancel_handle(); - let peekable = RcRef::map(self, |this| &this.reader); - let mut peekable = peekable.borrow_mut().await; - match Pin::new(&mut *peekable) - .peek_mut() + // Serialize all the reads using a task queue. + let _read_permit = self.read_queue.acquire().await; + poll_fn(|cx| self.channel.poll_read_ready(cx)) .or_cancel(cancel_handle) - .await? - { - None => Ok(BufView::empty()), - // Take the actual error since we only have a reference to it - Some(Err(_)) => Err(peekable.next().await.unwrap().err().unwrap()), - Some(Ok(bytes)) => { - if bytes.len() <= limit { - // We can safely take the next item since we peeked it - return peekable.next().await.unwrap(); - } - // The remainder of the bytes after we split it is still left in the peek buffer - let ret = bytes.split_to(limit); - Ok(ret) - } - } + .await?; + self + .channel + .read(limit) + .map(|buf| buf.unwrap_or_else(BufView::empty)) } } @@ -114,6 +373,7 @@ impl Resource for ReadableStreamResource { fn close(self: Rc<Self>) { self.cancel_handle.cancel(); + self.channel.close(); } } @@ -163,17 +423,12 @@ impl Future for CompletionHandle { #[op2(fast)] #[smi] pub fn op_readable_stream_resource_allocate(state: &mut OpState) -> ResourceId { - let (tx, rx) = tokio::sync::mpsc::channel(1); - let tx = RefCell::new(Some(tx)); let completion = CompletionHandle::default(); - let tx = Box::new(tx); let resource = ReadableStreamResource { + read_queue: Default::default(), cancel_handle: Default::default(), - reader: AsyncRefCell::new(ChannelStreamAdapter(rx).peekable()), - data: ReadableStreamResourceData { - tx: Box::into_raw(tx), - completion, - }, + channel: BoundedBufferChannel::default(), + data: ReadableStreamResourceData { completion }, }; state.resource_table.add(resource) } @@ -187,23 +442,19 @@ pub fn op_readable_stream_resource_get_sink( else { return std::ptr::null(); }; - resource.data.tx as _ + resource.channel.clone().into_raw() as _ } -fn get_sender(sender: *const c_void) -> Option<Sender<Result<BufView, Error>>> { +fn get_sender(sender: *const c_void) -> BoundedBufferChannel { // SAFETY: We know this is a valid v8::External - unsafe { - (sender as *const SenderCell) - .as_ref() - .and_then(|r| r.borrow_mut().as_ref().cloned()) - } + unsafe { BoundedBufferChannel::clone_from_raw(sender as _) } } fn drop_sender(sender: *const c_void) { // SAFETY: We know this is a valid v8::External unsafe { assert!(!sender.is_null()); - _ = Box::from_raw(sender as *mut SenderCell); + _ = Rc::from_raw(sender as *mut RefCell<BoundedBufferChannelInner>); } } @@ -214,10 +465,9 @@ pub fn op_readable_stream_resource_write_buf( ) -> impl Future<Output = bool> { let sender = get_sender(sender); async move { - let Some(sender) = sender else { - return false; - }; - sender.send(Ok(buffer.into())).await.ok().is_some() + poll_fn(|cx| sender.poll_write_ready(cx)).await; + sender.write(buffer.into_parts()).unwrap(); + !sender.closed() } } @@ -228,20 +478,17 @@ pub fn op_readable_stream_resource_write_error( ) -> impl Future<Output = bool> { let sender = get_sender(sender); async move { - let Some(sender) = sender else { - return false; - }; - sender - .send(Err(type_error(Cow::Owned(error)))) - .await - .ok() - .is_some() + // We can always write an error, no polling required + // TODO(mmastrac): we can remove async from this method + sender.write_error(type_error(Cow::Owned(error))); + !sender.closed() } } #[op2(fast)] #[smi] pub fn op_readable_stream_resource_close(sender: *const c_void) { + get_sender(sender).close(); drop_sender(sender); } @@ -264,7 +511,6 @@ pub fn op_readable_stream_resource_await_close( } struct ReadableStreamResourceData { - tx: *const SenderCell, completion: CompletionHandle, } @@ -273,3 +519,129 @@ impl Drop for ReadableStreamResourceData { self.completion.complete(true); } } + +#[cfg(test)] +mod tests { + use super::*; + use deno_core::v8; + use std::cell::OnceCell; + use std::sync::atomic::AtomicUsize; + use std::sync::OnceLock; + use std::time::Duration; + + static V8_GLOBAL: OnceLock<()> = OnceLock::new(); + + thread_local! { + static ISOLATE: OnceCell<std::sync::Mutex<v8::OwnedIsolate>> = OnceCell::new(); + } + + fn with_isolate<T>(mut f: impl FnMut(&mut v8::Isolate) -> T) -> T { + V8_GLOBAL.get_or_init(|| { + let platform = + v8::new_unprotected_default_platform(0, false).make_shared(); + v8::V8::initialize_platform(platform); + v8::V8::initialize(); + }); + ISOLATE.with(|cell| { + let mut isolate = cell + .get_or_init(|| { + std::sync::Mutex::new(v8::Isolate::new(Default::default())) + }) + .try_lock() + .unwrap(); + f(&mut isolate) + }) + } + + fn create_buffer(byte_length: usize) -> V8Slice<u8> { + with_isolate(|isolate| { + let ptr = v8::ArrayBuffer::new_backing_store(isolate, byte_length); + // SAFETY: we just made this + unsafe { V8Slice::from_parts(ptr.into(), 0..byte_length) } + }) + } + + #[test] + fn test_bounded_buffer_channel() { + let channel = BoundedBufferChannel::default(); + + for _ in 0..BUFFER_CHANNEL_SIZE - 1 { + channel.write(create_buffer(1024)).unwrap(); + } + } + + #[tokio::test(flavor = "current_thread")] + async fn test_multi_task() { + let channel = BoundedBufferChannel::default(); + let channel_send = channel.clone(); + + // Fast writer + let a = deno_core::unsync::spawn(async move { + for _ in 0..BUFFER_CHANNEL_SIZE * 2 { + poll_fn(|cx| channel_send.poll_write_ready(cx)).await; + channel_send + .write(create_buffer(BUFFER_AGGREGATION_LIMIT)) + .unwrap(); + } + }); + + // Slightly slower reader + let b = deno_core::unsync::spawn(async move { + for _ in 0..BUFFER_CHANNEL_SIZE * 2 { + tokio::time::sleep(Duration::from_millis(1)).await; + poll_fn(|cx| channel.poll_read_ready(cx)).await; + channel.read(BUFFER_AGGREGATION_LIMIT).unwrap(); + } + }); + + a.await.unwrap(); + b.await.unwrap(); + } + + #[tokio::test(flavor = "current_thread")] + async fn test_multi_task_small_reads() { + let channel = BoundedBufferChannel::default(); + let channel_send = channel.clone(); + + let total_send = Rc::new(AtomicUsize::new(0)); + let total_send_task = total_send.clone(); + let total_recv = Rc::new(AtomicUsize::new(0)); + let total_recv_task = total_recv.clone(); + + // Fast writer + let a = deno_core::unsync::spawn(async move { + for _ in 0..BUFFER_CHANNEL_SIZE * 2 { + poll_fn(|cx| channel_send.poll_write_ready(cx)).await; + channel_send.write(create_buffer(16)).unwrap(); + total_send_task.fetch_add(16, std::sync::atomic::Ordering::SeqCst); + } + // We need to close because we may get aggregated packets and we want a signal + channel_send.close(); + }); + + // Slightly slower reader + let b = deno_core::unsync::spawn(async move { + for _ in 0..BUFFER_CHANNEL_SIZE * 2 { + poll_fn(|cx| channel.poll_read_ready(cx)).await; + // We want to make sure we're aggregating at least some packets + while channel.byte_size() <= 16 && !channel.closed() { + tokio::time::sleep(Duration::from_millis(1)).await; + } + let len = channel + .read(1024) + .unwrap() + .map(|b| b.len()) + .unwrap_or_default(); + total_recv_task.fetch_add(len, std::sync::atomic::Ordering::SeqCst); + } + }); + + a.await.unwrap(); + b.await.unwrap(); + + assert_eq!( + total_send.load(std::sync::atomic::Ordering::SeqCst), + total_recv.load(std::sync::atomic::Ordering::SeqCst) + ); + } +} |