diff options
author | Matt Mastracci <matthew@mastracci.com> | 2023-11-01 15:11:01 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-01 21:11:01 +0000 |
commit | 42c426e7695a0037032d1ac5237830800eeaaed4 (patch) | |
tree | 242f9aa30187464f1b6314387654a76d8dc76fc0 /ext/websocket | |
parent | 587f2e0800a55e58b2579758d4278a4129b609c0 (diff) |
feat(ext/websocket): websockets over http2 (#21040)
Implements `WebSocket` over http/2. This requires a conformant http/2
server supporting the extended connect protocol.
Passes approximately 100 new WPT tests (mostly `?wpt_flags=h2` versions
of existing websockets APIs).
This is implemented as a fallback when http/1.1 fails, so a server that
supports both h1 and h2 WebSockets will still end up on the http/1.1
upgrade path.
The patch also cleas up the websockets handshake to split it up into
http, https+http1 and https+http2, making it a little less intertwined.
This uncovered a likely bug in the WPT test server:
https://github.com/web-platform-tests/wpt/issues/42896
Diffstat (limited to 'ext/websocket')
-rw-r--r-- | ext/websocket/Cargo.toml | 1 | ||||
-rw-r--r-- | ext/websocket/lib.rs | 332 | ||||
-rw-r--r-- | ext/websocket/stream.rs | 63 |
3 files changed, 292 insertions, 104 deletions
diff --git a/ext/websocket/Cargo.toml b/ext/websocket/Cargo.toml index da29203c4..a643e25a0 100644 --- a/ext/websocket/Cargo.toml +++ b/ext/websocket/Cargo.toml @@ -19,6 +19,7 @@ deno_core.workspace = true deno_net.workspace = true deno_tls.workspace = true fastwebsockets = { workspace = true, features = ["upgrade", "unstable-split"] } +h2.workspace = true http.workspace = true hyper = { workspace = true, features = ["backports"] } once_cell.workspace = true diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index ac40b8304..c2599f6f6 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -1,16 +1,19 @@ // Copyright 2018-2023 the Deno authors. All rights reserved. MIT license. use crate::stream::WebSocketStream; use bytes::Bytes; +use deno_core::anyhow::bail; use deno_core::error::invalid_hostname; use deno_core::error::type_error; use deno_core::error::AnyError; +use deno_core::futures::TryFutureExt; use deno_core::op2; +use deno_core::unsync::spawn; use deno_core::url; use deno_core::AsyncMutFuture; use deno_core::AsyncRefCell; use deno_core::ByteString; -use deno_core::CancelFuture; use deno_core::CancelHandle; +use deno_core::CancelTryFuture; use deno_core::JsBuffer; use deno_core::OpState; use deno_core::RcRef; @@ -19,13 +22,16 @@ use deno_core::ResourceId; use deno_core::ToJsBuffer; use deno_net::raw::NetworkStream; use deno_tls::create_client_config; +use deno_tls::rustls::ClientConfig; use deno_tls::RootCertStoreProvider; +use deno_tls::SocketUse; use http::header::CONNECTION; use http::header::UPGRADE; use http::HeaderName; use http::HeaderValue; use http::Method; use http::Request; +use http::StatusCode; use http::Uri; use hyper::Body; use once_cell::sync::Lazy; @@ -146,66 +152,175 @@ pub struct CreateResponse { extensions: String, } -async fn handshake<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>( - cancel_resource: Option<Rc<CancelHandle>>, +async fn handshake_websocket( + state: &Rc<RefCell<OpState>>, + uri: &Uri, + protocols: &str, + headers: Option<Vec<(ByteString, ByteString)>>, +) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> { + let mut request = Request::builder().method(Method::GET).uri( + uri + .path_and_query() + .ok_or(type_error("Missing path in url".to_string()))? + .as_str(), + ); + + let authority = uri.authority().unwrap().as_str(); + let host = authority + .find('@') + .map(|idx| authority.split_at(idx + 1).1) + .unwrap_or_else(|| authority); + request = request + .header("Host", host) + .header(UPGRADE, "websocket") + .header(CONNECTION, "Upgrade") + .header( + "Sec-WebSocket-Key", + fastwebsockets::handshake::generate_key(), + ); + + let user_agent = state.borrow().borrow::<WsUserAgent>().0.clone(); + request = + populate_common_request_headers(request, &user_agent, protocols, &headers)?; + + let request = request.body(Body::empty())?; + let domain = &uri.host().unwrap().to_string(); + let port = &uri.port_u16().unwrap_or(match uri.scheme_str() { + Some("wss") => 443, + Some("ws") => 80, + _ => unreachable!(), + }); + let addr = format!("{domain}:{port}"); + + let res = match uri.scheme_str() { + Some("ws") => handshake_http1_ws(request, &addr).await?, + Some("wss") => { + match handshake_http1_wss(state, request, domain, &addr).await { + Ok(res) => res, + Err(_) => { + handshake_http2_wss( + state, + uri, + authority, + &user_agent, + protocols, + domain, + &headers, + &addr, + ) + .await? + } + } + } + _ => unreachable!(), + }; + Ok(res) +} + +async fn handshake_http1_ws( request: Request<Body>, - socket: S, -) -> Result<(WebSocket<WebSocketStream>, http::Response<Body>), AnyError> { - let client = - fastwebsockets::handshake::client(&LocalExecutor, request, socket); + addr: &String, +) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> { + let tcp_socket = TcpStream::connect(addr).await?; + handshake_connection(request, tcp_socket).await +} - let (upgraded, response) = if let Some(cancel_resource) = cancel_resource { - client.or_cancel(cancel_resource).await? - } else { - client.await +async fn handshake_http1_wss( + state: &Rc<RefCell<OpState>>, + request: Request<Body>, + domain: &str, + addr: &str, +) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> { + let tcp_socket = TcpStream::connect(addr).await?; + let tls_config = create_ws_client_config(state, SocketUse::Http1Only)?; + let dnsname = + ServerName::try_from(domain).map_err(|_| invalid_hostname(domain))?; + let mut tls_connector = TlsStream::new_client_side( + tcp_socket, + tls_config.into(), + dnsname, + NonZeroUsize::new(65536), + ); + // If we can bail on an http/1.1 ALPN mismatch here, we can avoid doing extra work + tls_connector.handshake().await?; + handshake_connection(request, tls_connector).await +} + +#[allow(clippy::too_many_arguments)] +async fn handshake_http2_wss( + state: &Rc<RefCell<OpState>>, + uri: &Uri, + authority: &str, + user_agent: &str, + protocols: &str, + domain: &str, + headers: &Option<Vec<(ByteString, ByteString)>>, + addr: &str, +) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> { + let tcp_socket = TcpStream::connect(addr).await?; + let tls_config = create_ws_client_config(state, SocketUse::Http2Only)?; + let dnsname = + ServerName::try_from(domain).map_err(|_| invalid_hostname(domain))?; + // We need to better expose the underlying errors here + let mut tls_connector = + TlsStream::new_client_side(tcp_socket, tls_config.into(), dnsname, None); + let handshake = tls_connector.handshake().await?; + if handshake.alpn.is_none() { + bail!("Didn't receive h2 alpn, aborting connection"); } - .map_err(|err| { - DomExceptionNetworkError::new(&format!( - "failed to connect to WebSocket: {err}" - )) - })?; + let h2 = h2::client::Builder::new(); + let (mut send, conn) = h2.handshake::<_, Bytes>(tls_connector).await?; + spawn(conn); + let mut request = Request::builder(); + request = request.method(Method::CONNECT); + let uri = Uri::builder() + .authority(authority) + .path_and_query(uri.path_and_query().unwrap().as_str()) + .scheme("https") + .build()?; + request = request.uri(uri); + request = + populate_common_request_headers(request, user_agent, protocols, headers)?; + request = request.extension(h2::ext::Protocol::from("websocket")); + let (resp, send) = send.send_request(request.body(())?, false)?; + let resp = resp.await?; + if resp.status() != StatusCode::OK { + bail!("Invalid status code: {}", resp.status()); + } + let (http::response::Parts { headers, .. }, recv) = resp.into_parts(); + let mut stream = WebSocket::after_handshake( + WebSocketStream::new(stream::WsStreamKind::H2(send, recv), None), + Role::Client, + ); + // We currently don't support vectored writes in the H2 streams + stream.set_writev(false); + // TODO(mmastrac): we should be able to use a zero masking key over HTTPS + // stream.set_auto_apply_mask(false); + Ok((stream, headers)) +} + +async fn handshake_connection< + S: AsyncRead + AsyncWrite + Send + Unpin + 'static, +>( + request: Request<Body>, + socket: S, +) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> { + let (upgraded, response) = + fastwebsockets::handshake::client(&LocalExecutor, request, socket).await?; let upgraded = upgraded.into_inner(); let stream = WebSocketStream::new(stream::WsStreamKind::Upgraded(upgraded), None); let stream = WebSocket::after_handshake(stream, Role::Client); - Ok((stream, response)) + Ok((stream, response.into_parts().0.headers)) } -#[op2(async)] -#[serde] -pub async fn op_ws_create<WP>( - state: Rc<RefCell<OpState>>, - #[string] api_name: String, - #[string] url: String, - #[string] protocols: String, - #[smi] cancel_handle: Option<ResourceId>, - #[serde] headers: Option<Vec<(ByteString, ByteString)>>, -) -> Result<CreateResponse, AnyError> -where - WP: WebSocketPermissions + 'static, -{ - { - let mut s = state.borrow_mut(); - s.borrow_mut::<WP>() - .check_net_url(&url::Url::parse(&url)?, &api_name) - .expect( - "Permission check should have been done in op_ws_check_permission", - ); - } - - let cancel_resource = if let Some(cancel_rid) = cancel_handle { - let r = state - .borrow_mut() - .resource_table - .get::<WsCancelResource>(cancel_rid)?; - Some(r.0.clone()) - } else { - None - }; - - let unsafely_ignore_certificate_errors = state +pub fn create_ws_client_config( + state: &Rc<RefCell<OpState>>, + socket_use: SocketUse, +) -> Result<ClientConfig, AnyError> { + let unsafely_ignore_certificate_errors: Option<Vec<String>> = state .borrow() .try_borrow::<UnsafelyIgnoreCertificateErrors>() .and_then(|it| it.0.clone()); @@ -213,29 +328,25 @@ where .borrow() .borrow::<WsRootStoreProvider>() .get_or_try_init()?; - let user_agent = state.borrow().borrow::<WsUserAgent>().0.clone(); - let uri: Uri = url.parse()?; - let mut request = Request::builder().method(Method::GET).uri( - uri - .path_and_query() - .ok_or(type_error("Missing path in url".to_string()))? - .as_str(), - ); - let authority = uri.authority().unwrap().as_str(); - let host = authority - .find('@') - .map(|idx| authority.split_at(idx + 1).1) - .unwrap_or_else(|| authority); + create_client_config( + root_cert_store, + vec![], + unsafely_ignore_certificate_errors, + None, + socket_use, + ) +} + +/// Headers common to both http/1.1 and h2 requests. +fn populate_common_request_headers( + mut request: http::request::Builder, + user_agent: &str, + protocols: &str, + headers: &Option<Vec<(ByteString, ByteString)>>, +) -> Result<http::request::Builder, AnyError> { request = request .header("User-Agent", user_agent) - .header("Host", host) - .header(UPGRADE, "websocket") - .header(CONNECTION, "Upgrade") - .header( - "Sec-WebSocket-Key", - fastwebsockets::handshake::generate_key(), - ) .header("Sec-WebSocket-Version", "13"); if !protocols.is_empty() { @@ -244,9 +355,9 @@ where if let Some(headers) = headers { for (key, value) in headers { - let name = HeaderName::from_bytes(&key) + let name = HeaderName::from_bytes(key) .map_err(|err| type_error(err.to_string()))?; - let v = HeaderValue::from_bytes(&value) + let v = HeaderValue::from_bytes(value) .map_err(|err| type_error(err.to_string()))?; let is_disallowed_header = matches!( @@ -265,40 +376,54 @@ where } } } + Ok(request) +} - let request = request.body(Body::empty())?; - let domain = &uri.host().unwrap().to_string(); - let port = &uri.port_u16().unwrap_or(match uri.scheme_str() { - Some("wss") => 443, - Some("ws") => 80, - _ => unreachable!(), - }); - let addr = format!("{domain}:{port}"); - let tcp_socket = TcpStream::connect(addr).await?; - - let (stream, response) = match uri.scheme_str() { - Some("ws") => handshake(cancel_resource, request, tcp_socket).await?, - Some("wss") => { - let tls_config = create_client_config( - root_cert_store, - vec![], - unsafely_ignore_certificate_errors, - None, - )?; - let dnsname = ServerName::try_from(domain.as_str()) - .map_err(|_| invalid_hostname(domain))?; - let mut tls_connector = TlsStream::new_client_side( - tcp_socket, - tls_config.into(), - dnsname, - NonZeroUsize::new(65536), +#[op2(async)] +#[serde] +pub async fn op_ws_create<WP>( + state: Rc<RefCell<OpState>>, + #[string] api_name: String, + #[string] url: String, + #[string] protocols: String, + #[smi] cancel_handle: Option<ResourceId>, + #[serde] headers: Option<Vec<(ByteString, ByteString)>>, +) -> Result<CreateResponse, AnyError> +where + WP: WebSocketPermissions + 'static, +{ + { + let mut s = state.borrow_mut(); + s.borrow_mut::<WP>() + .check_net_url(&url::Url::parse(&url)?, &api_name) + .expect( + "Permission check should have been done in op_ws_check_permission", ); - let _hs = tls_connector.handshake().await?; - handshake(cancel_resource, request, tls_connector).await? - } - _ => unreachable!(), + } + + let cancel_resource = if let Some(cancel_rid) = cancel_handle { + let r = state + .borrow_mut() + .resource_table + .get::<WsCancelResource>(cancel_rid)?; + Some(r.0.clone()) + } else { + None }; + let uri: Uri = url.parse()?; + + let handshake = handshake_websocket(&state, &uri, &protocols, headers) + .map_err(|err| { + AnyError::from(DomExceptionNetworkError::new(&format!( + "failed to connect to WebSocket: {err}" + ))) + }); + let (stream, response) = match cancel_resource { + Some(rc) => handshake.try_or_cancel(rc).await, + None => handshake.await, + }?; + if let Some(cancel_rid) = cancel_handle { if let Ok(res) = state.borrow_mut().resource_table.take_any(cancel_rid) { res.close(); @@ -308,12 +433,11 @@ where let mut state = state.borrow_mut(); let rid = state.resource_table.add(ServerWebSocket::new(stream)); - let protocol = match response.headers().get("Sec-WebSocket-Protocol") { + let protocol = match response.get("Sec-WebSocket-Protocol") { Some(header) => header.to_str().unwrap(), None => "", }; let extensions = response - .headers() .get_all("Sec-WebSocket-Extensions") .iter() .map(|header| header.to_str().unwrap()) diff --git a/ext/websocket/stream.rs b/ext/websocket/stream.rs index 6f93406f6..7e36c8147 100644 --- a/ext/websocket/stream.rs +++ b/ext/websocket/stream.rs @@ -2,8 +2,12 @@ use bytes::Buf; use bytes::Bytes; use deno_net::raw::NetworkStream; +use h2::RecvStream; +use h2::SendStream; use hyper::upgrade::Upgraded; +use std::io::ErrorKind; use std::pin::Pin; +use std::task::ready; use std::task::Poll; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; @@ -13,6 +17,7 @@ use tokio::io::ReadBuf; pub(crate) enum WsStreamKind { Upgraded(Upgraded), Network(NetworkStream), + H2(SendStream<Bytes>, RecvStream), } pub(crate) struct WebSocketStream { @@ -54,6 +59,27 @@ impl AsyncRead for WebSocketStream { match &mut self.stream { WsStreamKind::Network(stream) => Pin::new(stream).poll_read(cx, buf), WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_read(cx, buf), + WsStreamKind::H2(_, recv) => { + let data = ready!(recv.poll_data(cx)); + let Some(data) = data else { + // EOF + return Poll::Ready(Ok(())); + }; + let mut data = data.map_err(|e| { + std::io::Error::new(std::io::ErrorKind::InvalidData, e) + })?; + recv.flow_control().release_capacity(data.len()).unwrap(); + // This looks like the prefix code above -- can we share this? + let copy_len = std::cmp::min(data.len(), buf.remaining()); + // TODO: There should be a way to do following two lines cleaner... + buf.put_slice(&data[..copy_len]); + data.advance(copy_len); + // Put back what's left + if !data.is_empty() { + self.pre = Some(data); + } + Poll::Ready(Ok(())) + } } } } @@ -67,6 +93,30 @@ impl AsyncWrite for WebSocketStream { match &mut self.stream { WsStreamKind::Network(stream) => Pin::new(stream).poll_write(cx, buf), WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_write(cx, buf), + WsStreamKind::H2(send, _) => { + // Zero-length write succeeds + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + send.reserve_capacity(buf.len()); + let res = ready!(send.poll_capacity(cx)); + + // TODO(mmastrac): the documentation is not entirely clear what to do here, so we'll continue + _ = res; + + // We'll try to send whatever we have capacity for + let size = std::cmp::min(buf.len(), send.capacity()); + assert!(size > 0); + + let buf: Bytes = Bytes::copy_from_slice(&buf[0..size]); + let len = buf.len(); + // TODO(mmastrac): surface the h2 error? + let res = send + .send_data(buf, false) + .map_err(|_| std::io::Error::from(ErrorKind::Other)); + Poll::Ready(res.map(|_| len)) + } } } @@ -77,6 +127,7 @@ impl AsyncWrite for WebSocketStream { match &mut self.stream { WsStreamKind::Network(stream) => Pin::new(stream).poll_flush(cx), WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_flush(cx), + WsStreamKind::H2(..) => Poll::Ready(Ok(())), } } @@ -87,6 +138,13 @@ impl AsyncWrite for WebSocketStream { match &mut self.stream { WsStreamKind::Network(stream) => Pin::new(stream).poll_shutdown(cx), WsStreamKind::Upgraded(stream) => Pin::new(stream).poll_shutdown(cx), + WsStreamKind::H2(send, _) => { + // TODO(mmastrac): surface the h2 error? + let res = send + .send_data(Bytes::new(), false) + .map_err(|_| std::io::Error::from(ErrorKind::Other)); + Poll::Ready(res) + } } } @@ -94,6 +152,7 @@ impl AsyncWrite for WebSocketStream { match &self.stream { WsStreamKind::Network(stream) => stream.is_write_vectored(), WsStreamKind::Upgraded(stream) => stream.is_write_vectored(), + WsStreamKind::H2(..) => false, } } @@ -109,6 +168,10 @@ impl AsyncWrite for WebSocketStream { WsStreamKind::Upgraded(stream) => { Pin::new(stream).poll_write_vectored(cx, bufs) } + WsStreamKind::H2(..) => { + // TODO(mmastrac): this is possibly just too difficult, but we'll never call it + unimplemented!() + } } } } |