summaryrefslogtreecommitdiff
path: root/ext/websocket/stream.rs
diff options
context:
space:
mode:
authorMatt Mastracci <matthew@mastracci.com>2023-04-22 11:48:21 -0600
committerGitHub <noreply@github.com>2023-04-22 11:48:21 -0600
commitbdffcb409fd1e257db280ab73e07cc319711256c (patch)
tree9aca1c1e73f0249bba8b66781b79c358a7a00798 /ext/websocket/stream.rs
parentd137501a639cb315772866f6775fcd9f43e28f5b (diff)
feat(ext/http): Rework Deno.serve using hyper 1.0-rc3 (#18619)
This is a rewrite of the `Deno.serve` API to live on top of hyper 1.0-rc3. The code should be more maintainable long-term, and avoids some of the slower mpsc patterns that made the older code less efficient than it could have been. Missing features: - `upgradeHttp` and `upgradeHttpRaw` (`upgradeWebSocket` is available, however). - Automatic compression is unavailable on responses.
Diffstat (limited to 'ext/websocket/stream.rs')
-rw-r--r--ext/websocket/stream.rs115
1 files changed, 115 insertions, 0 deletions
diff --git a/ext/websocket/stream.rs b/ext/websocket/stream.rs
new file mode 100644
index 000000000..69c06b7eb
--- /dev/null
+++ b/ext/websocket/stream.rs
@@ -0,0 +1,115 @@
+// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
+use bytes::Buf;
+use bytes::Bytes;
+use deno_net::raw::NetworkStream;
+use hyper::upgrade::Upgraded;
+use std::pin::Pin;
+use std::task::Poll;
+use tokio::io::AsyncRead;
+use tokio::io::AsyncWrite;
+use tokio::io::ReadBuf;
+use tokio_tungstenite::MaybeTlsStream;
+
+// TODO(bartlomieju): remove this
+pub(crate) enum WsStreamKind {
+ Tungstenite(MaybeTlsStream<Upgraded>),
+ Network(NetworkStream),
+}
+
+pub(crate) struct WebSocketStream {
+ stream: WsStreamKind,
+ pre: Option<Bytes>,
+}
+
+impl WebSocketStream {
+ pub fn new(stream: WsStreamKind, buffer: Option<Bytes>) -> Self {
+ Self {
+ stream,
+ pre: buffer,
+ }
+ }
+}
+
+impl AsyncRead for WebSocketStream {
+ // From hyper's Rewind (https://github.com/hyperium/hyper), MIT License, Copyright (c) Sean McArthur
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ if let Some(mut prefix) = self.pre.take() {
+ // If there are no remaining bytes, let the bytes get dropped.
+ if !prefix.is_empty() {
+ let copy_len = std::cmp::min(prefix.len(), buf.remaining());
+ // TODO: There should be a way to do following two lines cleaner...
+ buf.put_slice(&prefix[..copy_len]);
+ prefix.advance(copy_len);
+ // Put back what's left
+ if !prefix.is_empty() {
+ self.pre = Some(prefix);
+ }
+
+ return Poll::Ready(Ok(()));
+ }
+ }
+ match &mut self.stream {
+ WsStreamKind::Network(stream) => Pin::new(stream).poll_read(cx, buf),
+ WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_read(cx, buf),
+ }
+ }
+}
+
+impl AsyncWrite for WebSocketStream {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &[u8],
+ ) -> std::task::Poll<Result<usize, std::io::Error>> {
+ match &mut self.stream {
+ WsStreamKind::Network(stream) => Pin::new(stream).poll_write(cx, buf),
+ WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_write(cx, buf),
+ }
+ }
+
+ fn poll_flush(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), std::io::Error>> {
+ match &mut self.stream {
+ WsStreamKind::Network(stream) => Pin::new(stream).poll_flush(cx),
+ WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_flush(cx),
+ }
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), std::io::Error>> {
+ match &mut self.stream {
+ WsStreamKind::Network(stream) => Pin::new(stream).poll_shutdown(cx),
+ WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_shutdown(cx),
+ }
+ }
+
+ fn is_write_vectored(&self) -> bool {
+ match &self.stream {
+ WsStreamKind::Network(stream) => stream.is_write_vectored(),
+ WsStreamKind::Tungstenite(stream) => stream.is_write_vectored(),
+ }
+ }
+
+ fn poll_write_vectored(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ bufs: &[std::io::IoSlice<'_>],
+ ) -> std::task::Poll<Result<usize, std::io::Error>> {
+ match &mut self.stream {
+ WsStreamKind::Network(stream) => {
+ Pin::new(stream).poll_write_vectored(cx, bufs)
+ }
+ WsStreamKind::Tungstenite(stream) => {
+ Pin::new(stream).poll_write_vectored(cx, bufs)
+ }
+ }
+ }
+}