diff options
Diffstat (limited to 'test_util/src')
-rw-r--r-- | test_util/src/lib.rs | 141 |
1 files changed, 138 insertions, 3 deletions
diff --git a/test_util/src/lib.rs b/test_util/src/lib.rs index 692a6a08c..d63186311 100644 --- a/test_util/src/lib.rs +++ b/test_util/src/lib.rs @@ -4,16 +4,28 @@ use anyhow::anyhow; use base64::prelude::BASE64_STANDARD; use base64::Engine; +use bytes::Bytes; use denokv_proto::datapath::AtomicWrite; use denokv_proto::datapath::AtomicWriteOutput; use denokv_proto::datapath::AtomicWriteStatus; use denokv_proto::datapath::ReadRangeOutput; use denokv_proto::datapath::SnapshotRead; use denokv_proto::datapath::SnapshotReadOutput; +use fastwebsockets::FragmentCollector; +use fastwebsockets::Frame; +use fastwebsockets::OpCode; +use fastwebsockets::Role; +use fastwebsockets::WebSocket; +use futures::future::join3; +use futures::future::poll_fn; use futures::Future; use futures::FutureExt; use futures::Stream; use futures::StreamExt; +use h2::server::Handshake; +use h2::server::SendResponse; +use h2::Reason; +use h2::RecvStream; use hyper::header::HeaderValue; use hyper::http; use hyper::server::Server; @@ -21,6 +33,7 @@ use hyper::service::make_service_fn; use hyper::service::service_fn; use hyper::upgrade::Upgraded; use hyper::Body; +use hyper::Method; use hyper::Request; use hyper::Response; use hyper::StatusCode; @@ -58,6 +71,7 @@ use std::sync::MutexGuard; use std::task::Context; use std::task::Poll; use std::time::Duration; +use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; use tokio::net::TcpStream; @@ -105,6 +119,7 @@ const H2_ONLY_PORT: u16 = 5549; const HTTPS_CLIENT_AUTH_PORT: u16 = 5552; const WS_PORT: u16 = 4242; const WSS_PORT: u16 = 4243; +const WSS2_PORT: u16 = 4248; const WS_CLOSE_PORT: u16 = 4244; const WS_PING_PORT: u16 = 4245; const H2_GRPC_PORT: u16 = 4246; @@ -399,9 +414,6 @@ async fn run_ws_server(addr: &SocketAddr) { async fn ping_websocket_handler( ws: fastwebsockets::WebSocket<Upgraded>, ) -> Result<(), anyhow::Error> { - use fastwebsockets::Frame; - use fastwebsockets::OpCode; - let mut ws = fastwebsockets::FragmentCollector::new(ws); for i in 0..9 { @@ -458,6 +470,126 @@ async fn run_ws_close_server(addr: &SocketAddr) { } } +async fn handle_wss_stream( + recv: Request<RecvStream>, + mut send: SendResponse<Bytes>, +) -> Result<(), h2::Error> { + if recv.method() != Method::CONNECT { + eprintln!("wss2: refusing non-CONNECT stream"); + send.send_reset(Reason::REFUSED_STREAM); + return Ok(()); + } + let Some(protocol) = recv.extensions().get::<h2::ext::Protocol>() else { + eprintln!("wss2: refusing no-:protocol stream"); + send.send_reset(Reason::REFUSED_STREAM); + return Ok(()); + }; + if protocol.as_str() != "websocket" && protocol.as_str() != "WebSocket" { + eprintln!("wss2: refusing non-websocket stream"); + send.send_reset(Reason::REFUSED_STREAM); + return Ok(()); + } + let mut body = recv.into_body(); + let mut response = Response::new(()); + *response.status_mut() = StatusCode::OK; + let mut resp = send.send_response(response, false)?; + // Use a duplex stream to talk to fastwebsockets because it's just faster to implement + let (a, b) = tokio::io::duplex(65536); + let f1 = tokio::spawn(tokio::task::unconstrained(async move { + let ws = WebSocket::after_handshake(a, Role::Server); + let mut ws = FragmentCollector::new(ws); + loop { + let frame = ws.read_frame().await.unwrap(); + if frame.opcode == OpCode::Close { + break; + } + ws.write_frame(frame).await.unwrap(); + } + })); + let (mut br, mut bw) = tokio::io::split(b); + let f2 = tokio::spawn(tokio::task::unconstrained(async move { + loop { + let Some(Ok(data)) = poll_fn(|cx| body.poll_data(cx)).await else { + return; + }; + body.flow_control().release_capacity(data.len()).unwrap(); + let Ok(_) = bw.write_all(&data).await else { + break; + }; + } + })); + let f3 = tokio::spawn(tokio::task::unconstrained(async move { + loop { + let mut buf = [0; 65536]; + let n = br.read(&mut buf).await.unwrap(); + if n == 0 { + break; + } + resp.reserve_capacity(n); + poll_fn(|cx| resp.poll_capacity(cx)).await; + resp + .send_data(Bytes::copy_from_slice(&buf[0..n]), false) + .unwrap(); + } + resp.send_data(Bytes::new(), true).unwrap(); + })); + _ = join3(f1, f2, f3).await; + Ok(()) +} + +async fn run_wss2_server(addr: &SocketAddr) { + let cert_file = "tls/localhost.crt"; + let key_file = "tls/localhost.key"; + let ca_cert_file = "tls/RootCA.pem"; + + let tls_config = get_tls_config( + cert_file, + key_file, + ca_cert_file, + SupportedHttpVersions::Http2Only, + ) + .await + .unwrap(); + let tls_acceptor = TlsAcceptor::from(tls_config); + + let listener = TcpListener::bind(addr).await.unwrap(); + while let Ok((stream, _addr)) = listener.accept().await { + match tls_acceptor.accept(stream).await { + Ok(tls) => { + tokio::spawn(async move { + let mut h2 = h2::server::Builder::new(); + h2.enable_connect_protocol(); + // Using Bytes is pretty alloc-heavy but this is a test server + let server: Handshake<_, Bytes> = h2.handshake(tls); + let mut server = match server.await { + Ok(server) => server, + Err(e) => { + println!("Failed to handshake h2: {e:?}"); + return; + } + }; + loop { + let Some(conn) = server.accept().await else { + break; + }; + let (recv, send) = match conn { + Ok(conn) => conn, + Err(e) => { + println!("Failed to accept a connection: {e:?}"); + break; + } + }; + tokio::spawn(handle_wss_stream(recv, send)); + } + }); + } + Err(e) => { + println!("Failed to accept TLS: {e:?}"); + } + } + } +} + #[derive(Default)] enum SupportedHttpVersions { #[default] @@ -1933,6 +2065,8 @@ pub async fn run_all_servers() { let wss_server_fut = run_wss_server(&wss_addr); let ws_close_addr = SocketAddr::from(([127, 0, 0, 1], WS_CLOSE_PORT)); let ws_close_server_fut = run_ws_close_server(&ws_close_addr); + let wss2_addr = SocketAddr::from(([127, 0, 0, 1], WSS2_PORT)); + let wss2_server_fut = run_wss2_server(&wss2_addr); let tls_server_fut = run_tls_server(); let tls_client_auth_server_fut = run_tls_client_auth_server(); @@ -1952,6 +2086,7 @@ pub async fn run_all_servers() { ws_server_fut, ws_ping_server_fut, wss_server_fut, + wss2_server_fut, tls_server_fut, tls_client_auth_server_fut, ws_close_server_fut, |