summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ext/web/stream_resource.rs540
1 files changed, 456 insertions, 84 deletions
diff --git a/ext/web/stream_resource.rs b/ext/web/stream_resource.rs
index 8a454da73..1ee6ff963 100644
--- a/ext/web/stream_resource.rs
+++ b/ext/web/stream_resource.rs
@@ -1,9 +1,10 @@
// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
-use deno_core::anyhow::Error;
+use bytes::BytesMut;
use deno_core::error::type_error;
use deno_core::error::AnyError;
use deno_core::op2;
-use deno_core::AsyncRefCell;
+use deno_core::serde_v8::V8Slice;
+use deno_core::unsync::TaskQueue;
use deno_core::AsyncResult;
use deno_core::BufView;
use deno_core::CancelFuture;
@@ -14,61 +15,330 @@ use deno_core::RcLike;
use deno_core::RcRef;
use deno_core::Resource;
use deno_core::ResourceId;
-use futures::stream::Peekable;
-use futures::Stream;
-use futures::StreamExt;
+use futures::future::poll_fn;
use std::borrow::Cow;
use std::cell::RefCell;
+use std::cell::RefMut;
use std::ffi::c_void;
use std::future::Future;
+use std::marker::PhantomData;
+use std::mem::MaybeUninit;
use std::pin::Pin;
use std::rc::Rc;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
-use tokio::sync::mpsc::Receiver;
-use tokio::sync::mpsc::Sender;
-type SenderCell = RefCell<Option<Sender<Result<BufView, Error>>>>;
+// How many buffers we'll allow in the channel before we stop allowing writes.
+const BUFFER_CHANNEL_SIZE: u16 = 1024;
-// This indirection allows us to more easily integrate the fast streams work at a later date
-#[repr(transparent)]
-struct ChannelStreamAdapter<C>(C);
-
-impl<C> Stream for ChannelStreamAdapter<C>
-where
- C: ChannelBytesRead,
-{
- type Item = Result<BufView, AnyError>;
- fn poll_next(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- self.0.poll_recv(cx)
+// How much data is in the channel before we stop allowing writes.
+const BUFFER_BACKPRESSURE_LIMIT: usize = 64 * 1024;
+
+// Optimization: prevent multiple small writes from adding overhead.
+//
+// If the total size of the channel is less than this value and there is more than one buffer available
+// to read, we will allocate a buffer to store the entire contents of the channel and copy each value from
+// the channel rather than yielding them one at a time.
+const BUFFER_AGGREGATION_LIMIT: usize = 1024;
+
+struct BoundedBufferChannelInner {
+ buffers: [MaybeUninit<V8Slice<u8>>; BUFFER_CHANNEL_SIZE as _],
+ ring_producer: u16,
+ ring_consumer: u16,
+ error: Option<AnyError>,
+ current_size: usize,
+ // TODO(mmastrac): we can math this field instead of accounting for it
+ len: usize,
+ closed: bool,
+
+ read_waker: Option<Waker>,
+ write_waker: Option<Waker>,
+
+ _unsend: PhantomData<std::sync::MutexGuard<'static, ()>>,
+}
+
+impl Default for BoundedBufferChannelInner {
+ fn default() -> Self {
+ Self::new()
}
}
-pub trait ChannelBytesRead: Unpin + 'static {
- fn poll_recv(
- &mut self,
- cx: &mut Context<'_>,
- ) -> Poll<Option<Result<BufView, AnyError>>>;
+impl std::fmt::Debug for BoundedBufferChannelInner {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.write_fmt(format_args!(
+ "[BoundedBufferChannel closed={} error={:?} ring={}->{} len={} size={}]",
+ self.closed,
+ self.error,
+ self.ring_producer,
+ self.ring_consumer,
+ self.len,
+ self.current_size
+ ))
+ }
}
-impl ChannelBytesRead for tokio::sync::mpsc::Receiver<Result<BufView, Error>> {
- fn poll_recv(
- &mut self,
- cx: &mut Context<'_>,
- ) -> Poll<Option<Result<BufView, AnyError>>> {
- self.poll_recv(cx)
+impl BoundedBufferChannelInner {
+ pub fn new() -> Self {
+ const UNINIT: MaybeUninit<V8Slice<u8>> = MaybeUninit::uninit();
+ Self {
+ buffers: [UNINIT; BUFFER_CHANNEL_SIZE as _],
+ ring_producer: 0,
+ ring_consumer: 0,
+ len: 0,
+ closed: false,
+ error: None,
+ current_size: 0,
+ read_waker: None,
+ write_waker: None,
+ _unsend: PhantomData,
+ }
+ }
+
+ /// # Safety
+ ///
+ /// This doesn't check whether `ring_consumer` is valid, so you'd better make sure it is before
+ /// calling this.
+ #[inline(always)]
+ unsafe fn next_unsafe(&mut self) -> &mut V8Slice<u8> {
+ self
+ .buffers
+ .get_unchecked_mut(self.ring_consumer as usize)
+ .assume_init_mut()
+ }
+
+ /// # Safety
+ ///
+ /// This doesn't check whether `ring_consumer` is valid, so you'd better make sure it is before
+ /// calling this.
+ #[inline(always)]
+ unsafe fn take_next_unsafe(&mut self) -> V8Slice<u8> {
+ let res = std::ptr::read(self.next_unsafe());
+ self.ring_consumer = (self.ring_consumer + 1) % BUFFER_CHANNEL_SIZE;
+
+ res
+ }
+
+ fn drain(&mut self, mut f: impl FnMut(V8Slice<u8>)) {
+ while self.ring_producer != self.ring_consumer {
+ // SAFETY: We know the ring indexes are valid
+ let res = unsafe { std::ptr::read(self.next_unsafe()) };
+ self.ring_consumer = (self.ring_consumer + 1) % BUFFER_CHANNEL_SIZE;
+ f(res);
+ }
+ self.current_size = 0;
+ self.ring_producer = 0;
+ self.ring_consumer = 0;
+ self.len = 0;
+ }
+
+ pub fn read(&mut self, limit: usize) -> Result<Option<BufView>, AnyError> {
+ // Empty buffers will return the error, if one exists, or None
+ if self.len == 0 {
+ if let Some(error) = self.error.take() {
+ return Err(error);
+ } else {
+ return Ok(None);
+ }
+ }
+
+ // If we have less than the aggregation limit AND we have more than one buffer in the channel,
+ // aggregate and return everything in a single buffer.
+ if limit >= BUFFER_AGGREGATION_LIMIT
+ && self.current_size <= BUFFER_AGGREGATION_LIMIT
+ && self.len > 1
+ {
+ let mut bytes = BytesMut::with_capacity(BUFFER_AGGREGATION_LIMIT);
+ self.drain(|slice| {
+ bytes.extend_from_slice(slice.as_ref());
+ });
+
+ // We can always write again
+ if let Some(waker) = self.write_waker.take() {
+ waker.wake();
+ }
+
+ return Ok(Some(BufView::from(bytes.freeze())));
+ }
+
+ // SAFETY: We know this exists
+ let buf = unsafe { self.next_unsafe() };
+ let buf = if buf.len() <= limit {
+ self.current_size -= buf.len();
+ self.len -= 1;
+ // SAFETY: We know this exists
+ unsafe { self.take_next_unsafe() }
+ } else {
+ let buf = buf.split_to(limit);
+ self.current_size -= limit;
+ buf
+ };
+
+ // If current_size is zero, len must be zero (and if not, len must not be)
+ debug_assert!(
+ !((self.current_size == 0) ^ (self.len == 0)),
+ "Length accounting mismatch: {self:?}"
+ );
+
+ if self.write_waker.is_some() {
+ // We may be able to write again if we have buffer and byte room in the channel
+ if self.can_write() {
+ if let Some(waker) = self.write_waker.take() {
+ waker.wake();
+ }
+ }
+ }
+
+ Ok(Some(BufView::from(JsBuffer::from_parts(buf))))
+ }
+
+ pub fn write(&mut self, buffer: V8Slice<u8>) -> Result<(), V8Slice<u8>> {
+ let next_producer_index = (self.ring_producer + 1) % BUFFER_CHANNEL_SIZE;
+ if next_producer_index == self.ring_consumer {
+ return Err(buffer);
+ }
+
+ self.current_size += buffer.len();
+
+ // SAFETY: we know the ringbuffer bounds are correct
+ unsafe {
+ *self.buffers.get_unchecked_mut(self.ring_producer as usize) =
+ MaybeUninit::new(buffer)
+ };
+ self.ring_producer = next_producer_index;
+ self.len += 1;
+ debug_assert!(self.ring_producer != self.ring_consumer);
+ if let Some(waker) = self.read_waker.take() {
+ waker.wake();
+ }
+ Ok(())
+ }
+
+ pub fn write_error(&mut self, error: AnyError) {
+ self.error = Some(error);
+ if let Some(waker) = self.read_waker.take() {
+ waker.wake();
+ }
+ }
+
+ #[inline(always)]
+ pub fn can_read(&self) -> bool {
+ // Read will return if:
+ // - the stream is closed
+ // - there is an error
+ // - the stream is not empty
+ self.closed
+ || self.error.is_some()
+ || self.ring_consumer != self.ring_producer
+ }
+
+ #[inline(always)]
+ pub fn can_write(&self) -> bool {
+ // Write will return if:
+ // - the stream is closed
+ // - there is an error
+ // - the stream is not full (either buffer or byte count)
+ let next_producer_index = (self.ring_producer + 1) % BUFFER_CHANNEL_SIZE;
+ self.closed
+ || self.error.is_some()
+ || (next_producer_index != self.ring_consumer
+ && self.current_size < BUFFER_BACKPRESSURE_LIMIT)
+ }
+
+ pub fn poll_read_ready(&mut self, cx: &mut Context) -> Poll<()> {
+ if !self.can_read() {
+ self.read_waker = Some(cx.waker().clone());
+ Poll::Pending
+ } else {
+ self.read_waker.take();
+ Poll::Ready(())
+ }
+ }
+
+ pub fn poll_write_ready(&mut self, cx: &mut Context) -> Poll<()> {
+ if !self.can_write() {
+ self.write_waker = Some(cx.waker().clone());
+ Poll::Pending
+ } else {
+ self.write_waker.take();
+ Poll::Ready(())
+ }
+ }
+
+ pub fn close(&mut self) {
+ self.closed = true;
+ // Wake up reads and writes, since they'll both be able to proceed forever now
+ if let Some(waker) = self.write_waker.take() {
+ waker.wake();
+ }
+ if let Some(waker) = self.read_waker.take() {
+ waker.wake();
+ }
+ }
+}
+
+#[repr(transparent)]
+#[derive(Clone, Default)]
+struct BoundedBufferChannel {
+ inner: Rc<RefCell<BoundedBufferChannelInner>>,
+}
+
+impl BoundedBufferChannel {
+ // TODO(mmastrac): in release mode we should be able to make this an UnsafeCell
+ #[inline(always)]
+ fn inner(&self) -> RefMut<BoundedBufferChannelInner> {
+ self.inner.borrow_mut()
+ }
+
+ pub fn into_raw(self) -> *const BoundedBufferChannel {
+ Rc::into_raw(self.inner) as _
+ }
+
+ pub unsafe fn clone_from_raw(ptr: *const BoundedBufferChannel) -> Self {
+ let rc = Rc::from_raw(ptr as *const RefCell<BoundedBufferChannelInner>);
+ let clone = rc.clone();
+ std::mem::forget(rc);
+ std::mem::transmute(clone)
+ }
+
+ pub fn read(&self, limit: usize) -> Result<Option<BufView>, AnyError> {
+ self.inner().read(limit)
+ }
+
+ pub fn write(&self, buffer: V8Slice<u8>) -> Result<(), V8Slice<u8>> {
+ self.inner().write(buffer)
+ }
+
+ pub fn write_error(&self, error: AnyError) {
+ self.inner().write_error(error)
+ }
+
+ pub fn poll_read_ready(&self, cx: &mut Context) -> Poll<()> {
+ self.inner().poll_read_ready(cx)
+ }
+
+ pub fn poll_write_ready(&self, cx: &mut Context) -> Poll<()> {
+ self.inner().poll_write_ready(cx)
+ }
+
+ pub fn closed(&self) -> bool {
+ self.inner().closed
+ }
+
+ #[cfg(test)]
+ pub fn byte_size(&self) -> usize {
+ self.inner().current_size
+ }
+
+ pub fn close(&self) {
+ self.inner().close()
}
}
#[allow(clippy::type_complexity)]
struct ReadableStreamResource {
- reader: AsyncRefCell<
- Peekable<ChannelStreamAdapter<Receiver<Result<BufView, Error>>>>,
- >,
+ read_queue: Rc<TaskQueue>,
+ channel: BoundedBufferChannel,
cancel_handle: CancelHandle,
data: ReadableStreamResourceData,
}
@@ -80,26 +350,15 @@ impl ReadableStreamResource {
async fn read(self: Rc<Self>, limit: usize) -> Result<BufView, AnyError> {
let cancel_handle = self.cancel_handle();
- let peekable = RcRef::map(self, |this| &this.reader);
- let mut peekable = peekable.borrow_mut().await;
- match Pin::new(&mut *peekable)
- .peek_mut()
+ // Serialize all the reads using a task queue.
+ let _read_permit = self.read_queue.acquire().await;
+ poll_fn(|cx| self.channel.poll_read_ready(cx))
.or_cancel(cancel_handle)
- .await?
- {
- None => Ok(BufView::empty()),
- // Take the actual error since we only have a reference to it
- Some(Err(_)) => Err(peekable.next().await.unwrap().err().unwrap()),
- Some(Ok(bytes)) => {
- if bytes.len() <= limit {
- // We can safely take the next item since we peeked it
- return peekable.next().await.unwrap();
- }
- // The remainder of the bytes after we split it is still left in the peek buffer
- let ret = bytes.split_to(limit);
- Ok(ret)
- }
- }
+ .await?;
+ self
+ .channel
+ .read(limit)
+ .map(|buf| buf.unwrap_or_else(BufView::empty))
}
}
@@ -114,6 +373,7 @@ impl Resource for ReadableStreamResource {
fn close(self: Rc<Self>) {
self.cancel_handle.cancel();
+ self.channel.close();
}
}
@@ -163,17 +423,12 @@ impl Future for CompletionHandle {
#[op2(fast)]
#[smi]
pub fn op_readable_stream_resource_allocate(state: &mut OpState) -> ResourceId {
- let (tx, rx) = tokio::sync::mpsc::channel(1);
- let tx = RefCell::new(Some(tx));
let completion = CompletionHandle::default();
- let tx = Box::new(tx);
let resource = ReadableStreamResource {
+ read_queue: Default::default(),
cancel_handle: Default::default(),
- reader: AsyncRefCell::new(ChannelStreamAdapter(rx).peekable()),
- data: ReadableStreamResourceData {
- tx: Box::into_raw(tx),
- completion,
- },
+ channel: BoundedBufferChannel::default(),
+ data: ReadableStreamResourceData { completion },
};
state.resource_table.add(resource)
}
@@ -187,23 +442,19 @@ pub fn op_readable_stream_resource_get_sink(
else {
return std::ptr::null();
};
- resource.data.tx as _
+ resource.channel.clone().into_raw() as _
}
-fn get_sender(sender: *const c_void) -> Option<Sender<Result<BufView, Error>>> {
+fn get_sender(sender: *const c_void) -> BoundedBufferChannel {
// SAFETY: We know this is a valid v8::External
- unsafe {
- (sender as *const SenderCell)
- .as_ref()
- .and_then(|r| r.borrow_mut().as_ref().cloned())
- }
+ unsafe { BoundedBufferChannel::clone_from_raw(sender as _) }
}
fn drop_sender(sender: *const c_void) {
// SAFETY: We know this is a valid v8::External
unsafe {
assert!(!sender.is_null());
- _ = Box::from_raw(sender as *mut SenderCell);
+ _ = Rc::from_raw(sender as *mut RefCell<BoundedBufferChannelInner>);
}
}
@@ -214,10 +465,9 @@ pub fn op_readable_stream_resource_write_buf(
) -> impl Future<Output = bool> {
let sender = get_sender(sender);
async move {
- let Some(sender) = sender else {
- return false;
- };
- sender.send(Ok(buffer.into())).await.ok().is_some()
+ poll_fn(|cx| sender.poll_write_ready(cx)).await;
+ sender.write(buffer.into_parts()).unwrap();
+ !sender.closed()
}
}
@@ -228,20 +478,17 @@ pub fn op_readable_stream_resource_write_error(
) -> impl Future<Output = bool> {
let sender = get_sender(sender);
async move {
- let Some(sender) = sender else {
- return false;
- };
- sender
- .send(Err(type_error(Cow::Owned(error))))
- .await
- .ok()
- .is_some()
+ // We can always write an error, no polling required
+ // TODO(mmastrac): we can remove async from this method
+ sender.write_error(type_error(Cow::Owned(error)));
+ !sender.closed()
}
}
#[op2(fast)]
#[smi]
pub fn op_readable_stream_resource_close(sender: *const c_void) {
+ get_sender(sender).close();
drop_sender(sender);
}
@@ -264,7 +511,6 @@ pub fn op_readable_stream_resource_await_close(
}
struct ReadableStreamResourceData {
- tx: *const SenderCell,
completion: CompletionHandle,
}
@@ -273,3 +519,129 @@ impl Drop for ReadableStreamResourceData {
self.completion.complete(true);
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use deno_core::v8;
+ use std::cell::OnceCell;
+ use std::sync::atomic::AtomicUsize;
+ use std::sync::OnceLock;
+ use std::time::Duration;
+
+ static V8_GLOBAL: OnceLock<()> = OnceLock::new();
+
+ thread_local! {
+ static ISOLATE: OnceCell<std::sync::Mutex<v8::OwnedIsolate>> = OnceCell::new();
+ }
+
+ fn with_isolate<T>(mut f: impl FnMut(&mut v8::Isolate) -> T) -> T {
+ V8_GLOBAL.get_or_init(|| {
+ let platform =
+ v8::new_unprotected_default_platform(0, false).make_shared();
+ v8::V8::initialize_platform(platform);
+ v8::V8::initialize();
+ });
+ ISOLATE.with(|cell| {
+ let mut isolate = cell
+ .get_or_init(|| {
+ std::sync::Mutex::new(v8::Isolate::new(Default::default()))
+ })
+ .try_lock()
+ .unwrap();
+ f(&mut isolate)
+ })
+ }
+
+ fn create_buffer(byte_length: usize) -> V8Slice<u8> {
+ with_isolate(|isolate| {
+ let ptr = v8::ArrayBuffer::new_backing_store(isolate, byte_length);
+ // SAFETY: we just made this
+ unsafe { V8Slice::from_parts(ptr.into(), 0..byte_length) }
+ })
+ }
+
+ #[test]
+ fn test_bounded_buffer_channel() {
+ let channel = BoundedBufferChannel::default();
+
+ for _ in 0..BUFFER_CHANNEL_SIZE - 1 {
+ channel.write(create_buffer(1024)).unwrap();
+ }
+ }
+
+ #[tokio::test(flavor = "current_thread")]
+ async fn test_multi_task() {
+ let channel = BoundedBufferChannel::default();
+ let channel_send = channel.clone();
+
+ // Fast writer
+ let a = deno_core::unsync::spawn(async move {
+ for _ in 0..BUFFER_CHANNEL_SIZE * 2 {
+ poll_fn(|cx| channel_send.poll_write_ready(cx)).await;
+ channel_send
+ .write(create_buffer(BUFFER_AGGREGATION_LIMIT))
+ .unwrap();
+ }
+ });
+
+ // Slightly slower reader
+ let b = deno_core::unsync::spawn(async move {
+ for _ in 0..BUFFER_CHANNEL_SIZE * 2 {
+ tokio::time::sleep(Duration::from_millis(1)).await;
+ poll_fn(|cx| channel.poll_read_ready(cx)).await;
+ channel.read(BUFFER_AGGREGATION_LIMIT).unwrap();
+ }
+ });
+
+ a.await.unwrap();
+ b.await.unwrap();
+ }
+
+ #[tokio::test(flavor = "current_thread")]
+ async fn test_multi_task_small_reads() {
+ let channel = BoundedBufferChannel::default();
+ let channel_send = channel.clone();
+
+ let total_send = Rc::new(AtomicUsize::new(0));
+ let total_send_task = total_send.clone();
+ let total_recv = Rc::new(AtomicUsize::new(0));
+ let total_recv_task = total_recv.clone();
+
+ // Fast writer
+ let a = deno_core::unsync::spawn(async move {
+ for _ in 0..BUFFER_CHANNEL_SIZE * 2 {
+ poll_fn(|cx| channel_send.poll_write_ready(cx)).await;
+ channel_send.write(create_buffer(16)).unwrap();
+ total_send_task.fetch_add(16, std::sync::atomic::Ordering::SeqCst);
+ }
+ // We need to close because we may get aggregated packets and we want a signal
+ channel_send.close();
+ });
+
+ // Slightly slower reader
+ let b = deno_core::unsync::spawn(async move {
+ for _ in 0..BUFFER_CHANNEL_SIZE * 2 {
+ poll_fn(|cx| channel.poll_read_ready(cx)).await;
+ // We want to make sure we're aggregating at least some packets
+ while channel.byte_size() <= 16 && !channel.closed() {
+ tokio::time::sleep(Duration::from_millis(1)).await;
+ }
+ let len = channel
+ .read(1024)
+ .unwrap()
+ .map(|b| b.len())
+ .unwrap_or_default();
+ total_recv_task.fetch_add(len, std::sync::atomic::Ordering::SeqCst);
+ }
+ });
+
+ a.await.unwrap();
+ b.await.unwrap();
+
+ assert_eq!(
+ total_send.load(std::sync::atomic::Ordering::SeqCst),
+ total_recv.load(std::sync::atomic::Ordering::SeqCst)
+ );
+ }
+}