diff options
author | Matt Mastracci <matthew@mastracci.com> | 2023-04-23 14:07:37 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-23 14:07:37 -0600 |
commit | fafb2584efec33152fbe353d94151fa36004586a (patch) | |
tree | 839afc382be75b955abab77edd18cb9a9dbfb6bb /ext/websocket/lib.rs | |
parent | c95477c49f16a753a9d25b46014fabfd3c7eb9e6 (diff) |
refactor(ext/websocket): Remove dep on tungstenite by reworking code (#18812)
Diffstat (limited to 'ext/websocket/lib.rs')
-rw-r--r-- | ext/websocket/lib.rs | 57 |
1 files changed, 33 insertions, 24 deletions
diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index 71aa66ff3..943b5d47c 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -38,11 +38,12 @@ use std::future::Future; use std::path::PathBuf; use std::rc::Rc; use std::sync::Arc; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; use tokio::net::TcpStream; use tokio_rustls::rustls::RootCertStore; use tokio_rustls::rustls::ServerName; use tokio_rustls::TlsConnector; -use tokio_tungstenite::MaybeTlsStream; use fastwebsockets::CloseCode; use fastwebsockets::FragmentCollector; @@ -129,6 +130,33 @@ pub struct CreateResponse { extensions: String, } +async fn handshake<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>( + cancel_resource: Option<Rc<CancelHandle>>, + request: Request<Body>, + socket: S, +) -> Result<(WebSocket<WebSocketStream>, http::Response<Body>), AnyError> { + let client = + fastwebsockets::handshake::client(&LocalExecutor, request, socket); + + let (upgraded, response) = if let Some(cancel_resource) = cancel_resource { + client.or_cancel(cancel_resource).await? + } else { + client.await + } + .map_err(|err| { + DomExceptionNetworkError::new(&format!( + "failed to connect to WebSocket: {err}" + )) + })?; + + let upgraded = upgraded.into_inner(); + let stream = + WebSocketStream::new(stream::WsStreamKind::Upgraded(upgraded), None); + let stream = WebSocket::after_handshake(stream, Role::Client); + + Ok((stream, response)) +} + #[op] pub async fn op_ws_create<WP>( state: Rc<RefCell<OpState>>, @@ -155,7 +183,7 @@ where .borrow_mut() .resource_table .get::<WsCancelResource>(cancel_rid)?; - Some(r) + Some(r.0.clone()) } else { None }; @@ -223,8 +251,8 @@ where let addr = format!("{domain}:{port}"); let tcp_socket = TcpStream::connect(addr).await?; - let socket: MaybeTlsStream<TcpStream> = match uri.scheme_str() { - Some("ws") => MaybeTlsStream::Plain(tcp_socket), + let (stream, response) = match uri.scheme_str() { + Some("ws") => handshake(cancel_resource, request, tcp_socket).await?, Some("wss") => { let tls_config = create_client_config( root_cert_store, @@ -236,30 +264,11 @@ where let dnsname = ServerName::try_from(domain.as_str()) .map_err(|_| invalid_hostname(domain))?; let tls_socket = tls_connector.connect(dnsname, tcp_socket).await?; - MaybeTlsStream::Rustls(tls_socket) + handshake(cancel_resource, request, tls_socket).await? } _ => unreachable!(), }; - let client = - fastwebsockets::handshake::client(&LocalExecutor, request, socket); - - let (upgraded, response) = if let Some(cancel_resource) = cancel_resource { - client.or_cancel(cancel_resource.0.to_owned()).await? - } else { - client.await - } - .map_err(|err| { - DomExceptionNetworkError::new(&format!( - "failed to connect to WebSocket: {err}" - )) - })?; - - let inner = MaybeTlsStream::Plain(upgraded.into_inner()); - let stream = - WebSocketStream::new(stream::WsStreamKind::Tungstenite(inner), None); - let stream = WebSocket::after_handshake(stream, Role::Client); - if let Some(cancel_rid) = cancel_handle { state.borrow_mut().resource_table.close(cancel_rid).ok(); } |