summaryrefslogtreecommitdiff
path: root/ext/websocket/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ext/websocket/lib.rs')
-rw-r--r--ext/websocket/lib.rs191
1 files changed, 107 insertions, 84 deletions
diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs
index b8043516b..a5734271c 100644
--- a/ext/websocket/lib.rs
+++ b/ext/websocket/lib.rs
@@ -1,10 +1,6 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use crate::stream::WebSocketStream;
use bytes::Bytes;
-use deno_core::anyhow::bail;
-use deno_core::error::invalid_hostname;
-use deno_core::error::type_error;
-use deno_core::error::AnyError;
use deno_core::futures::TryFutureExt;
use deno_core::op2;
use deno_core::unsync::spawn;
@@ -43,7 +39,6 @@ use serde::Serialize;
use std::borrow::Cow;
use std::cell::Cell;
use std::cell::RefCell;
-use std::fmt;
use std::future::Future;
use std::num::NonZeroUsize;
use std::path::PathBuf;
@@ -55,6 +50,7 @@ use tokio::io::ReadHalf;
use tokio::io::WriteHalf;
use tokio::net::TcpStream;
+use deno_permissions::PermissionCheckError;
use fastwebsockets::CloseCode;
use fastwebsockets::FragmentCollectorRead;
use fastwebsockets::Frame;
@@ -75,11 +71,33 @@ static USE_WRITEV: Lazy<bool> = Lazy::new(|| {
false
});
+#[derive(Debug, thiserror::Error)]
+pub enum WebsocketError {
+ #[error(transparent)]
+ Url(url::ParseError),
+ #[error(transparent)]
+ Permission(#[from] PermissionCheckError),
+ #[error(transparent)]
+ Resource(deno_core::error::AnyError),
+ #[error(transparent)]
+ Uri(#[from] http::uri::InvalidUri),
+ #[error("{0}")]
+ Io(#[from] std::io::Error),
+ #[error(transparent)]
+ WebSocket(#[from] fastwebsockets::WebSocketError),
+ #[error("failed to connect to WebSocket: {0}")]
+ ConnectionFailed(#[from] HandshakeError),
+ #[error(transparent)]
+ Canceled(#[from] deno_core::Canceled),
+}
+
#[derive(Clone)]
pub struct WsRootStoreProvider(Option<Arc<dyn RootCertStoreProvider>>);
impl WsRootStoreProvider {
- pub fn get_or_try_init(&self) -> Result<Option<RootCertStore>, AnyError> {
+ pub fn get_or_try_init(
+ &self,
+ ) -> Result<Option<RootCertStore>, deno_core::error::AnyError> {
Ok(match &self.0 {
Some(provider) => Some(provider.get_or_try_init()?.clone()),
None => None,
@@ -95,7 +113,7 @@ pub trait WebSocketPermissions {
&mut self,
_url: &url::Url,
_api_name: &str,
- ) -> Result<(), AnyError>;
+ ) -> Result<(), PermissionCheckError>;
}
impl WebSocketPermissions for deno_permissions::PermissionsContainer {
@@ -104,7 +122,7 @@ impl WebSocketPermissions for deno_permissions::PermissionsContainer {
&mut self,
url: &url::Url,
api_name: &str,
- ) -> Result<(), AnyError> {
+ ) -> Result<(), PermissionCheckError> {
deno_permissions::PermissionsContainer::check_net_url(self, url, api_name)
}
}
@@ -137,13 +155,14 @@ pub fn op_ws_check_permission_and_cancel_handle<WP>(
#[string] api_name: String,
#[string] url: String,
cancel_handle: bool,
-) -> Result<Option<ResourceId>, AnyError>
+) -> Result<Option<ResourceId>, WebsocketError>
where
WP: WebSocketPermissions + 'static,
{
- state
- .borrow_mut::<WP>()
- .check_net_url(&url::Url::parse(&url)?, &api_name)?;
+ state.borrow_mut::<WP>().check_net_url(
+ &url::Url::parse(&url).map_err(WebsocketError::Url)?,
+ &api_name,
+ )?;
if cancel_handle {
let rid = state
@@ -163,16 +182,46 @@ pub struct CreateResponse {
extensions: String,
}
+#[derive(Debug, thiserror::Error)]
+pub enum HandshakeError {
+ #[error("Missing path in url")]
+ MissingPath,
+ #[error("Invalid status code {0}")]
+ InvalidStatusCode(StatusCode),
+ #[error(transparent)]
+ Http(#[from] http::Error),
+ #[error(transparent)]
+ WebSocket(#[from] fastwebsockets::WebSocketError),
+ #[error("Didn't receive h2 alpn, aborting connection")]
+ NoH2Alpn,
+ #[error(transparent)]
+ Rustls(#[from] deno_tls::rustls::Error),
+ #[error(transparent)]
+ Io(#[from] std::io::Error),
+ #[error(transparent)]
+ H2(#[from] h2::Error),
+ #[error("Invalid hostname: '{0}'")]
+ InvalidHostname(String),
+ #[error(transparent)]
+ RootStoreError(deno_core::error::AnyError),
+ #[error(transparent)]
+ Tls(deno_tls::TlsError),
+ #[error(transparent)]
+ HeaderName(#[from] http::header::InvalidHeaderName),
+ #[error(transparent)]
+ HeaderValue(#[from] http::header::InvalidHeaderValue),
+}
+
async fn handshake_websocket(
state: &Rc<RefCell<OpState>>,
uri: &Uri,
protocols: &str,
headers: Option<Vec<(ByteString, ByteString)>>,
-) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> {
+) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> {
let mut request = Request::builder().method(Method::GET).uri(
uri
.path_and_query()
- .ok_or(type_error("Missing path in url".to_string()))?
+ .ok_or(HandshakeError::MissingPath)?
.as_str(),
);
@@ -194,7 +243,9 @@ async fn handshake_websocket(
request =
populate_common_request_headers(request, &user_agent, protocols, &headers)?;
- let request = request.body(http_body_util::Empty::new())?;
+ let request = request
+ .body(http_body_util::Empty::new())
+ .map_err(HandshakeError::Http)?;
let domain = &uri.host().unwrap().to_string();
let port = &uri.port_u16().unwrap_or(match uri.scheme_str() {
Some("wss") => 443,
@@ -231,7 +282,7 @@ async fn handshake_websocket(
async fn handshake_http1_ws(
request: Request<http_body_util::Empty<Bytes>>,
addr: &String,
-) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> {
+) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> {
let tcp_socket = TcpStream::connect(addr).await?;
handshake_connection(request, tcp_socket).await
}
@@ -241,11 +292,11 @@ async fn handshake_http1_wss(
request: Request<http_body_util::Empty<Bytes>>,
domain: &str,
addr: &str,
-) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> {
+) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> {
let tcp_socket = TcpStream::connect(addr).await?;
let tls_config = create_ws_client_config(state, SocketUse::Http1Only)?;
let dnsname = ServerName::try_from(domain.to_string())
- .map_err(|_| invalid_hostname(domain))?;
+ .map_err(|_| HandshakeError::InvalidHostname(domain.to_string()))?;
let mut tls_connector = TlsStream::new_client_side(
tcp_socket,
ClientConnection::new(tls_config.into(), dnsname)?,
@@ -266,11 +317,11 @@ async fn handshake_http2_wss(
domain: &str,
headers: &Option<Vec<(ByteString, ByteString)>>,
addr: &str,
-) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> {
+) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> {
let tcp_socket = TcpStream::connect(addr).await?;
let tls_config = create_ws_client_config(state, SocketUse::Http2Only)?;
let dnsname = ServerName::try_from(domain.to_string())
- .map_err(|_| invalid_hostname(domain))?;
+ .map_err(|_| HandshakeError::InvalidHostname(domain.to_string()))?;
// We need to better expose the underlying errors here
let mut tls_connector = TlsStream::new_client_side(
tcp_socket,
@@ -279,7 +330,7 @@ async fn handshake_http2_wss(
);
let handshake = tls_connector.handshake().await?;
if handshake.alpn.is_none() {
- bail!("Didn't receive h2 alpn, aborting connection");
+ return Err(HandshakeError::NoH2Alpn);
}
let h2 = h2::client::Builder::new();
let (mut send, conn) = h2.handshake::<_, Bytes>(tls_connector).await?;
@@ -298,7 +349,7 @@ async fn handshake_http2_wss(
let (resp, send) = send.send_request(request.body(())?, false)?;
let resp = resp.await?;
if resp.status() != StatusCode::OK {
- bail!("Invalid status code: {}", resp.status());
+ return Err(HandshakeError::InvalidStatusCode(resp.status()));
}
let (http::response::Parts { headers, .. }, recv) = resp.into_parts();
let mut stream = WebSocket::after_handshake(
@@ -317,7 +368,7 @@ async fn handshake_connection<
>(
request: Request<http_body_util::Empty<Bytes>>,
socket: S,
-) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), AnyError> {
+) -> Result<(WebSocket<WebSocketStream>, http::HeaderMap), HandshakeError> {
let (upgraded, response) =
fastwebsockets::handshake::client(&LocalExecutor, request, socket).await?;
@@ -332,7 +383,7 @@ async fn handshake_connection<
pub fn create_ws_client_config(
state: &Rc<RefCell<OpState>>,
socket_use: SocketUse,
-) -> Result<ClientConfig, AnyError> {
+) -> Result<ClientConfig, HandshakeError> {
let unsafely_ignore_certificate_errors: Option<Vec<String>> = state
.borrow()
.try_borrow::<UnsafelyIgnoreCertificateErrors>()
@@ -340,7 +391,8 @@ pub fn create_ws_client_config(
let root_cert_store = state
.borrow()
.borrow::<WsRootStoreProvider>()
- .get_or_try_init()?;
+ .get_or_try_init()
+ .map_err(HandshakeError::RootStoreError)?;
create_client_config(
root_cert_store,
@@ -349,7 +401,7 @@ pub fn create_ws_client_config(
TlsKeys::Null,
socket_use,
)
- .map_err(|e| e.into())
+ .map_err(HandshakeError::Tls)
}
/// Headers common to both http/1.1 and h2 requests.
@@ -358,7 +410,7 @@ fn populate_common_request_headers(
user_agent: &str,
protocols: &str,
headers: &Option<Vec<(ByteString, ByteString)>>,
-) -> Result<http::request::Builder, AnyError> {
+) -> Result<http::request::Builder, HandshakeError> {
request = request
.header("User-Agent", user_agent)
.header("Sec-WebSocket-Version", "13");
@@ -369,10 +421,8 @@ fn populate_common_request_headers(
if let Some(headers) = headers {
for (key, value) in headers {
- let name = HeaderName::from_bytes(key)
- .map_err(|err| type_error(err.to_string()))?;
- let v = HeaderValue::from_bytes(value)
- .map_err(|err| type_error(err.to_string()))?;
+ let name = HeaderName::from_bytes(key)?;
+ let v = HeaderValue::from_bytes(value)?;
let is_disallowed_header = matches!(
name,
@@ -402,14 +452,17 @@ pub async fn op_ws_create<WP>(
#[string] protocols: String,
#[smi] cancel_handle: Option<ResourceId>,
#[serde] headers: Option<Vec<(ByteString, ByteString)>>,
-) -> Result<CreateResponse, AnyError>
+) -> Result<CreateResponse, WebsocketError>
where
WP: WebSocketPermissions + 'static,
{
{
let mut s = state.borrow_mut();
s.borrow_mut::<WP>()
- .check_net_url(&url::Url::parse(&url)?, &api_name)
+ .check_net_url(
+ &url::Url::parse(&url).map_err(WebsocketError::Url)?,
+ &api_name,
+ )
.expect(
"Permission check should have been done in op_ws_check_permission",
);
@@ -419,7 +472,8 @@ where
let r = state
.borrow_mut()
.resource_table
- .get::<WsCancelResource>(cancel_rid)?;
+ .get::<WsCancelResource>(cancel_rid)
+ .map_err(WebsocketError::Resource)?;
Some(r.0.clone())
} else {
None
@@ -428,15 +482,11 @@ where
let uri: Uri = url.parse()?;
let handshake = handshake_websocket(&state, &uri, &protocols, headers)
- .map_err(|err| {
- AnyError::from(DomExceptionNetworkError::new(&format!(
- "failed to connect to WebSocket: {err}"
- )))
- });
+ .map_err(WebsocketError::ConnectionFailed);
let (stream, response) = match cancel_resource {
- Some(rc) => handshake.try_or_cancel(rc).await,
- None => handshake.await,
- }?;
+ Some(rc) => handshake.try_or_cancel(rc).await?,
+ None => handshake.await?,
+ };
if let Some(cancel_rid) = cancel_handle {
if let Ok(res) = state.borrow_mut().resource_table.take_any(cancel_rid) {
@@ -521,14 +571,12 @@ impl ServerWebSocket {
self: &Rc<Self>,
lock: AsyncMutFuture<WebSocketWrite<WriteHalf<WebSocketStream>>>,
frame: Frame<'_>,
- ) -> Result<(), AnyError> {
+ ) -> Result<(), WebsocketError> {
let mut ws = lock.await;
if ws.is_closed() {
return Ok(());
}
- ws.write_frame(frame)
- .await
- .map_err(|err| type_error(err.to_string()))?;
+ ws.write_frame(frame).await?;
Ok(())
}
}
@@ -543,7 +591,7 @@ pub fn ws_create_server_stream(
state: &mut OpState,
transport: NetworkStream,
read_buf: Bytes,
-) -> Result<ResourceId, AnyError> {
+) -> ResourceId {
let mut ws = WebSocket::after_handshake(
WebSocketStream::new(
stream::WsStreamKind::Network(transport),
@@ -555,8 +603,7 @@ pub fn ws_create_server_stream(
ws.set_auto_close(true);
ws.set_auto_pong(true);
- let rid = state.resource_table.add(ServerWebSocket::new(ws));
- Ok(rid)
+ state.resource_table.add(ServerWebSocket::new(ws))
}
fn send_binary(state: &mut OpState, rid: ResourceId, data: &[u8]) {
@@ -626,11 +673,12 @@ pub async fn op_ws_send_binary_async(
state: Rc<RefCell<OpState>>,
#[smi] rid: ResourceId,
#[buffer] data: JsBuffer,
-) -> Result<(), AnyError> {
+) -> Result<(), WebsocketError> {
let resource = state
.borrow_mut()
.resource_table
- .get::<ServerWebSocket>(rid)?;
+ .get::<ServerWebSocket>(rid)
+ .map_err(WebsocketError::Resource)?;
let data = data.to_vec();
let lock = resource.reserve_lock();
resource
@@ -644,11 +692,12 @@ pub async fn op_ws_send_text_async(
state: Rc<RefCell<OpState>>,
#[smi] rid: ResourceId,
#[string] data: String,
-) -> Result<(), AnyError> {
+) -> Result<(), WebsocketError> {
let resource = state
.borrow_mut()
.resource_table
- .get::<ServerWebSocket>(rid)?;
+ .get::<ServerWebSocket>(rid)
+ .map_err(WebsocketError::Resource)?;
let lock = resource.reserve_lock();
resource
.write_frame(
@@ -678,11 +727,12 @@ pub fn op_ws_get_buffered_amount(
pub async fn op_ws_send_ping(
state: Rc<RefCell<OpState>>,
#[smi] rid: ResourceId,
-) -> Result<(), AnyError> {
+) -> Result<(), WebsocketError> {
let resource = state
.borrow_mut()
.resource_table
- .get::<ServerWebSocket>(rid)?;
+ .get::<ServerWebSocket>(rid)
+ .map_err(WebsocketError::Resource)?;
let lock = resource.reserve_lock();
resource
.write_frame(
@@ -698,7 +748,7 @@ pub async fn op_ws_close(
#[smi] rid: ResourceId,
#[smi] code: Option<u16>,
#[string] reason: Option<String>,
-) -> Result<(), AnyError> {
+) -> Result<(), WebsocketError> {
let Ok(resource) = state
.borrow_mut()
.resource_table
@@ -713,8 +763,7 @@ pub async fn op_ws_close(
resource.closed.set(true);
let lock = resource.reserve_lock();
- resource.write_frame(lock, frame).await?;
- Ok(())
+ resource.write_frame(lock, frame).await
}
#[op2]
@@ -868,32 +917,6 @@ pub fn get_declaration() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("lib.deno_websocket.d.ts")
}
-#[derive(Debug)]
-pub struct DomExceptionNetworkError {
- pub msg: String,
-}
-
-impl DomExceptionNetworkError {
- pub fn new(msg: &str) -> Self {
- DomExceptionNetworkError {
- msg: msg.to_string(),
- }
- }
-}
-
-impl fmt::Display for DomExceptionNetworkError {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- f.pad(&self.msg)
- }
-}
-
-impl std::error::Error for DomExceptionNetworkError {}
-
-pub fn get_network_error_class_name(e: &AnyError) -> Option<&'static str> {
- e.downcast_ref::<DomExceptionNetworkError>()
- .map(|_| "DOMExceptionNetworkError")
-}
-
// Needed so hyper can use non Send futures
#[derive(Clone)]
struct LocalExecutor;