diff options
author | haturau <135221985+haturatu@users.noreply.github.com> | 2024-11-20 01:20:47 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-20 01:20:47 +0900 |
commit | 85719a67e59c7aa45bead26e4942d7df8b1b42d4 (patch) | |
tree | face0aecaac53e93ce2f23b53c48859bcf1a36ec /ext/websocket/lib.rs | |
parent | 67697bc2e4a62a9670699fd18ad0dd8efc5bd955 (diff) | |
parent | 186b52731c6bb326c4d32905c5e732d082e83465 (diff) |
Merge branch 'denoland:main' into main
Diffstat (limited to 'ext/websocket/lib.rs')
-rw-r--r-- | ext/websocket/lib.rs | 191 |
1 files changed, 107 insertions, 84 deletions
diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index b8043516b..a5734271c 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -1,10 +1,6 @@ // Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. use crate::stream::WebSocketStream; use bytes::Bytes; -use deno_core::anyhow::bail; -use deno_core::error::invalid_hostname; -use deno_core::error::type_error; -use deno_core::error::AnyError; use deno_core::futures::TryFutureExt; use deno_core::op2; use deno_core::unsync::spawn; @@ -43,7 +39,6 @@ use serde::Serialize; use std::borrow::Cow; use std::cell::Cell; use std::cell::RefCell; -use std::fmt; use std::future::Future; use std::num::NonZeroUsize; use std::path::PathBuf; @@ -55,6 +50,7 @@ use tokio::io::ReadHalf; use tokio::io::WriteHalf; use tokio::net::TcpStream; +use deno_permissions::PermissionCheckError; use fastwebsockets::CloseCode; use fastwebsockets::FragmentCollectorRead; use fastwebsockets::Frame; @@ -75,11 +71,33 @@ static USE_WRITEV: Lazy<bool> = Lazy::new(|| { false }); +#[derive(Debug, thiserror::Error)] +pub enum WebsocketError { + #[error(transparent)] + Url(url::ParseError), + #[error(transparent)] + Permission(#[from] PermissionCheckError), + #[error(transparent)] + Resource(deno_core::error::AnyError), + #[error(transparent)] + Uri(#[from] http::uri::InvalidUri), + #[error("{0}")] + Io(#[from] std::io::Error), + #[error(transparent)] + WebSocket(#[from] fastwebsockets::WebSocketError), + #[error("failed to connect to WebSocket: {0}")] + ConnectionFailed(#[from] HandshakeError), + #[error(transparent)] + Canceled(#[from] deno_core::Canceled), +} + #[derive(Clone)] pub struct WsRootStoreProvider(Option<Arc<dyn RootCertStoreProvider>>); impl WsRootStoreProvider { - pub fn get_or_try_init(&self) -> Result<Option<RootCertStore>, AnyError> { + pub fn get_or_try_init( + &self, + ) -> Result<Option<RootCertStore>, deno_core::error::AnyError> { Ok(match &self.0 { Some(provider) => Some(provider.get_or_try_init()?.clone()), None => None, @@ -95,7 +113,7 @@ pub trait WebSocketPermissions { &mut self, _url: &url::Url, _api_name: &str, - ) -> Result<(), AnyError>; + ) -> Result<(), PermissionCheckError>; } impl WebSocketPermissions for deno_permissions::PermissionsContainer { @@ -104,7 +122,7 @@ impl WebSocketPermissions for deno_permissions::PermissionsContainer { &mut self, url: &url::Url, api_name: &str, - ) -> Result<(), AnyError> { + ) -> Result<(), PermissionCheckError> { deno_permissions::PermissionsContainer::check_net_url(self, url, api_name) } } @@ -137,13 +155,14 @@ pub fn op_ws_check_permission_and_cancel_handle<WP>( #[string] api_name: String, #[string] url: String, cancel_handle: bool, -) -> Result<Option<ResourceId>, AnyError> +) -> Result<Option<ResourceId>, WebsocketError> where WP: WebSocketPermissions + 'static, { - state - .borrow_mut::<WP>() - .check_net_url(&url::Url::parse(&url)?, &api_name)?; + state.borrow_mut::<WP>().check_net_url( + &url::Url::parse(&url).map_err(WebsocketError::Url)?, + &api_name, + )?; if cancel_handle { let rid = state @@ -163,16 +182,46 @@ pub struct CreateResponse { extensions: String, } +#[derive(Debug, thiserror::Error)] +pub enum HandshakeError { + #[error("Missing path in url")] + MissingPath, + #[error("Invalid status code {0}")] + InvalidStatusCode(StatusCode), + #[error(transparent)] + Http(#[from] http::Error), + #[error(transparent)] + WebSocket(#[from] fastwebsockets::WebSocketError), + #[error("Didn't receive h2 alpn, aborting connection")] + NoH2Alpn, + #[error(transparent)] + Rustls(#[from] deno_tls::rustls::Error), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + H2(#[from] h2::Error), + #[error("Invalid hostname: '{0}'")] + InvalidHostname(String), + #[error(transparent)] + RootStoreError(deno_core::error::AnyError), + #[error(transparent)] + Tls(deno_tls::TlsError), + #[error(transparent)] + HeaderName(#[from] http::header::InvalidHeaderName), + #[error(transparent)] + HeaderValue(#[from] http::header::InvalidHeaderValue), +} + async fn handshake_websocket( state: &Rc<RefCell<OpState>>, uri: &Uri, protocols: &str, headers: Option<Vec<(ByteString, ByteString)>>, -) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> { +) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> { let mut request = Request::builder().method(Method::GET).uri( uri .path_and_query() - .ok_or(type_error("Missing path in url".to_string()))? + .ok_or(HandshakeError::MissingPath)? .as_str(), ); @@ -194,7 +243,9 @@ async fn handshake_websocket( request = populate_common_request_headers(request, &user_agent, protocols, &headers)?; - let request = request.body(http_body_util::Empty::new())?; + let request = request + .body(http_body_util::Empty::new()) + .map_err(HandshakeError::Http)?; let domain = &uri.host().unwrap().to_string(); let port = &uri.port_u16().unwrap_or(match uri.scheme_str() { Some("wss") => 443, @@ -231,7 +282,7 @@ async fn handshake_websocket( async fn handshake_http1_ws( request: Request<http_body_util::Empty<Bytes>>, addr: &String, -) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> { +) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> { let tcp_socket = TcpStream::connect(addr).await?; handshake_connection(request, tcp_socket).await } @@ -241,11 +292,11 @@ async fn handshake_http1_wss( request: Request<http_body_util::Empty<Bytes>>, domain: &str, addr: &str, -) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> { +) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> { let tcp_socket = TcpStream::connect(addr).await?; let tls_config = create_ws_client_config(state, SocketUse::Http1Only)?; let dnsname = ServerName::try_from(domain.to_string()) - .map_err(|_| invalid_hostname(domain))?; + .map_err(|_| HandshakeError::InvalidHostname(domain.to_string()))?; let mut tls_connector = TlsStream::new_client_side( tcp_socket, ClientConnection::new(tls_config.into(), dnsname)?, @@ -266,11 +317,11 @@ async fn handshake_http2_wss( domain: &str, headers: &Option<Vec<(ByteString, ByteString)>>, addr: &str, -) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> { +) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> { let tcp_socket = TcpStream::connect(addr).await?; let tls_config = create_ws_client_config(state, SocketUse::Http2Only)?; let dnsname = ServerName::try_from(domain.to_string()) - .map_err(|_| invalid_hostname(domain))?; + .map_err(|_| HandshakeError::InvalidHostname(domain.to_string()))?; // We need to better expose the underlying errors here let mut tls_connector = TlsStream::new_client_side( tcp_socket, @@ -279,7 +330,7 @@ async fn handshake_http2_wss( ); let handshake = tls_connector.handshake().await?; if handshake.alpn.is_none() { - bail!("Didn't receive h2 alpn, aborting connection"); + return Err(HandshakeError::NoH2Alpn); } let h2 = h2::client::Builder::new(); let (mut send, conn) = h2.handshake::<_, Bytes>(tls_connector).await?; @@ -298,7 +349,7 @@ async fn handshake_http2_wss( let (resp, send) = send.send_request(request.body(())?, false)?; let resp = resp.await?; if resp.status() != StatusCode::OK { - bail!("Invalid status code: {}", resp.status()); + return Err(HandshakeError::InvalidStatusCode(resp.status())); } let (http::response::Parts { headers, .. }, recv) = resp.into_parts(); let mut stream = WebSocket::after_handshake( @@ -317,7 +368,7 @@ async fn handshake_connection< >( request: Request<http_body_util::Empty<Bytes>>, socket: S, -) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> { +) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> { let (upgraded, response) = fastwebsockets::handshake::client(&LocalExecutor, request, socket).await?; @@ -332,7 +383,7 @@ async fn handshake_connection< pub fn create_ws_client_config( state: &Rc<RefCell<OpState>>, socket_use: SocketUse, -) -> Result<ClientConfig, AnyError> { +) -> Result<ClientConfig, HandshakeError> { let unsafely_ignore_certificate_errors: Option<Vec<String>> = state .borrow() .try_borrow::<UnsafelyIgnoreCertificateErrors>() @@ -340,7 +391,8 @@ pub fn create_ws_client_config( let root_cert_store = state .borrow() .borrow::<WsRootStoreProvider>() - .get_or_try_init()?; + .get_or_try_init() + .map_err(HandshakeError::RootStoreError)?; create_client_config( root_cert_store, @@ -349,7 +401,7 @@ pub fn create_ws_client_config( TlsKeys::Null, socket_use, ) - .map_err(|e| e.into()) + .map_err(HandshakeError::Tls) } /// Headers common to both http/1.1 and h2 requests. @@ -358,7 +410,7 @@ fn populate_common_request_headers( user_agent: &str, protocols: &str, headers: &Option<Vec<(ByteString, ByteString)>>, -) -> Result<http::request::Builder, AnyError> { +) -> Result<http::request::Builder, HandshakeError> { request = request .header("User-Agent", user_agent) .header("Sec-WebSocket-Version", "13"); @@ -369,10 +421,8 @@ fn populate_common_request_headers( if let Some(headers) = headers { for (key, value) in headers { - let name = HeaderName::from_bytes(key) - .map_err(|err| type_error(err.to_string()))?; - let v = HeaderValue::from_bytes(value) - .map_err(|err| type_error(err.to_string()))?; + let name = HeaderName::from_bytes(key)?; + let v = HeaderValue::from_bytes(value)?; let is_disallowed_header = matches!( name, @@ -402,14 +452,17 @@ pub async fn op_ws_create<WP>( #[string] protocols: String, #[smi] cancel_handle: Option<ResourceId>, #[serde] headers: Option<Vec<(ByteString, ByteString)>>, -) -> Result<CreateResponse, AnyError> +) -> Result<CreateResponse, WebsocketError> where WP: WebSocketPermissions + 'static, { { let mut s = state.borrow_mut(); s.borrow_mut::<WP>() - .check_net_url(&url::Url::parse(&url)?, &api_name) + .check_net_url( + &url::Url::parse(&url).map_err(WebsocketError::Url)?, + &api_name, + ) .expect( "Permission check should have been done in op_ws_check_permission", ); @@ -419,7 +472,8 @@ where let r = state .borrow_mut() .resource_table - .get::<WsCancelResource>(cancel_rid)?; + .get::<WsCancelResource>(cancel_rid) + .map_err(WebsocketError::Resource)?; Some(r.0.clone()) } else { None @@ -428,15 +482,11 @@ where let uri: Uri = url.parse()?; let handshake = handshake_websocket(&state, &uri, &protocols, headers) - .map_err(|err| { - AnyError::from(DomExceptionNetworkError::new(&format!( - "failed to connect to WebSocket: {err}" - ))) - }); + .map_err(WebsocketError::ConnectionFailed); let (stream, response) = match cancel_resource { - Some(rc) => handshake.try_or_cancel(rc).await, - None => handshake.await, - }?; + Some(rc) => handshake.try_or_cancel(rc).await?, + None => handshake.await?, + }; if let Some(cancel_rid) = cancel_handle { if let Ok(res) = state.borrow_mut().resource_table.take_any(cancel_rid) { @@ -521,14 +571,12 @@ impl ServerWebSocket { self: &Rc<Self>, lock: AsyncMutFuture<WebSocketWrite<WriteHalf<WebSocketStream>>>, frame: Frame<'_>, - ) -> Result<(), AnyError> { + ) -> Result<(), WebsocketError> { let mut ws = lock.await; if ws.is_closed() { return Ok(()); } - ws.write_frame(frame) - .await - .map_err(|err| type_error(err.to_string()))?; + ws.write_frame(frame).await?; Ok(()) } } @@ -543,7 +591,7 @@ pub fn ws_create_server_stream( state: &mut OpState, transport: NetworkStream, read_buf: Bytes, -) -> Result<ResourceId, AnyError> { +) -> ResourceId { let mut ws = WebSocket::after_handshake( WebSocketStream::new( stream::WsStreamKind::Network(transport), @@ -555,8 +603,7 @@ pub fn ws_create_server_stream( ws.set_auto_close(true); ws.set_auto_pong(true); - let rid = state.resource_table.add(ServerWebSocket::new(ws)); - Ok(rid) + state.resource_table.add(ServerWebSocket::new(ws)) } fn send_binary(state: &mut OpState, rid: ResourceId, data: &[u8]) { @@ -626,11 +673,12 @@ pub async fn op_ws_send_binary_async( state: Rc<RefCell<OpState>>, #[smi] rid: ResourceId, #[buffer] data: JsBuffer, -) -> Result<(), AnyError> { +) -> Result<(), WebsocketError> { let resource = state .borrow_mut() .resource_table - .get::<ServerWebSocket>(rid)?; + .get::<ServerWebSocket>(rid) + .map_err(WebsocketError::Resource)?; let data = data.to_vec(); let lock = resource.reserve_lock(); resource @@ -644,11 +692,12 @@ pub async fn op_ws_send_text_async( state: Rc<RefCell<OpState>>, #[smi] rid: ResourceId, #[string] data: String, -) -> Result<(), AnyError> { +) -> Result<(), WebsocketError> { let resource = state .borrow_mut() .resource_table - .get::<ServerWebSocket>(rid)?; + .get::<ServerWebSocket>(rid) + .map_err(WebsocketError::Resource)?; let lock = resource.reserve_lock(); resource .write_frame( @@ -678,11 +727,12 @@ pub fn op_ws_get_buffered_amount( pub async fn op_ws_send_ping( state: Rc<RefCell<OpState>>, #[smi] rid: ResourceId, -) -> Result<(), AnyError> { +) -> Result<(), WebsocketError> { let resource = state .borrow_mut() .resource_table - .get::<ServerWebSocket>(rid)?; + .get::<ServerWebSocket>(rid) + .map_err(WebsocketError::Resource)?; let lock = resource.reserve_lock(); resource .write_frame( @@ -698,7 +748,7 @@ pub async fn op_ws_close( #[smi] rid: ResourceId, #[smi] code: Option<u16>, #[string] reason: Option<String>, -) -> Result<(), AnyError> { +) -> Result<(), WebsocketError> { let Ok(resource) = state .borrow_mut() .resource_table @@ -713,8 +763,7 @@ pub async fn op_ws_close( resource.closed.set(true); let lock = resource.reserve_lock(); - resource.write_frame(lock, frame).await?; - Ok(()) + resource.write_frame(lock, frame).await } #[op2] @@ -868,32 +917,6 @@ pub fn get_declaration() -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("lib.deno_websocket.d.ts") } -#[derive(Debug)] -pub struct DomExceptionNetworkError { - pub msg: String, -} - -impl DomExceptionNetworkError { - pub fn new(msg: &str) -> Self { - DomExceptionNetworkError { - msg: msg.to_string(), - } - } -} - -impl fmt::Display for DomExceptionNetworkError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.pad(&self.msg) - } -} - -impl std::error::Error for DomExceptionNetworkError {} - -pub fn get_network_error_class_name(e: &AnyError) -> Option<&'static str> { - e.downcast_ref::<DomExceptionNetworkError>() - .map(|_| "DOMExceptionNetworkError") -} - // Needed so hyper can use non Send futures #[derive(Clone)] struct LocalExecutor; |