summaryrefslogtreecommitdiff
path: root/test_util/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'test_util/src/lib.rs')
-rw-r--r--test_util/src/lib.rs141
1 files changed, 138 insertions, 3 deletions
diff --git a/test_util/src/lib.rs b/test_util/src/lib.rs
index 692a6a08c..d63186311 100644
--- a/test_util/src/lib.rs
+++ b/test_util/src/lib.rs
@@ -4,16 +4,28 @@
use anyhow::anyhow;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
+use bytes::Bytes;
use denokv_proto::datapath::AtomicWrite;
use denokv_proto::datapath::AtomicWriteOutput;
use denokv_proto::datapath::AtomicWriteStatus;
use denokv_proto::datapath::ReadRangeOutput;
use denokv_proto::datapath::SnapshotRead;
use denokv_proto::datapath::SnapshotReadOutput;
+use fastwebsockets::FragmentCollector;
+use fastwebsockets::Frame;
+use fastwebsockets::OpCode;
+use fastwebsockets::Role;
+use fastwebsockets::WebSocket;
+use futures::future::join3;
+use futures::future::poll_fn;
use futures::Future;
use futures::FutureExt;
use futures::Stream;
use futures::StreamExt;
+use h2::server::Handshake;
+use h2::server::SendResponse;
+use h2::Reason;
+use h2::RecvStream;
use hyper::header::HeaderValue;
use hyper::http;
use hyper::server::Server;
@@ -21,6 +33,7 @@ use hyper::service::make_service_fn;
use hyper::service::service_fn;
use hyper::upgrade::Upgraded;
use hyper::Body;
+use hyper::Method;
use hyper::Request;
use hyper::Response;
use hyper::StatusCode;
@@ -58,6 +71,7 @@ use std::sync::MutexGuard;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
+use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
@@ -105,6 +119,7 @@ const H2_ONLY_PORT: u16 = 5549;
const HTTPS_CLIENT_AUTH_PORT: u16 = 5552;
const WS_PORT: u16 = 4242;
const WSS_PORT: u16 = 4243;
+const WSS2_PORT: u16 = 4248;
const WS_CLOSE_PORT: u16 = 4244;
const WS_PING_PORT: u16 = 4245;
const H2_GRPC_PORT: u16 = 4246;
@@ -399,9 +414,6 @@ async fn run_ws_server(addr: &SocketAddr) {
async fn ping_websocket_handler(
ws: fastwebsockets::WebSocket<Upgraded>,
) -> Result<(), anyhow::Error> {
- use fastwebsockets::Frame;
- use fastwebsockets::OpCode;
-
let mut ws = fastwebsockets::FragmentCollector::new(ws);
for i in 0..9 {
@@ -458,6 +470,126 @@ async fn run_ws_close_server(addr: &SocketAddr) {
}
}
+async fn handle_wss_stream(
+ recv: Request<RecvStream>,
+ mut send: SendResponse<Bytes>,
+) -> Result<(), h2::Error> {
+ if recv.method() != Method::CONNECT {
+ eprintln!("wss2: refusing non-CONNECT stream");
+ send.send_reset(Reason::REFUSED_STREAM);
+ return Ok(());
+ }
+ let Some(protocol) = recv.extensions().get::<h2::ext::Protocol>() else {
+ eprintln!("wss2: refusing no-:protocol stream");
+ send.send_reset(Reason::REFUSED_STREAM);
+ return Ok(());
+ };
+ if protocol.as_str() != "websocket" && protocol.as_str() != "WebSocket" {
+ eprintln!("wss2: refusing non-websocket stream");
+ send.send_reset(Reason::REFUSED_STREAM);
+ return Ok(());
+ }
+ let mut body = recv.into_body();
+ let mut response = Response::new(());
+ *response.status_mut() = StatusCode::OK;
+ let mut resp = send.send_response(response, false)?;
+ // Use a duplex stream to talk to fastwebsockets because it's just faster to implement
+ let (a, b) = tokio::io::duplex(65536);
+ let f1 = tokio::spawn(tokio::task::unconstrained(async move {
+ let ws = WebSocket::after_handshake(a, Role::Server);
+ let mut ws = FragmentCollector::new(ws);
+ loop {
+ let frame = ws.read_frame().await.unwrap();
+ if frame.opcode == OpCode::Close {
+ break;
+ }
+ ws.write_frame(frame).await.unwrap();
+ }
+ }));
+ let (mut br, mut bw) = tokio::io::split(b);
+ let f2 = tokio::spawn(tokio::task::unconstrained(async move {
+ loop {
+ let Some(Ok(data)) = poll_fn(|cx| body.poll_data(cx)).await else {
+ return;
+ };
+ body.flow_control().release_capacity(data.len()).unwrap();
+ let Ok(_) = bw.write_all(&data).await else {
+ break;
+ };
+ }
+ }));
+ let f3 = tokio::spawn(tokio::task::unconstrained(async move {
+ loop {
+ let mut buf = [0; 65536];
+ let n = br.read(&mut buf).await.unwrap();
+ if n == 0 {
+ break;
+ }
+ resp.reserve_capacity(n);
+ poll_fn(|cx| resp.poll_capacity(cx)).await;
+ resp
+ .send_data(Bytes::copy_from_slice(&buf[0..n]), false)
+ .unwrap();
+ }
+ resp.send_data(Bytes::new(), true).unwrap();
+ }));
+ _ = join3(f1, f2, f3).await;
+ Ok(())
+}
+
+async fn run_wss2_server(addr: &SocketAddr) {
+ let cert_file = "tls/localhost.crt";
+ let key_file = "tls/localhost.key";
+ let ca_cert_file = "tls/RootCA.pem";
+
+ let tls_config = get_tls_config(
+ cert_file,
+ key_file,
+ ca_cert_file,
+ SupportedHttpVersions::Http2Only,
+ )
+ .await
+ .unwrap();
+ let tls_acceptor = TlsAcceptor::from(tls_config);
+
+ let listener = TcpListener::bind(addr).await.unwrap();
+ while let Ok((stream, _addr)) = listener.accept().await {
+ match tls_acceptor.accept(stream).await {
+ Ok(tls) => {
+ tokio::spawn(async move {
+ let mut h2 = h2::server::Builder::new();
+ h2.enable_connect_protocol();
+ // Using Bytes is pretty alloc-heavy but this is a test server
+ let server: Handshake<_, Bytes> = h2.handshake(tls);
+ let mut server = match server.await {
+ Ok(server) => server,
+ Err(e) => {
+ println!("Failed to handshake h2: {e:?}");
+ return;
+ }
+ };
+ loop {
+ let Some(conn) = server.accept().await else {
+ break;
+ };
+ let (recv, send) = match conn {
+ Ok(conn) => conn,
+ Err(e) => {
+ println!("Failed to accept a connection: {e:?}");
+ break;
+ }
+ };
+ tokio::spawn(handle_wss_stream(recv, send));
+ }
+ });
+ }
+ Err(e) => {
+ println!("Failed to accept TLS: {e:?}");
+ }
+ }
+ }
+}
+
#[derive(Default)]
enum SupportedHttpVersions {
#[default]
@@ -1933,6 +2065,8 @@ pub async fn run_all_servers() {
let wss_server_fut = run_wss_server(&wss_addr);
let ws_close_addr = SocketAddr::from(([127, 0, 0, 1], WS_CLOSE_PORT));
let ws_close_server_fut = run_ws_close_server(&ws_close_addr);
+ let wss2_addr = SocketAddr::from(([127, 0, 0, 1], WSS2_PORT));
+ let wss2_server_fut = run_wss2_server(&wss2_addr);
let tls_server_fut = run_tls_server();
let tls_client_auth_server_fut = run_tls_client_auth_server();
@@ -1952,6 +2086,7 @@ pub async fn run_all_servers() {
ws_server_fut,
ws_ping_server_fut,
wss_server_fut,
+ wss2_server_fut,
tls_server_fut,
tls_client_auth_server_fut,
ws_close_server_fut,