summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatt Mastracci <matthew@mastracci.com>2023-04-23 14:07:37 -0600
committerGitHub <noreply@github.com>2023-04-23 14:07:37 -0600
commitfafb2584efec33152fbe353d94151fa36004586a (patch)
tree839afc382be75b955abab77edd18cb9a9dbfb6bb
parentc95477c49f16a753a9d25b46014fabfd3c7eb9e6 (diff)
refactor(ext/websocket): Remove dep on tungstenite by reworking code (#18812)
-rw-r--r--ext/websocket/lib.rs57
-rw-r--r--ext/websocket/stream.rs15
2 files changed, 40 insertions, 32 deletions
diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs
index 71aa66ff3..943b5d47c 100644
--- a/ext/websocket/lib.rs
+++ b/ext/websocket/lib.rs
@@ -38,11 +38,12 @@ use std::future::Future;
use std::path::PathBuf;
use std::rc::Rc;
use std::sync::Arc;
+use tokio::io::AsyncRead;
+use tokio::io::AsyncWrite;
use tokio::net::TcpStream;
use tokio_rustls::rustls::RootCertStore;
use tokio_rustls::rustls::ServerName;
use tokio_rustls::TlsConnector;
-use tokio_tungstenite::MaybeTlsStream;
use fastwebsockets::CloseCode;
use fastwebsockets::FragmentCollector;
@@ -129,6 +130,33 @@ pub struct CreateResponse {
extensions: String,
}
+async fn handshake<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
+ cancel_resource: Option<Rc<CancelHandle>>,
+ request: Request<Body>,
+ socket: S,
+) -> Result<(WebSocket<WebSocketStream>, http::Response<Body>), AnyError> {
+ let client =
+ fastwebsockets::handshake::client(&LocalExecutor, request, socket);
+
+ let (upgraded, response) = if let Some(cancel_resource) = cancel_resource {
+ client.or_cancel(cancel_resource).await?
+ } else {
+ client.await
+ }
+ .map_err(|err| {
+ DomExceptionNetworkError::new(&format!(
+ "failed to connect to WebSocket: {err}"
+ ))
+ })?;
+
+ let upgraded = upgraded.into_inner();
+ let stream =
+ WebSocketStream::new(stream::WsStreamKind::Upgraded(upgraded), None);
+ let stream = WebSocket::after_handshake(stream, Role::Client);
+
+ Ok((stream, response))
+}
+
#[op]
pub async fn op_ws_create<WP>(
state: Rc<RefCell<OpState>>,
@@ -155,7 +183,7 @@ where
.borrow_mut()
.resource_table
.get::<WsCancelResource>(cancel_rid)?;
- Some(r)
+ Some(r.0.clone())
} else {
None
};
@@ -223,8 +251,8 @@ where
let addr = format!("{domain}:{port}");
let tcp_socket = TcpStream::connect(addr).await?;
- let socket: MaybeTlsStream<TcpStream> = match uri.scheme_str() {
- Some("ws") => MaybeTlsStream::Plain(tcp_socket),
+ let (stream, response) = match uri.scheme_str() {
+ Some("ws") => handshake(cancel_resource, request, tcp_socket).await?,
Some("wss") => {
let tls_config = create_client_config(
root_cert_store,
@@ -236,30 +264,11 @@ where
let dnsname = ServerName::try_from(domain.as_str())
.map_err(|_| invalid_hostname(domain))?;
let tls_socket = tls_connector.connect(dnsname, tcp_socket).await?;
- MaybeTlsStream::Rustls(tls_socket)
+ handshake(cancel_resource, request, tls_socket).await?
}
_ => unreachable!(),
};
- let client =
- fastwebsockets::handshake::client(&LocalExecutor, request, socket);
-
- let (upgraded, response) = if let Some(cancel_resource) = cancel_resource {
- client.or_cancel(cancel_resource.0.to_owned()).await?
- } else {
- client.await
- }
- .map_err(|err| {
- DomExceptionNetworkError::new(&format!(
- "failed to connect to WebSocket: {err}"
- ))
- })?;
-
- let inner = MaybeTlsStream::Plain(upgraded.into_inner());
- let stream =
- WebSocketStream::new(stream::WsStreamKind::Tungstenite(inner), None);
- let stream = WebSocket::after_handshake(stream, Role::Client);
-
if let Some(cancel_rid) = cancel_handle {
state.borrow_mut().resource_table.close(cancel_rid).ok();
}
diff --git a/ext/websocket/stream.rs b/ext/websocket/stream.rs
index 69c06b7eb..6f93406f6 100644
--- a/ext/websocket/stream.rs
+++ b/ext/websocket/stream.rs
@@ -8,11 +8,10 @@ use std::task::Poll;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::ReadBuf;
-use tokio_tungstenite::MaybeTlsStream;
// TODO(bartlomieju): remove this
pub(crate) enum WsStreamKind {
- Tungstenite(MaybeTlsStream<Upgraded>),
+ Upgraded(Upgraded),
Network(NetworkStream),
}
@@ -54,7 +53,7 @@ impl AsyncRead for WebSocketStream {
}
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_read(cx, buf),
- WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_read(cx, buf),
+ WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
@@ -67,7 +66,7 @@ impl AsyncWrite for WebSocketStream {
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_write(cx, buf),
- WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_write(cx, buf),
+ WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
@@ -77,7 +76,7 @@ impl AsyncWrite for WebSocketStream {
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_flush(cx),
- WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_flush(cx),
+ WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_flush(cx),
}
}
@@ -87,14 +86,14 @@ impl AsyncWrite for WebSocketStream {
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.stream {
WsStreamKind::Network(stream) => Pin::new(stream).poll_shutdown(cx),
- WsStreamKind::Tungstenite(stream) => Pin::new(stream).poll_shutdown(cx),
+ WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
fn is_write_vectored(&self) -> bool {
match &self.stream {
WsStreamKind::Network(stream) => stream.is_write_vectored(),
- WsStreamKind::Tungstenite(stream) => stream.is_write_vectored(),
+ WsStreamKind::Upgraded(stream) => stream.is_write_vectored(),
}
}
@@ -107,7 +106,7 @@ impl AsyncWrite for WebSocketStream {
WsStreamKind::Network(stream) => {
Pin::new(stream).poll_write_vectored(cx, bufs)
}
- WsStreamKind::Tungstenite(stream) => {
+ WsStreamKind::Upgraded(stream) => {
Pin::new(stream).poll_write_vectored(cx, bufs)
}
}