diff options
Diffstat (limited to 'test_util/src')
-rw-r--r-- | test_util/src/servers/ws.rs | 100 |
1 files changed, 51 insertions, 49 deletions
diff --git a/test_util/src/servers/ws.rs b/test_util/src/servers/ws.rs index f9f811910..d94ecb38d 100644 --- a/test_util/src/servers/ws.rs +++ b/test_util/src/servers/ws.rs @@ -2,26 +2,25 @@ use anyhow::anyhow; use bytes::Bytes; -use fastwebsockets::FragmentCollector; -use fastwebsockets::Frame; -use fastwebsockets::OpCode; -use fastwebsockets::Role; -use fastwebsockets::WebSocket; +use fastwebsockets_06::FragmentCollector; +use fastwebsockets_06::Frame; +use fastwebsockets_06::OpCode; +use fastwebsockets_06::Role; +use fastwebsockets_06::WebSocket; use futures::future::join3; use futures::future::poll_fn; use futures::Future; use futures::StreamExt; -use h2::server::Handshake; -use h2::server::SendResponse; -use h2::Reason; -use h2::RecvStream; -use hyper::service::service_fn; -use hyper::upgrade::Upgraded; -use hyper::Body; -use hyper::Method; -use hyper::Request; -use hyper::Response; -use hyper::StatusCode; +use h2_04::server::Handshake; +use h2_04::server::SendResponse; +use h2_04::Reason; +use h2_04::RecvStream; +use hyper1::upgrade::Upgraded; +use hyper1::Method; +use hyper1::Request; +use hyper1::Response; +use hyper1::StatusCode; +use hyper_util::rt::TokioIo; use pretty_assertions::assert_eq; use std::pin::Pin; use std::result::Result; @@ -71,7 +70,7 @@ pub async fn run_wss2_server(port: u16) { .await; while let Some(Ok(tls)) = tls.next().await { tokio::spawn(async move { - let mut h2 = h2::server::Builder::new(); + let mut h2 = h2_04::server::Builder::new(); h2.enable_connect_protocol(); // Using Bytes is pretty alloc-heavy but this is a test server let server: Handshake<_, Bytes> = h2.handshake(tls); @@ -100,15 +99,15 @@ pub async fn run_wss2_server(port: u16) { } async fn echo_websocket_handler( - ws: fastwebsockets::WebSocket<Upgraded>, + ws: fastwebsockets_06::WebSocket<TokioIo<Upgraded>>, ) -> Result<(), anyhow::Error> { - let mut ws = fastwebsockets::FragmentCollector::new(ws); + let mut ws = FragmentCollector::new(ws); loop { let frame = ws.read_frame().await.unwrap(); match frame.opcode { - fastwebsockets::OpCode::Close => break, - fastwebsockets::OpCode::Text | fastwebsockets::OpCode::Binary => { + OpCode::Close => break, + OpCode::Text | OpCode::Binary => { ws.write_frame(frame).await.unwrap(); } _ => {} @@ -120,37 +119,42 @@ async fn echo_websocket_handler( type WsHandler = fn( - fastwebsockets::WebSocket<Upgraded>, + fastwebsockets_06::WebSocket<TokioIo<Upgraded>>, ) -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send>>; fn spawn_ws_server<S>(stream: S, handler: WsHandler) where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { - let srv_fn = service_fn(move |mut req: Request<Body>| async move { - let (response, upgrade_fut) = fastwebsockets::upgrade::upgrade(&mut req) - .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e))?; + let service = hyper1::service::service_fn( + move |mut req: http_1::Request<hyper1::body::Incoming>| async move { + let (response, upgrade_fut) = + fastwebsockets_06::upgrade::upgrade(&mut req).map_err(|e| { + anyhow!("Error upgrading websocket connection: {}", e) + })?; - tokio::spawn(async move { - let ws = upgrade_fut - .await - .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e)) - .unwrap(); + tokio::spawn(async move { + let ws = upgrade_fut + .await + .map_err(|e| anyhow!("Error upgrading websocket connection: {}", e)) + .unwrap(); - if let Err(e) = handler(ws).await { - eprintln!("Error in websocket connection: {}", e); - } - }); + if let Err(e) = handler(ws).await { + eprintln!("Error in websocket connection: {}", e); + } + }); - Ok::<_, anyhow::Error>(response) - }); + Ok::<_, anyhow::Error>(response) + }, + ); + let io = TokioIo::new(stream); tokio::spawn(async move { - let conn_fut = hyper::server::conn::Http::new() - .serve_connection(stream, srv_fn) + let conn = hyper1::server::conn::http1::Builder::new() + .serve_connection(io, service) .with_upgrades(); - if let Err(e) = conn_fut.await { + if let Err(e) = conn.await { eprintln!("websocket server error: {e:?}"); } }); @@ -159,13 +163,13 @@ where async fn handle_wss_stream( recv: Request<RecvStream>, mut send: SendResponse<Bytes>, -) -> Result<(), h2::Error> { +) -> Result<(), h2_04::Error> { if recv.method() != Method::CONNECT { eprintln!("wss2: refusing non-CONNECT stream"); send.send_reset(Reason::REFUSED_STREAM); return Ok(()); } - let Some(protocol) = recv.extensions().get::<h2::ext::Protocol>() else { + let Some(protocol) = recv.extensions().get::<h2_04::ext::Protocol>() else { eprintln!("wss2: refusing no-:protocol stream"); send.send_reset(Reason::REFUSED_STREAM); return Ok(()); @@ -224,11 +228,11 @@ async fn handle_wss_stream( } async fn close_websocket_handler( - ws: fastwebsockets::WebSocket<Upgraded>, + ws: fastwebsockets_06::WebSocket<TokioIo<Upgraded>>, ) -> Result<(), anyhow::Error> { - let mut ws = fastwebsockets::FragmentCollector::new(ws); + let mut ws = FragmentCollector::new(ws); - ws.write_frame(fastwebsockets::Frame::close_raw(vec![].into())) + ws.write_frame(Frame::close_raw(vec![].into())) .await .unwrap(); @@ -236,9 +240,9 @@ async fn close_websocket_handler( } async fn ping_websocket_handler( - ws: fastwebsockets::WebSocket<Upgraded>, + ws: fastwebsockets_06::WebSocket<TokioIo<Upgraded>>, ) -> Result<(), anyhow::Error> { - let mut ws = fastwebsockets::FragmentCollector::new(ws); + let mut ws = FragmentCollector::new(ws); for i in 0..9 { ws.write_frame(Frame::new(true, OpCode::Ping, None, vec![].into())) @@ -260,9 +264,7 @@ async fn ping_websocket_handler( assert_eq!(frame.payload, format!("hello {}", i).as_bytes()); } - ws.write_frame(fastwebsockets::Frame::close(1000, b"")) - .await - .unwrap(); + ws.write_frame(Frame::close(1000, b"")).await.unwrap(); Ok(()) } |