diff options
author | Bartek IwaĆczuk <biwanczuk@gmail.com> | 2023-04-22 11:17:31 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-22 11:17:31 +0200 |
commit | 068228cb454d14a6f5943061a5a6569b9e395e23 (patch) | |
tree | 4b62ba7df771ac6dc32c70273e397e863f84dc05 /test_util/src/lib.rs | |
parent | a615eb3b56545960ec9684991442dd34a8b2abfc (diff) |
refactor: rewrite tests to "fastwebsockets" crate (#18781)
Migrating off of `tokio-tungstenite` crate.
---------
Co-authored-by: Divy Srivastava <dj.srivastava23@gmail.com>
Diffstat (limited to 'test_util/src/lib.rs')
-rw-r--r-- | test_util/src/lib.rs | 169 |
1 files changed, 110 insertions, 59 deletions
diff --git a/test_util/src/lib.rs b/test_util/src/lib.rs index 6a6614ad0..e647c0a4c 100644 --- a/test_util/src/lib.rs +++ b/test_util/src/lib.rs @@ -2,6 +2,7 @@ // Usage: provide a port as argument to run hyper_hello benchmark server // otherwise this starts multiple servers on many ports for test endpoints. use anyhow::anyhow; +use futures::Future; use futures::FutureExt; use futures::Stream; use futures::StreamExt; @@ -9,6 +10,7 @@ use hyper::header::HeaderValue; use hyper::server::Server; use hyper::service::make_service_fn; use hyper::service::service_fn; +use hyper::upgrade::Upgraded; use hyper::Body; use hyper::Request; use hyper::Response; @@ -49,7 +51,6 @@ use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls; use tokio_rustls::TlsAcceptor; -use tokio_tungstenite::accept_async; use url::Url; pub mod assertions; @@ -302,69 +303,128 @@ async fn basic_auth_redirect( Ok(resp) } +async fn echo_websocket_handler( + ws: fastwebsockets::WebSocket<Upgraded>, +) -> Result<(), anyhow::Error> { + let mut ws = fastwebsockets::FragmentCollector::new(ws); + + loop { + let frame = ws.read_frame().await.unwrap(); + match frame.opcode { + fastwebsockets::OpCode::Close => break, + fastwebsockets::OpCode::Text | fastwebsockets::OpCode::Binary => { + ws.write_frame(frame).await.unwrap(); + } + _ => {} + } + } + + Ok(()) +} + +type WsHandler = + fn( + fastwebsockets::WebSocket<Upgraded>, + ) -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>; + +fn spawn_ws_server<S>(stream: S, handler: WsHandler) +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, +{ + let srv_fn = service_fn(move |mut req: Request<Body>| async move { + let (response, upgrade_fut) = fastwebsockets::upgrade::upgrade(&mut req) + .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))?; + + tokio::spawn(async move { + let ws = upgrade_fut + .await + .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e)) + .unwrap(); + + if let Err(e) = handler(ws).await { + eprintln!("Error in websocket connection: {}", e); + } + }); + + Ok::<_, anyhow::Error>(response) + }); + + tokio::spawn(async move { + let conn_fut = hyper::server::conn::Http::new() + .serve_connection(stream, srv_fn) + .with_upgrades(); + + if let Err(e) = conn_fut.await { + eprintln!("websocket server error: {e:?}"); + } + }); +} + async fn run_ws_server(addr: &SocketAddr) { let listener = TcpListener::bind(addr).await.unwrap(); println!("ready: ws"); // Eye catcher for HttpServerCount while let Ok((stream, _addr)) = listener.accept().await { - tokio::spawn(async move { - let ws_stream_fut = accept_async(stream); - - let ws_stream = ws_stream_fut.await; - if let Ok(ws_stream) = ws_stream { - let (tx, rx) = ws_stream.split(); - rx.forward(tx) - .map(|result| { - if let Err(e) = result { - println!("websocket server error: {e:?}"); - } - }) - .await; - } - }); + spawn_ws_server(stream, |ws| Box::pin(echo_websocket_handler(ws))); } } +async fn ping_websocket_handler( + ws: fastwebsockets::WebSocket<Upgraded>, +) -> Result<(), anyhow::Error> { + use fastwebsockets::Frame; + use fastwebsockets::OpCode; + + let mut ws = fastwebsockets::FragmentCollector::new(ws); + + for i in 0..9 { + ws.write_frame(Frame::new(true, OpCode::Ping, None, vec![])) + .await + .unwrap(); + + let frame = ws.read_frame().await.unwrap(); + assert_eq!(frame.opcode, OpCode::Pong); + assert!(frame.payload.is_empty()); + + ws.write_frame(Frame::text(format!("hello {}", i).as_bytes().to_vec())) + .await + .unwrap(); + + let frame = ws.read_frame().await.unwrap(); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload, format!("hello {}", i).as_bytes()); + } + + ws.write_frame(fastwebsockets::Frame::close(1000, b"")) + .await + .unwrap(); + + Ok(()) +} + async fn run_ws_ping_server(addr: &SocketAddr) { let listener = TcpListener::bind(addr).await.unwrap(); println!("ready: ws"); // Eye catcher for HttpServerCount while let Ok((stream, _addr)) = listener.accept().await { - tokio::spawn(async move { - let ws_stream = accept_async(stream).await; - use futures::SinkExt; - use tokio_tungstenite::tungstenite::Message; - if let Ok(mut ws_stream) = ws_stream { - for i in 0..9 { - ws_stream.send(Message::Ping(vec![])).await.unwrap(); - - let msg = ws_stream.next().await.unwrap().unwrap(); - assert_eq!(msg, Message::Pong(vec![])); - - ws_stream - .send(Message::Text(format!("hello {}", i))) - .await - .unwrap(); - - let msg = ws_stream.next().await.unwrap().unwrap(); - assert_eq!(msg, Message::Text(format!("hello {}", i))); - } - - ws_stream.close(None).await.unwrap(); - } - }); + spawn_ws_server(stream, |ws| Box::pin(ping_websocket_handler(ws))); } } +async fn close_websocket_handler( + ws: fastwebsockets::WebSocket<Upgraded>, +) -> Result<(), anyhow::Error> { + let mut ws = fastwebsockets::FragmentCollector::new(ws); + + ws.write_frame(fastwebsockets::Frame::close_raw(vec![])) + .await + .unwrap(); + + Ok(()) +} + async fn run_ws_close_server(addr: &SocketAddr) { let listener = TcpListener::bind(addr).await.unwrap(); while let Ok((stream, _addr)) = listener.accept().await { - tokio::spawn(async move { - let ws_stream_fut = accept_async(stream); - - let ws_stream = ws_stream_fut.await; - if let Ok(mut ws_stream) = ws_stream { - ws_stream.close(None).await.unwrap(); - } - }); + spawn_ws_server(stream, |ws| Box::pin(close_websocket_handler(ws))); } } @@ -471,18 +531,9 @@ async fn run_wss_server(addr: &SocketAddr) { tokio::spawn(async move { match acceptor.accept(stream).await { Ok(tls_stream) => { - let ws_stream_fut = accept_async(tls_stream); - let ws_stream = ws_stream_fut.await; - if let Ok(ws_stream) = ws_stream { - let (tx, rx) = ws_stream.split(); - rx.forward(tx) - .map(|result| { - if let Err(e) = result { - println!("Websocket server error: {e:?}"); - } - }) - .await; - } + spawn_ws_server(tls_stream, |ws| { + Box::pin(echo_websocket_handler(ws)) + }); } Err(e) => { eprintln!("TLS accept error: {e:?}"); |