diff options
author | Matt Mastracci <matthew@mastracci.com> | 2023-04-24 23:24:40 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-24 23:24:40 +0200 |
commit | bb74e75a049768c2949aa08de6752a16813b97de (patch) | |
tree | 0c6af6d5ef1b8e8aff878d9e2aa8c32bee1c4c39 /ext/http/network_buffered_stream.rs | |
parent | 0e97fa4d5f056e12d3c0704bfb7bcdc56316ef94 (diff) |
feat(ext/http): h2c for http/2 (#18817)
This implements HTTP/2 prior-knowledge connections, allowing clients to
request HTTP/2 over plaintext or TLS-without-ALPN connections. If a
client requests a specific protocol via ALPN (`h2` or `http/1.1`),
however, the protocol is forced and must be used.
Diffstat (limited to 'ext/http/network_buffered_stream.rs')
-rw-r--r-- | ext/http/network_buffered_stream.rs | 284 |
1 files changed, 284 insertions, 0 deletions
diff --git a/ext/http/network_buffered_stream.rs b/ext/http/network_buffered_stream.rs new file mode 100644 index 000000000..e4b2ee895 --- /dev/null +++ b/ext/http/network_buffered_stream.rs @@ -0,0 +1,284 @@ +// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license. + +use bytes::Bytes; +use deno_core::futures::future::poll_fn; +use deno_core::futures::ready; +use std::io; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::Poll; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; + +const MAX_PREFIX_SIZE: usize = 256; + +pub struct NetworkStreamPrefixCheck<S: AsyncRead + Unpin> { + io: S, + prefix: &'static [u8], + buffer: [MaybeUninit<u8>; MAX_PREFIX_SIZE * 2], +} + +impl<S: AsyncRead + Unpin> NetworkStreamPrefixCheck<S> { + pub fn new(io: S, prefix: &'static [u8]) -> Self { + debug_assert!(prefix.len() < MAX_PREFIX_SIZE); + Self { + io, + prefix, + buffer: [MaybeUninit::<u8>::uninit(); MAX_PREFIX_SIZE * 2], + } + } + + // Returns a [`NetworkBufferedStream`], rewound with the bytes we read to determine what + // type of stream this is. + pub async fn match_prefix( + self, + ) -> io::Result<(bool, NetworkBufferedStream<S>)> { + let mut buffer = self.buffer; + let mut readbuf = ReadBuf::uninit(&mut buffer); + let mut io = self.io; + let prefix = self.prefix; + loop { + enum State { + Unknown, + Matched, + NotMatched, + } + + let state = poll_fn(|cx| { + let filled_len = readbuf.filled().len(); + let res = ready!(Pin::new(&mut io).poll_read(cx, &mut readbuf)); + if let Err(e) = res { + return Poll::Ready(Err(e)); + } + let filled = readbuf.filled(); + let new_len = filled.len(); + if new_len == filled_len { + // Empty read, no match + return Poll::Ready(Ok(State::NotMatched)); + } else if new_len < prefix.len() { + // Read less than prefix, make sure we're still matching the prefix (early exit) + if !prefix.starts_with(filled) { + return Poll::Ready(Ok(State::NotMatched)); + } + } else if new_len >= prefix.len() { + // We have enough to determine + if filled.starts_with(prefix) { + return Poll::Ready(Ok(State::Matched)); + } else { + return Poll::Ready(Ok(State::NotMatched)); + } + } + + Poll::Ready(Ok(State::Unknown)) + }) + .await?; + + match state { + State::Unknown => continue, + State::Matched => { + let initialized_len = readbuf.filled().len(); + return Ok(( + true, + NetworkBufferedStream::new(io, buffer, initialized_len), + )); + } + State::NotMatched => { + let initialized_len = readbuf.filled().len(); + return Ok(( + false, + NetworkBufferedStream::new(io, buffer, initialized_len), + )); + } + } + } + } +} + +pub struct NetworkBufferedStream<S: AsyncRead + Unpin> { + io: S, + initialized_len: usize, + prefix_offset: usize, + prefix: [MaybeUninit<u8>; MAX_PREFIX_SIZE * 2], + prefix_read: bool, +} + +impl<S: AsyncRead + Unpin> NetworkBufferedStream<S> { + fn new( + io: S, + prefix: [MaybeUninit<u8>; MAX_PREFIX_SIZE * 2], + initialized_len: usize, + ) -> Self { + Self { + io, + initialized_len, + prefix_offset: 0, + prefix, + prefix_read: false, + } + } + + fn current_slice(&self) -> &[u8] { + // We trust that these bytes are initialized properly + let slice = &self.prefix[self.prefix_offset..self.initialized_len]; + + // This guarantee comes from slice_assume_init_ref (we can't use that until it's stable) + + // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that + // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`. + // The pointer obtained is valid since it refers to memory owned by `slice` which is a + // reference and thus guaranteed to be valid for reads. + + unsafe { &*(slice as *const [_] as *const [u8]) as _ } + } + + pub fn into_inner(self) -> (S, Bytes) { + let bytes = Bytes::copy_from_slice(self.current_slice()); + (self.io, bytes) + } +} + +impl<S: AsyncRead + Unpin> AsyncRead for NetworkBufferedStream<S> { + // From hyper's Rewind (https://github.com/hyperium/hyper), MIT License, Copyright (c) Sean McArthur + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + if !self.prefix_read { + let prefix = self.current_slice(); + + // If there are no remaining bytes, let the bytes get dropped. + if !prefix.is_empty() { + let copy_len = std::cmp::min(prefix.len(), buf.remaining()); + buf.put_slice(&prefix[..copy_len]); + self.prefix_offset += copy_len; + + return Poll::Ready(Ok(())); + } else { + self.prefix_read = true; + } + } + Pin::new(&mut self.io).poll_read(cx, buf) + } +} + +impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite + for NetworkBufferedStream<S> +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll<Result<usize, std::io::Error>> { + Pin::new(&mut self.io).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), std::io::Error>> { + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), std::io::Error>> { + Pin::new(&mut self.io).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll<Result<usize, std::io::Error>> { + Pin::new(&mut self.io).poll_write_vectored(cx, bufs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncReadExt; + + struct YieldsOneByteAtATime(&'static [u8]); + + impl AsyncRead for YieldsOneByteAtATime { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + if let Some((head, tail)) = self.as_mut().0.split_first() { + self.as_mut().0 = tail; + let dest = buf.initialize_unfilled_to(1); + dest[0] = *head; + buf.advance(1); + } + Poll::Ready(Ok(())) + } + } + + async fn test( + io: impl AsyncRead + Unpin, + prefix: &'static [u8], + expect_match: bool, + expect_string: &'static str, + ) -> io::Result<()> { + let (matches, mut io) = NetworkStreamPrefixCheck::new(io, prefix) + .match_prefix() + .await?; + assert_eq!(matches, expect_match); + let mut s = String::new(); + Pin::new(&mut io).read_to_string(&mut s).await?; + assert_eq!(s, expect_string); + Ok(()) + } + + #[tokio::test] + async fn matches_prefix_simple() -> io::Result<()> { + let buf = b"prefix match".as_slice(); + test(buf, b"prefix", true, "prefix match").await + } + + #[tokio::test] + async fn matches_prefix_exact() -> io::Result<()> { + let buf = b"prefix".as_slice(); + test(buf, b"prefix", true, "prefix").await + } + + #[tokio::test] + async fn not_matches_prefix_simple() -> io::Result<()> { + let buf = b"prefill match".as_slice(); + test(buf, b"prefix", false, "prefill match").await + } + + #[tokio::test] + async fn not_matches_prefix_short() -> io::Result<()> { + let buf = b"nope".as_slice(); + test(buf, b"prefix", false, "nope").await + } + + #[tokio::test] + async fn not_matches_prefix_empty() -> io::Result<()> { + let buf = b"".as_slice(); + test(buf, b"prefix", false, "").await + } + + #[tokio::test] + async fn matches_one_byte_at_a_time() -> io::Result<()> { + let buf = YieldsOneByteAtATime(b"prefix"); + test(buf, b"prefix", true, "prefix").await + } + + #[tokio::test] + async fn not_matches_one_byte_at_a_time() -> io::Result<()> { + let buf = YieldsOneByteAtATime(b"prefill"); + test(buf, b"prefix", false, "prefill").await + } +} |