summaryrefslogtreecommitdiff
path: root/test_util/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'test_util/src/lib.rs')
-rw-r--r--test_util/src/lib.rs169
1 files changed, 110 insertions, 59 deletions
diff --git a/test_util/src/lib.rs b/test_util/src/lib.rs
index 6a6614ad0..e647c0a4c 100644
--- a/test_util/src/lib.rs
+++ b/test_util/src/lib.rs
@@ -2,6 +2,7 @@
// Usage: provide a port as argument to run hyper_hello benchmark server
// otherwise this starts multiple servers on many ports for test endpoints.
use anyhow::anyhow;
+use futures::Future;
use futures::FutureExt;
use futures::Stream;
use futures::StreamExt;
@@ -9,6 +10,7 @@ use hyper::header::HeaderValue;
use hyper::server::Server;
use hyper::service::make_service_fn;
use hyper::service::service_fn;
+use hyper::upgrade::Upgraded;
use hyper::Body;
use hyper::Request;
use hyper::Response;
@@ -49,7 +51,6 @@ use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio_rustls::rustls;
use tokio_rustls::TlsAcceptor;
-use tokio_tungstenite::accept_async;
use url::Url;
pub mod assertions;
@@ -302,69 +303,128 @@ async fn basic_auth_redirect(
Ok(resp)
}
+async fn echo_websocket_handler(
+ ws: fastwebsockets::WebSocket<Upgraded>,
+) -> Result<(), anyhow::Error> {
+ let mut ws = fastwebsockets::FragmentCollector::new(ws);
+
+ loop {
+ let frame = ws.read_frame().await.unwrap();
+ match frame.opcode {
+ fastwebsockets::OpCode::Close => break,
+ fastwebsockets::OpCode::Text | fastwebsockets::OpCode::Binary => {
+ ws.write_frame(frame).await.unwrap();
+ }
+ _ => {}
+ }
+ }
+
+ Ok(())
+}
+
+type WsHandler =
+ fn(
+ fastwebsockets::WebSocket<Upgraded>,
+ ) -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>;
+
+fn spawn_ws_server<S>(stream: S, handler: WsHandler)
+where
+ S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
+{
+ let srv_fn = service_fn(move |mut req: Request<Body>| async move {
+ let (response, upgrade_fut) = fastwebsockets::upgrade::upgrade(&mut req)
+ .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))?;
+
+ tokio::spawn(async move {
+ let ws = upgrade_fut
+ .await
+ .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))
+ .unwrap();
+
+ if let Err(e) = handler(ws).await {
+ eprintln!("Error in websocket connection: {}", e);
+ }
+ });
+
+ Ok::<_, anyhow::Error>(response)
+ });
+
+ tokio::spawn(async move {
+ let conn_fut = hyper::server::conn::Http::new()
+ .serve_connection(stream, srv_fn)
+ .with_upgrades();
+
+ if let Err(e) = conn_fut.await {
+ eprintln!("websocket server error: {e:?}");
+ }
+ });
+}
+
async fn run_ws_server(addr: &SocketAddr) {
let listener = TcpListener::bind(addr).await.unwrap();
println!("ready: ws"); // Eye catcher for HttpServerCount
while let Ok((stream, _addr)) = listener.accept().await {
- tokio::spawn(async move {
- let ws_stream_fut = accept_async(stream);
-
- let ws_stream = ws_stream_fut.await;
- if let Ok(ws_stream) = ws_stream {
- let (tx, rx) = ws_stream.split();
- rx.forward(tx)
- .map(|result| {
- if let Err(e) = result {
- println!("websocket server error: {e:?}");
- }
- })
- .await;
- }
- });
+ spawn_ws_server(stream, |ws| Box::pin(echo_websocket_handler(ws)));
}
}
+async fn ping_websocket_handler(
+ ws: fastwebsockets::WebSocket<Upgraded>,
+) -> Result<(), anyhow::Error> {
+ use fastwebsockets::Frame;
+ use fastwebsockets::OpCode;
+
+ let mut ws = fastwebsockets::FragmentCollector::new(ws);
+
+ for i in 0..9 {
+ ws.write_frame(Frame::new(true, OpCode::Ping, None, vec![]))
+ .await
+ .unwrap();
+
+ let frame = ws.read_frame().await.unwrap();
+ assert_eq!(frame.opcode, OpCode::Pong);
+ assert!(frame.payload.is_empty());
+
+ ws.write_frame(Frame::text(format!("hello {}", i).as_bytes().to_vec()))
+ .await
+ .unwrap();
+
+ let frame = ws.read_frame().await.unwrap();
+ assert_eq!(frame.opcode, OpCode::Text);
+ assert_eq!(frame.payload, format!("hello {}", i).as_bytes());
+ }
+
+ ws.write_frame(fastwebsockets::Frame::close(1000, b""))
+ .await
+ .unwrap();
+
+ Ok(())
+}
+
async fn run_ws_ping_server(addr: &SocketAddr) {
let listener = TcpListener::bind(addr).await.unwrap();
println!("ready: ws"); // Eye catcher for HttpServerCount
while let Ok((stream, _addr)) = listener.accept().await {
- tokio::spawn(async move {
- let ws_stream = accept_async(stream).await;
- use futures::SinkExt;
- use tokio_tungstenite::tungstenite::Message;
- if let Ok(mut ws_stream) = ws_stream {
- for i in 0..9 {
- ws_stream.send(Message::Ping(vec![])).await.unwrap();
-
- let msg = ws_stream.next().await.unwrap().unwrap();
- assert_eq!(msg, Message::Pong(vec![]));
-
- ws_stream
- .send(Message::Text(format!("hello {}", i)))
- .await
- .unwrap();
-
- let msg = ws_stream.next().await.unwrap().unwrap();
- assert_eq!(msg, Message::Text(format!("hello {}", i)));
- }
-
- ws_stream.close(None).await.unwrap();
- }
- });
+ spawn_ws_server(stream, |ws| Box::pin(ping_websocket_handler(ws)));
}
}
+async fn close_websocket_handler(
+ ws: fastwebsockets::WebSocket<Upgraded>,
+) -> Result<(), anyhow::Error> {
+ let mut ws = fastwebsockets::FragmentCollector::new(ws);
+
+ ws.write_frame(fastwebsockets::Frame::close_raw(vec![]))
+ .await
+ .unwrap();
+
+ Ok(())
+}
+
async fn run_ws_close_server(addr: &SocketAddr) {
let listener = TcpListener::bind(addr).await.unwrap();
while let Ok((stream, _addr)) = listener.accept().await {
- tokio::spawn(async move {
- let ws_stream_fut = accept_async(stream);
-
- let ws_stream = ws_stream_fut.await;
- if let Ok(mut ws_stream) = ws_stream {
- ws_stream.close(None).await.unwrap();
- }
- });
+ spawn_ws_server(stream, |ws| Box::pin(close_websocket_handler(ws)));
}
}
@@ -471,18 +531,9 @@ async fn run_wss_server(addr: &SocketAddr) {
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
- let ws_stream_fut = accept_async(tls_stream);
- let ws_stream = ws_stream_fut.await;
- if let Ok(ws_stream) = ws_stream {
- let (tx, rx) = ws_stream.split();
- rx.forward(tx)
- .map(|result| {
- if let Err(e) = result {
- println!("Websocket server error: {e:?}");
- }
- })
- .await;
- }
+ spawn_ws_server(tls_stream, |ws| {
+ Box::pin(echo_websocket_handler(ws))
+ });
}
Err(e) => {
eprintln!("TLS accept error: {e:?}");