summaryrefslogtreecommitdiff
path: root/tests/util/server/src/servers/ws.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/server/src/servers/ws.rs')
-rw-r--r--tests/util/server/src/servers/ws.rs268
1 files changed, 268 insertions, 0 deletions
diff --git a/tests/util/server/src/servers/ws.rs b/tests/util/server/src/servers/ws.rs
new file mode 100644
index 000000000..815119b6a
--- /dev/null
+++ b/tests/util/server/src/servers/ws.rs
@@ -0,0 +1,268 @@
+// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
+
+use anyhow::anyhow;
+use bytes::Bytes;
+use fastwebsockets::FragmentCollector;
+use fastwebsockets::Frame;
+use fastwebsockets::OpCode;
+use fastwebsockets::Role;
+use fastwebsockets::WebSocket;
+use futures::future::join3;
+use futures::future::poll_fn;
+use futures::Future;
+use futures::StreamExt;
+use h2::server::Handshake;
+use h2::server::SendResponse;
+use h2::Reason;
+use h2::RecvStream;
+use hyper::upgrade::Upgraded;
+use hyper::Method;
+use hyper::Request;
+use hyper::Response;
+use hyper::StatusCode;
+use hyper_util::rt::TokioIo;
+use pretty_assertions::assert_eq;
+use std::pin::Pin;
+use std::result::Result;
+use tokio::io::AsyncReadExt;
+use tokio::io::AsyncWriteExt;
+
+use super::get_tcp_listener_stream;
+use super::get_tls_listener_stream;
+use super::SupportedHttpVersions;
+
+pub async fn run_ws_server(port: u16) {
+ let mut tcp = get_tcp_listener_stream("ws", port).await;
+ while let Some(Ok(stream)) = tcp.next().await {
+ spawn_ws_server(stream, |ws| Box::pin(echo_websocket_handler(ws)));
+ }
+}
+
+pub async fn run_ws_ping_server(port: u16) {
+ let mut tcp = get_tcp_listener_stream("ws (ping)", port).await;
+ while let Some(Ok(stream)) = tcp.next().await {
+ spawn_ws_server(stream, |ws| Box::pin(ping_websocket_handler(ws)));
+ }
+}
+
+pub async fn run_wss_server(port: u16) {
+ let mut tls = get_tls_listener_stream("wss", port, Default::default()).await;
+ while let Some(Ok(tls_stream)) = tls.next().await {
+ tokio::spawn(async move {
+ spawn_ws_server(tls_stream, |ws| Box::pin(echo_websocket_handler(ws)));
+ });
+ }
+}
+
+pub async fn run_ws_close_server(port: u16) {
+ let mut tcp = get_tcp_listener_stream("ws (close)", port).await;
+ while let Some(Ok(stream)) = tcp.next().await {
+ spawn_ws_server(stream, |ws| Box::pin(close_websocket_handler(ws)));
+ }
+}
+
+pub async fn run_wss2_server(port: u16) {
+ let mut tls = get_tls_listener_stream(
+ "wss2 (tls)",
+ port,
+ SupportedHttpVersions::Http2Only,
+ )
+ .await;
+ while let Some(Ok(tls)) = tls.next().await {
+ tokio::spawn(async move {
+ let mut h2 = h2::server::Builder::new();
+ h2.enable_connect_protocol();
+ // Using Bytes is pretty alloc-heavy but this is a test server
+ let server: Handshake<_, Bytes> = h2.handshake(tls);
+ let mut server = match server.await {
+ Ok(server) => server,
+ Err(e) => {
+ println!("Failed to handshake h2: {e:?}");
+ return;
+ }
+ };
+ loop {
+ let Some(conn) = server.accept().await else {
+ break;
+ };
+ let (recv, send) = match conn {
+ Ok(conn) => conn,
+ Err(e) => {
+ println!("Failed to accept a connection: {e:?}");
+ break;
+ }
+ };
+ tokio::spawn(handle_wss_stream(recv, send));
+ }
+ });
+ }
+}
+
+async fn echo_websocket_handler(
+ ws: fastwebsockets::WebSocket<TokioIo<Upgraded>>,
+) -> Result<(), anyhow::Error> {
+ let mut ws = FragmentCollector::new(ws);
+
+ loop {
+ let frame = ws.read_frame().await.unwrap();
+ match frame.opcode {
+ OpCode::Close => break,
+ OpCode::Text | OpCode::Binary => {
+ ws.write_frame(frame).await.unwrap();
+ }
+ _ => {}
+ }
+ }
+
+ Ok(())
+}
+
+type WsHandler =
+ fn(
+ fastwebsockets::WebSocket<TokioIo<Upgraded>>,
+ ) -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
+
+fn spawn_ws_server<S>(stream: S, handler: WsHandler)
+where
+ S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
+{
+ let service = hyper::service::service_fn(
+ move |mut req: http::Request<hyper::body::Incoming>| async move {
+ let (response, upgrade_fut) = fastwebsockets::upgrade::upgrade(&mut req)
+ .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))?;
+
+ tokio::spawn(async move {
+ let ws = upgrade_fut
+ .await
+ .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))
+ .unwrap();
+
+ if let Err(e) = handler(ws).await {
+ eprintln!("Error in websocket connection: {}", e);
+ }
+ });
+
+ Ok::<_, anyhow::Error>(response)
+ },
+ );
+
+ let io = TokioIo::new(stream);
+ tokio::spawn(async move {
+ let conn = hyper::server::conn::http1::Builder::new()
+ .serve_connection(io, service)
+ .with_upgrades();
+
+ if let Err(e) = conn.await {
+ eprintln!("websocket server error: {e:?}");
+ }
+ });
+}
+
+async fn handle_wss_stream(
+ recv: Request<RecvStream>,
+ mut send: SendResponse<Bytes>,
+) -> Result<(), h2::Error> {
+ if recv.method() != Method::CONNECT {
+ eprintln!("wss2: refusing non-CONNECT stream");
+ send.send_reset(Reason::REFUSED_STREAM);
+ return Ok(());
+ }
+ let Some(protocol) = recv.extensions().get::<h2::ext::Protocol>() else {
+ eprintln!("wss2: refusing no-:protocol stream");
+ send.send_reset(Reason::REFUSED_STREAM);
+ return Ok(());
+ };
+ if protocol.as_str() != "websocket" && protocol.as_str() != "WebSocket" {
+ eprintln!("wss2: refusing non-websocket stream");
+ send.send_reset(Reason::REFUSED_STREAM);
+ return Ok(());
+ }
+ let mut body = recv.into_body();
+ let mut response = Response::new(());
+ *response.status_mut() = StatusCode::OK;
+ let mut resp = send.send_response(response, false)?;
+ // Use a duplex stream to talk to fastwebsockets because it's just faster to implement
+ let (a, b) = tokio::io::duplex(65536);
+ let f1 = tokio::spawn(tokio::task::unconstrained(async move {
+ let ws = WebSocket::after_handshake(a, Role::Server);
+ let mut ws = FragmentCollector::new(ws);
+ loop {
+ let frame = ws.read_frame().await.unwrap();
+ if frame.opcode == OpCode::Close {
+ break;
+ }
+ ws.write_frame(frame).await.unwrap();
+ }
+ }));
+ let (mut br, mut bw) = tokio::io::split(b);
+ let f2 = tokio::spawn(tokio::task::unconstrained(async move {
+ loop {
+ let Some(Ok(data)) = poll_fn(|cx| body.poll_data(cx)).await else {
+ return;
+ };
+ body.flow_control().release_capacity(data.len()).unwrap();
+ let Ok(_) = bw.write_all(&data).await else {
+ break;
+ };
+ }
+ }));
+ let f3 = tokio::spawn(tokio::task::unconstrained(async move {
+ loop {
+ let mut buf = [0; 65536];
+ let n = br.read(&mut buf).await.unwrap();
+ if n == 0 {
+ break;
+ }
+ resp.reserve_capacity(n);
+ poll_fn(|cx| resp.poll_capacity(cx)).await;
+ resp
+ .send_data(Bytes::copy_from_slice(&buf[0..n]), false)
+ .unwrap();
+ }
+ resp.send_data(Bytes::new(), true).unwrap();
+ }));
+ _ = join3(f1, f2, f3).await;
+ Ok(())
+}
+
+async fn close_websocket_handler(
+ ws: fastwebsockets::WebSocket<TokioIo<Upgraded>>,
+) -> Result<(), anyhow::Error> {
+ let mut ws = FragmentCollector::new(ws);
+
+ ws.write_frame(Frame::close_raw(vec![].into()))
+ .await
+ .unwrap();
+
+ Ok(())
+}
+
+async fn ping_websocket_handler(
+ ws: fastwebsockets::WebSocket<TokioIo<Upgraded>>,
+) -> Result<(), anyhow::Error> {
+ let mut ws = FragmentCollector::new(ws);
+
+ for i in 0..9 {
+ ws.write_frame(Frame::new(true, OpCode::Ping, None, vec![].into()))
+ .await
+ .unwrap();
+
+ let frame = ws.read_frame().await.unwrap();
+ assert_eq!(frame.opcode, OpCode::Pong);
+ assert!(frame.payload.is_empty());
+
+ ws.write_frame(Frame::text(
+ format!("hello {}", i).as_bytes().to_vec().into(),
+ ))
+ .await
+ .unwrap();
+
+ let frame = ws.read_frame().await.unwrap();
+ assert_eq!(frame.opcode, OpCode::Text);
+ assert_eq!(frame.payload, format!("hello {}", i).as_bytes());
+ }
+
+ ws.write_frame(Frame::close(1000, b"")).await.unwrap();
+
+ Ok(())
+}