diff options
author | Matt Mastracci <matthew@mastracci.com> | 2023-11-08 13:00:29 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-08 13:00:29 -0700 |
commit | 02c5f49a7aab7b8cfe5ad3b282e6668a1aecddbb (patch) | |
tree | 64a9d9c25002bb9769457aa809ee9e2db57bf9aa /test_util/src | |
parent | 5e82fce0a0051d694ab14467c120a1578c86bb42 (diff) |
chore: refactor test_server and move to rustls-tokio-stream (#21117)
Remove tokio-rustls as a direct dependency of Deno and refactor
test_server to reduce code duplication.
All tcp and tls listener paths go through the same streams now, with the
exception of the simpler Hyper http-only handlers (those can be done in
a later follow-up).
Minor bugs fixed:
- gRPC server should only serve h2
- WebSocket over http/2 had a port overlap
- Restored missing eye-catchers for some servers (still missing on Hyper
ones)
Diffstat (limited to 'test_util/src')
-rw-r--r-- | test_util/src/https.rs | 133 | ||||
-rw-r--r-- | test_util/src/lib.rs | 674 |
2 files changed, 316 insertions, 491 deletions
diff --git a/test_util/src/https.rs b/test_util/src/https.rs new file mode 100644 index 000000000..8793e3c37 --- /dev/null +++ b/test_util/src/https.rs @@ -0,0 +1,133 @@ +// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license. +use anyhow::anyhow; +use futures::Stream; +use futures::StreamExt; +use rustls::Certificate; +use rustls::PrivateKey; +use rustls_tokio_stream::rustls; +use rustls_tokio_stream::TlsStream; +use std::io; +use std::num::NonZeroUsize; +use std::result::Result; +use std::sync::Arc; +use tokio::net::TcpStream; + +use crate::get_tcp_listener_stream; +use crate::testdata_path; + +pub const TLS_BUFFER_SIZE: Option<NonZeroUsize> = NonZeroUsize::new(65536); + +#[derive(Default)] +pub enum SupportedHttpVersions { + #[default] + All, + Http1Only, + Http2Only, +} + +pub fn get_tls_listener_stream_from_tcp( + tls_config: Arc<rustls::ServerConfig>, + mut tcp: impl Stream<Item = Result<TcpStream, std::io::Error>> + Unpin + 'static, +) -> impl Stream<Item = Result<TlsStream, std::io::Error>> + Unpin { + async_stream::stream! { + while let Some(result) = tcp.next().await { + match result { + Ok(tcp) => yield Ok(TlsStream::new_server_side(tcp, tls_config.clone(), TLS_BUFFER_SIZE)), + Err(e) => yield Err(e), + }; + } + }.boxed_local() +} + +pub async fn get_tls_listener_stream( + name: &'static str, + port: u16, + http: SupportedHttpVersions, +) -> impl Stream<Item = Result<TlsStream, std::io::Error>> + Unpin { + let cert_file = "tls/localhost.crt"; + let key_file = "tls/localhost.key"; + let ca_cert_file = "tls/RootCA.pem"; + let tls_config = get_tls_config(cert_file, key_file, ca_cert_file, http) + .await + .unwrap(); + + let tcp = get_tcp_listener_stream(name, port).await; + get_tls_listener_stream_from_tcp(tls_config, tcp) +} + +pub async fn get_tls_config( + cert: &str, + key: &str, + ca: &str, + http_versions: SupportedHttpVersions, +) -> io::Result<Arc<rustls::ServerConfig>> { + let cert_path = testdata_path().join(cert); + let key_path = testdata_path().join(key); + let ca_path = testdata_path().join(ca); + + let cert_file = std::fs::File::open(cert_path)?; + let key_file = std::fs::File::open(key_path)?; + let ca_file = std::fs::File::open(ca_path)?; + + let certs: Vec<Certificate> = { + let mut cert_reader = io::BufReader::new(cert_file); + rustls_pemfile::certs(&mut cert_reader) + .unwrap() + .into_iter() + .map(Certificate) + .collect() + }; + + let mut ca_cert_reader = io::BufReader::new(ca_file); + let ca_cert = rustls_pemfile::certs(&mut ca_cert_reader) + .expect("Cannot load CA certificate") + .remove(0); + + let mut key_reader = io::BufReader::new(key_file); + let key = { + let pkcs8_key = rustls_pemfile::pkcs8_private_keys(&mut key_reader) + .expect("Cannot load key file"); + let rsa_key = rustls_pemfile::rsa_private_keys(&mut key_reader) + .expect("Cannot load key file"); + if !pkcs8_key.is_empty() { + Some(pkcs8_key[0].clone()) + } else if !rsa_key.is_empty() { + Some(rsa_key[0].clone()) + } else { + None + } + }; + + match key { + Some(key) => { + let mut root_cert_store = rustls::RootCertStore::empty(); + root_cert_store.add(&rustls::Certificate(ca_cert)).unwrap(); + + // Allow (but do not require) client authentication. + + let mut config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(Arc::new( + rustls::server::AllowAnyAnonymousOrAuthenticatedClient::new( + root_cert_store, + ), + )) + .with_single_cert(certs, PrivateKey(key)) + .map_err(|e| anyhow!("Error setting cert: {:?}", e)) + .unwrap(); + + match http_versions { + SupportedHttpVersions::All => { + config.alpn_protocols = vec!["h2".into(), "http/1.1".into()]; + } + SupportedHttpVersions::Http1Only => {} + SupportedHttpVersions::Http2Only => { + config.alpn_protocols = vec!["h2".into()]; + } + } + + Ok(Arc::new(config)) + } + None => Err(io::Error::new(io::ErrorKind::Other, "Cannot find key")), + } +} diff --git a/test_util/src/lib.rs b/test_util/src/lib.rs index d63186311..e004f1474 100644 --- a/test_util/src/lib.rs +++ b/test_util/src/lib.rs @@ -26,6 +26,8 @@ use h2::server::Handshake; use h2::server::SendResponse; use h2::Reason; use h2::RecvStream; +use https::get_tls_listener_stream; +use https::SupportedHttpVersions; use hyper::header::HeaderValue; use hyper::http; use hyper::server::Server; @@ -43,15 +45,13 @@ use pretty_assertions::assert_eq; use prost::Message; use pty::Pty; use regex::Regex; -use rustls::Certificate; -use rustls::PrivateKey; +use rustls_tokio_stream::TlsStream; use serde::Serialize; use std::collections::HashMap; use std::convert::Infallible; use std::env; use std::io; use std::io::Write; -use std::mem::replace; use std::net::Ipv6Addr; use std::net::SocketAddr; use std::net::SocketAddrV6; @@ -65,7 +65,6 @@ use std::process::Command; use std::process::Output; use std::process::Stdio; use std::result::Result; -use std::sync::Arc; use std::sync::Mutex; use std::sync::MutexGuard; use std::task::Context; @@ -73,17 +72,15 @@ use std::task::Poll; use std::time::Duration; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; -use tokio::net::TcpListener; use tokio::net::TcpStream; -use tokio_rustls::rustls; -use tokio_rustls::server::TlsStream; -use tokio_rustls::TlsAcceptor; +use tokio::task::LocalSet; use url::Url; pub mod assertions; mod builders; pub mod factory; mod fs; +mod https; pub mod lsp; mod npm; pub mod pty; @@ -119,7 +116,7 @@ const H2_ONLY_PORT: u16 = 5549; const HTTPS_CLIENT_AUTH_PORT: u16 = 5552; const WS_PORT: u16 = 4242; const WSS_PORT: u16 = 4243; -const WSS2_PORT: u16 = 4248; +const WSS2_PORT: u16 = 4249; const WS_CLOSE_PORT: u16 = 4244; const WS_PING_PORT: u16 = 4245; const H2_GRPC_PORT: u16 = 4246; @@ -403,10 +400,9 @@ where }); } -async fn run_ws_server(addr: &SocketAddr) { - let listener = TcpListener::bind(addr).await.unwrap(); - println!("ready: ws"); // Eye catcher for HttpServerCount - while let Ok((stream, _addr)) = listener.accept().await { +async fn run_ws_server(port: u16) { + let mut tcp = get_tcp_listener_stream("ws", port).await; + while let Some(Ok(stream)) = tcp.next().await { spawn_ws_server(stream, |ws| Box::pin(echo_websocket_handler(ws))); } } @@ -443,10 +439,9 @@ async fn ping_websocket_handler( Ok(()) } -async fn run_ws_ping_server(addr: &SocketAddr) { - let listener = TcpListener::bind(addr).await.unwrap(); - println!("ready: ws"); // Eye catcher for HttpServerCount - while let Ok((stream, _addr)) = listener.accept().await { +async fn run_ws_ping_server(port: u16) { + let mut tcp = get_tcp_listener_stream("ws (ping)", port).await; + while let Some(Ok(stream)) = tcp.next().await { spawn_ws_server(stream, |ws| Box::pin(ping_websocket_handler(ws))); } } @@ -463,9 +458,9 @@ async fn close_websocket_handler( Ok(()) } -async fn run_ws_close_server(addr: &SocketAddr) { - let listener = TcpListener::bind(addr).await.unwrap(); - while let Ok((stream, _addr)) = listener.accept().await { +async fn run_ws_close_server(port: u16) { + let mut tcp = get_tcp_listener_stream("ws (close)", port).await; + while let Some(Ok(stream)) = tcp.next().await { spawn_ws_server(stream, |ws| Box::pin(close_websocket_handler(ws))); } } @@ -537,204 +532,73 @@ async fn handle_wss_stream( Ok(()) } -async fn run_wss2_server(addr: &SocketAddr) { - let cert_file = "tls/localhost.crt"; - let key_file = "tls/localhost.key"; - let ca_cert_file = "tls/RootCA.pem"; - - let tls_config = get_tls_config( - cert_file, - key_file, - ca_cert_file, +async fn run_wss2_server(port: u16) { + let mut tls = get_tls_listener_stream( + "wss2 (tls)", + port, SupportedHttpVersions::Http2Only, ) - .await - .unwrap(); - let tls_acceptor = TlsAcceptor::from(tls_config); - - let listener = TcpListener::bind(addr).await.unwrap(); - while let Ok((stream, _addr)) = listener.accept().await { - match tls_acceptor.accept(stream).await { - Ok(tls) => { - tokio::spawn(async move { - let mut h2 = h2::server::Builder::new(); - h2.enable_connect_protocol(); - // Using Bytes is pretty alloc-heavy but this is a test server - let server: Handshake<_, Bytes> = h2.handshake(tls); - let mut server = match server.await { - Ok(server) => server, - Err(e) => { - println!("Failed to handshake h2: {e:?}"); - return; - } - }; - loop { - let Some(conn) = server.accept().await else { - break; - }; - let (recv, send) = match conn { - Ok(conn) => conn, - Err(e) => { - println!("Failed to accept a connection: {e:?}"); - break; - } - }; - tokio::spawn(handle_wss_stream(recv, send)); - } - }); - } - Err(e) => { - println!("Failed to accept TLS: {e:?}"); - } - } - } -} - -#[derive(Default)] -enum SupportedHttpVersions { - #[default] - All, - Http1Only, - Http2Only, -} - -async fn get_tls_config( - cert: &str, - key: &str, - ca: &str, - http_versions: SupportedHttpVersions, -) -> io::Result<Arc<rustls::ServerConfig>> { - let cert_path = testdata_path().join(cert); - let key_path = testdata_path().join(key); - let ca_path = testdata_path().join(ca); - - let cert_file = std::fs::File::open(cert_path)?; - let key_file = std::fs::File::open(key_path)?; - let ca_file = std::fs::File::open(ca_path)?; - - let certs: Vec<Certificate> = { - let mut cert_reader = io::BufReader::new(cert_file); - rustls_pemfile::certs(&mut cert_reader) - .unwrap() - .into_iter() - .map(Certificate) - .collect() - }; - - let mut ca_cert_reader = io::BufReader::new(ca_file); - let ca_cert = rustls_pemfile::certs(&mut ca_cert_reader) - .expect("Cannot load CA certificate") - .remove(0); - - let mut key_reader = io::BufReader::new(key_file); - let key = { - let pkcs8_key = rustls_pemfile::pkcs8_private_keys(&mut key_reader) - .expect("Cannot load key file"); - let rsa_key = rustls_pemfile::rsa_private_keys(&mut key_reader) - .expect("Cannot load key file"); - if !pkcs8_key.is_empty() { - Some(pkcs8_key[0].clone()) - } else if !rsa_key.is_empty() { - Some(rsa_key[0].clone()) - } else { - None - } - }; - - match key { - Some(key) => { - let mut root_cert_store = rustls::RootCertStore::empty(); - root_cert_store.add(&rustls::Certificate(ca_cert)).unwrap(); - - // Allow (but do not require) client authentication. - - let mut config = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_client_cert_verifier(Arc::new( - rustls::server::AllowAnyAnonymousOrAuthenticatedClient::new( - root_cert_store, - ), - )) - .with_single_cert(certs, PrivateKey(key)) - .map_err(|e| anyhow!("Error setting cert: {:?}", e)) - .unwrap(); - - match http_versions { - SupportedHttpVersions::All => { - config.alpn_protocols = vec!["h2".into(), "http/1.1".into()]; - } - SupportedHttpVersions::Http1Only => {} - SupportedHttpVersions::Http2Only => { - config.alpn_protocols = vec!["h2".into()]; + .await; + while let Some(Ok(tls)) = tls.next().await { + tokio::spawn(async move { + let mut h2 = h2::server::Builder::new(); + h2.enable_connect_protocol(); + // Using Bytes is pretty alloc-heavy but this is a test server + let server: Handshake<_, Bytes> = h2.handshake(tls); + let mut server = match server.await { + Ok(server) => server, + Err(e) => { + println!("Failed to handshake h2: {e:?}"); + return; } + }; + loop { + let Some(conn) = server.accept().await else { + break; + }; + let (recv, send) = match conn { + Ok(conn) => conn, + Err(e) => { + println!("Failed to accept a connection: {e:?}"); + break; + } + }; + tokio::spawn(handle_wss_stream(recv, send)); } - - Ok(Arc::new(config)) - } - None => Err(io::Error::new(io::ErrorKind::Other, "Cannot find key")), + }); } } -async fn run_wss_server(addr: &SocketAddr) { - let cert_file = "tls/localhost.crt"; - let key_file = "tls/localhost.key"; - let ca_cert_file = "tls/RootCA.pem"; - - let tls_config = - get_tls_config(cert_file, key_file, ca_cert_file, Default::default()) - .await - .unwrap(); - let tls_acceptor = TlsAcceptor::from(tls_config); - let listener = TcpListener::bind(addr).await.unwrap(); - println!("ready: wss"); // Eye catcher for HttpServerCount - - while let Ok((stream, _addr)) = listener.accept().await { - let acceptor = tls_acceptor.clone(); +async fn run_wss_server(port: u16) { + let mut tls = get_tls_listener_stream("wss", port, Default::default()).await; + while let Some(Ok(tls_stream)) = tls.next().await { tokio::spawn(async move { - match acceptor.accept(stream).await { - Ok(tls_stream) => { - spawn_ws_server(tls_stream, |ws| { - Box::pin(echo_websocket_handler(ws)) - }); - } - Err(e) => { - eprintln!("TLS accept error: {e:?}"); - } - } + spawn_ws_server(tls_stream, |ws| Box::pin(echo_websocket_handler(ws))); }); } } -/// This server responds with 'PASS' if client authentication was successful. Try it by running -/// test_server and -/// curl --key cli/tests/testdata/tls/localhost.key \ -/// --cert cli/tests/testsdata/tls/localhost.crt \ -/// --cacert cli/tests/testdata/tls/RootCA.crt https://localhost:4552/ -async fn run_tls_client_auth_server() { - let cert_file = "tls/localhost.crt"; - let key_file = "tls/localhost.key"; - let ca_cert_file = "tls/RootCA.pem"; - let tls_config = - get_tls_config(cert_file, key_file, ca_cert_file, Default::default()) - .await - .unwrap(); - let tls_acceptor = TlsAcceptor::from(tls_config); +/// Returns a [`Stream`] of [`TcpStream`]s accepted from the given port. +async fn get_tcp_listener_stream( + name: &'static str, + port: u16, +) -> impl Stream<Item = Result<TcpStream, std::io::Error>> + Unpin + Send { + let host_and_port = &format!("localhost:{port}"); // Listen on ALL addresses that localhost can resolves to. let accept = |listener: tokio::net::TcpListener| { async { let result = listener.accept().await; - Some((result, listener)) + Some((result.map(|r| r.0), listener)) } .boxed() }; - let host_and_port = &format!("localhost:{TLS_CLIENT_AUTH_PORT}"); - + let mut addresses = vec![]; let listeners = tokio::net::lookup_host(host_and_port) .await .expect(host_and_port) - .inspect(|address| println!("{host_and_port} -> {address}")) + .inspect(|address| addresses.push(*address)) .map(tokio::net::TcpListener::bind) .collect::<futures::stream::FuturesUnordered<_>>() .collect::<Vec<_>>() @@ -744,29 +608,37 @@ async fn run_tls_client_auth_server() { .map(|listener| futures::stream::unfold(listener, accept)) .collect::<Vec<_>>(); - println!("ready: tls client auth"); // Eye catcher for HttpServerCount + // Eye catcher for HttpServerCount + println!("ready: {name} on {:?}", addresses); - let mut listeners = futures::stream::select_all(listeners); + futures::stream::select_all(listeners) +} - while let Some(Ok((stream, _addr))) = listeners.next().await { - let acceptor = tls_acceptor.clone(); +/// This server responds with 'PASS' if client authentication was successful. Try it by running +/// test_server and +/// curl --key cli/tests/testdata/tls/localhost.key \ +/// --cert cli/tests/testsdata/tls/localhost.crt \ +/// --cacert cli/tests/testdata/tls/RootCA.crt https://localhost:4552/ +async fn run_tls_client_auth_server() { + let mut tls = get_tls_listener_stream( + "tls client auth", + TLS_CLIENT_AUTH_PORT, + Default::default(), + ) + .await; + while let Some(Ok(mut tls_stream)) = tls.next().await { tokio::spawn(async move { - match acceptor.accept(stream).await { - Ok(mut tls_stream) => { - let (_, tls_session) = tls_stream.get_mut(); - // We only need to check for the presence of client certificates - // here. Rusttls ensures that they are valid and signed by the CA. - let response = match tls_session.peer_certificates() { - Some(_certs) => b"PASS", - None => b"FAIL", - }; - tls_stream.write_all(response).await.unwrap(); - } - - Err(e) => { - eprintln!("TLS accept error: {e:?}"); - } - } + let Ok(handshake) = tls_stream.handshake().await else { + eprintln!("Failed to handshake"); + return; + }; + // We only need to check for the presence of client certificates + // here. Rusttls ensures that they are valid and signed by the CA. + let response = match handshake.has_peer_certificates { + true => b"PASS", + false => b"FAIL", + }; + tls_stream.write_all(response).await.unwrap(); }); } } @@ -775,55 +647,11 @@ async fn run_tls_client_auth_server() { /// test_server and /// curl --cacert cli/tests/testdata/tls/RootCA.crt https://localhost:4553/ async fn run_tls_server() { - let cert_file = "tls/localhost.crt"; - let key_file = "tls/localhost.key"; - let ca_cert_file = "tls/RootCA.pem"; - let tls_config = - get_tls_config(cert_file, key_file, ca_cert_file, Default::default()) - .await - .unwrap(); - let tls_acceptor = TlsAcceptor::from(tls_config); - - // Listen on ALL addresses that localhost can resolves to. - let accept = |listener: tokio::net::TcpListener| { - async { - let result = listener.accept().await; - Some((result, listener)) - } - .boxed() - }; - - let host_and_port = &format!("localhost:{TLS_PORT}"); - - let listeners = tokio::net::lookup_host(host_and_port) - .await - .expect(host_and_port) - .inspect(|address| println!("{host_and_port} -> {address}")) - .map(tokio::net::TcpListener::bind) - .collect::<futures::stream::FuturesUnordered<_>>() - .collect::<Vec<_>>() - .await - .into_iter() - .map(|s| s.unwrap()) - .map(|listener| futures::stream::unfold(listener, accept)) - .collect::<Vec<_>>(); - - println!("ready: tls"); // Eye catcher for HttpServerCount - - let mut listeners = futures::stream::select_all(listeners); - - while let Some(Ok((stream, _addr))) = listeners.next().await { - let acceptor = tls_acceptor.clone(); + let mut tls = + get_tls_listener_stream("tls", TLS_PORT, Default::default()).await; + while let Some(Ok(mut tls_stream)) = tls.next().await { tokio::spawn(async move { - match acceptor.accept(stream).await { - Ok(mut tls_stream) => { - tls_stream.write_all(b"PASS").await.unwrap(); - } - - Err(e) => { - eprintln!("TLS accept error: {e:?}"); - } - } + tls_stream.write_all(b"PASS").await.unwrap(); }); } } @@ -1595,15 +1423,12 @@ async fn download_npm_registry_file( /// Taken from example in https://github.com/ctz/hyper-rustls/blob/a02ef72a227dcdf102f86e905baa7415c992e8b3/examples/server.rs struct HyperAcceptor<'a> { acceptor: Pin< - Box< - dyn Stream<Item = io::Result<tokio_rustls::server::TlsStream<TcpStream>>> - + 'a, - >, + Box<dyn Stream<Item = io::Result<rustls_tokio_stream::TlsStream>> + 'a>, >, } impl hyper::server::accept::Accept for HyperAcceptor<'_> { - type Conn = tokio_rustls::server::TlsStream<TcpStream>; + type Conn = rustls_tokio_stream::TlsStream; type Error = io::Error; fn poll_accept( @@ -1729,142 +1554,56 @@ async fn wrap_main_server_for_addr(main_server_addr: &SocketAddr) { } async fn wrap_main_https_server() { - let main_server_https_addr = SocketAddr::from(([127, 0, 0, 1], HTTPS_PORT)); - let cert_file = "tls/localhost.crt"; - let key_file = "tls/localhost.key"; - let ca_cert_file = "tls/RootCA.pem"; - let tls_config = - get_tls_config(cert_file, key_file, ca_cert_file, Default::default()) - .await - .unwrap(); - loop { - let tcp = TcpListener::bind(&main_server_https_addr) - .await - .expect("Cannot bind TCP"); - println!("ready: https"); // Eye catcher for HttpServerCount - let tls_acceptor = TlsAcceptor::from(tls_config.clone()); - // Prepare a long-running future stream to accept and serve clients. - let incoming_tls_stream = async_stream::stream! { - loop { - let (socket, _) = tcp.accept().await?; - let stream = tls_acceptor.accept(socket); - yield stream.await; - } - } - .boxed(); - - let main_server_https_svc = make_service_fn(|_| async { - Ok::<_, Infallible>(service_fn(main_server)) - }); - let main_server_https = Server::builder(HyperAcceptor { - acceptor: incoming_tls_stream, - }) - .serve(main_server_https_svc); - - //continue to prevent TLS error stopping the server - if main_server_https.await.is_err() { - continue; - } - } + let tls = + get_tls_listener_stream("https", HTTPS_PORT, Default::default()).await; + let main_server_https_svc = + make_service_fn(|_| async { Ok::<_, Infallible>(service_fn(main_server)) }); + let main_server_https = Server::builder(HyperAcceptor { + acceptor: tls.boxed_local(), + }) + .serve(main_server_https_svc); + let _ = main_server_https.await; } async fn wrap_https_h1_only_tls_server() { - let main_server_https_addr = - SocketAddr::from(([127, 0, 0, 1], H1_ONLY_TLS_PORT)); - let cert_file = "tls/localhost.crt"; - let key_file = "tls/localhost.key"; - let ca_cert_file = "tls/RootCA.pem"; - let tls_config = get_tls_config( - cert_file, - key_file, - ca_cert_file, + let tls = get_tls_listener_stream( + "https (h1 only)", + H1_ONLY_TLS_PORT, SupportedHttpVersions::Http1Only, ) - .await - .unwrap(); - loop { - let tcp = TcpListener::bind(&main_server_https_addr) - .await - .expect("Cannot bind TCP"); - println!("ready: https"); // Eye catcher for HttpServerCount - let tls_acceptor = TlsAcceptor::from(tls_config.clone()); - // Prepare a long-running future stream to accept and serve clients. - let incoming_tls_stream = async_stream::stream! { - loop { - let (socket, _) = tcp.accept().await?; - let stream = tls_acceptor.accept(socket); - yield stream.await; - } - } - .boxed(); + .await; - let main_server_https_svc = make_service_fn(|_| async { - Ok::<_, Infallible>(service_fn(main_server)) - }); - let main_server_https = Server::builder(HyperAcceptor { - acceptor: incoming_tls_stream, - }) - .http1_only(true) - .serve(main_server_https_svc); + let main_server_https_svc = + make_service_fn(|_| async { Ok::<_, Infallible>(service_fn(main_server)) }); + let main_server_https = Server::builder(HyperAcceptor { + acceptor: tls.boxed_local(), + }) + .http1_only(true) + .serve(main_server_https_svc); - //continue to prevent TLS error stopping the server - if main_server_https.await.is_err() { - continue; - } - } + let _ = main_server_https.await; } async fn wrap_https_h2_only_tls_server() { - let main_server_https_addr = - SocketAddr::from(([127, 0, 0, 1], H2_ONLY_TLS_PORT)); - let tls_config = create_tls_server_config().await; - loop { - let tcp = TcpListener::bind(&main_server_https_addr) - .await - .expect("Cannot bind TCP"); - println!("ready: https"); // Eye catcher for HttpServerCount - let tls_acceptor = TlsAcceptor::from(tls_config.clone()); - // Prepare a long-running future stream to accept and serve clients. - let incoming_tls_stream = async_stream::stream! { - loop { - let (socket, _) = tcp.accept().await?; - let stream = tls_acceptor.accept(socket); - yield stream.await; - } - } - .boxed(); - - let main_server_https_svc = make_service_fn(|_| async { - Ok::<_, Infallible>(service_fn(main_server)) - }); - let main_server_https = Server::builder(HyperAcceptor { - acceptor: incoming_tls_stream, - }) - .http2_only(true) - .serve(main_server_https_svc); - - //continue to prevent TLS error stopping the server - if main_server_https.await.is_err() { - continue; - } - } -} - -async fn create_tls_server_config() -> Arc<rustls::ServerConfig> { - let cert_file = "tls/localhost.crt"; - let key_file = "tls/localhost.key"; - let ca_cert_file = "tls/RootCA.pem"; - get_tls_config( - cert_file, - key_file, - ca_cert_file, + let tls = get_tls_listener_stream( + "https (h2 only)", + H2_ONLY_TLS_PORT, SupportedHttpVersions::Http2Only, ) - .await - .unwrap() + .await; + + let main_server_https_svc = + make_service_fn(|_| async { Ok::<_, Infallible>(service_fn(main_server)) }); + let main_server_https = Server::builder(HyperAcceptor { + acceptor: tls.boxed_local(), + }) + .http2_only(true) + .serve(main_server_https_svc); + + let _ = main_server_https.await; } -async fn wrap_https_h1_only_server() { +async fn wrap_http_h1_only_server() { let main_server_http_addr = SocketAddr::from(([127, 0, 0, 1], H1_ONLY_PORT)); let main_server_http_svc = @@ -1875,7 +1614,7 @@ async fn wrap_https_h1_only_server() { let _ = main_server_http.await; } -async fn wrap_https_h2_only_server() { +async fn wrap_http_h2_only_server() { let main_server_http_addr = SocketAddr::from(([127, 0, 0, 1], H2_ONLY_PORT)); let main_server_http_svc = @@ -1887,12 +1626,13 @@ async fn wrap_https_h2_only_server() { } async fn h2_grpc_server() { - let addr = SocketAddr::from(([127, 0, 0, 1], H2_GRPC_PORT)); - let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); - - let addr_tls = SocketAddr::from(([127, 0, 0, 1], H2S_GRPC_PORT)); - let listener_tls = tokio::net::TcpListener::bind(addr_tls).await.unwrap(); - let tls_config = create_tls_server_config().await; + let mut tcp = get_tcp_listener_stream("grpc", H2_GRPC_PORT).await; + let mut tls = get_tls_listener_stream( + "grpc (tls)", + H2S_GRPC_PORT, + SupportedHttpVersions::Http2Only, + ) + .await; async fn serve(socket: TcpStream) -> Result<(), anyhow::Error> { let mut connection = h2::server::handshake(socket).await?; @@ -1907,9 +1647,7 @@ async fn h2_grpc_server() { Ok(()) } - async fn serve_tls( - socket: TlsStream<TcpStream>, - ) -> Result<(), anyhow::Error> { + async fn serve_tls(socket: TlsStream) -> Result<(), anyhow::Error> { let mut connection = h2::server::handshake(socket).await?; while let Some(result) = connection.accept().await { @@ -1957,87 +1695,54 @@ async fn h2_grpc_server() { Ok(()) } - let http = tokio::spawn(async move { - loop { - if let Ok((socket, _peer_addr)) = listener.accept().await { - tokio::spawn(async move { - let _ = serve(socket).await; - }); - } + let local_set = LocalSet::new(); + local_set.spawn_local(async move { + while let Some(Ok(tcp)) = tcp.next().await { + tokio::spawn(async move { + let _ = serve(tcp).await; + }); } }); - let https = tokio::spawn(async move { - loop { - if let Ok((socket, _peer_addr)) = listener_tls.accept().await { - let tls_acceptor = TlsAcceptor::from(tls_config.clone()); - let tls = tls_acceptor.accept(socket).await.unwrap(); - tokio::spawn(async move { - let _ = serve_tls(tls).await; - }); - } + local_set.spawn_local(async move { + while let Some(Ok(tls)) = tls.next().await { + tokio::spawn(async move { + let _ = serve_tls(tls).await; + }); } }); - http.await.unwrap(); - https.await.unwrap(); + local_set.await; } async fn wrap_client_auth_https_server() { - let main_server_https_addr = - SocketAddr::from(([127, 0, 0, 1], HTTPS_CLIENT_AUTH_PORT)); - let cert_file = "tls/localhost.crt"; - let key_file = "tls/localhost.key"; - let ca_cert_file = "tls/RootCA.pem"; - let tls_config = - get_tls_config(cert_file, key_file, ca_cert_file, Default::default()) - .await - .unwrap(); - loop { - let tcp = TcpListener::bind(&main_server_https_addr) - .await - .expect("Cannot bind TCP"); - println!("ready: https_client_auth on :{HTTPS_CLIENT_AUTH_PORT:?}"); // Eye catcher for HttpServerCount - let tls_acceptor = TlsAcceptor::from(tls_config.clone()); - // Prepare a long-running future stream to accept and serve clients. - let incoming_tls_stream = async_stream::stream! { - loop { - let (socket, _) = tcp.accept().await?; - - match tls_acceptor.accept(socket).await { - Ok(mut tls_stream) => { - let (_, tls_session) = tls_stream.get_mut(); - // We only need to check for the presence of client certificates - // here. Rusttls ensures that they are valid and signed by the CA. - match tls_session.peer_certificates() { - Some(_certs) => { yield Ok(tls_stream); }, - None => { eprintln!("https_client_auth: no valid client certificate"); }, - }; - } - - Err(e) => { - eprintln!("https-client-auth accept error: {e:?}"); - yield Err(e); - } - } + let mut tls = get_tls_listener_stream( + "https_client_auth", + HTTPS_CLIENT_AUTH_PORT, + Default::default(), + ) + .await; - } + let tls = async_stream::stream! { + while let Some(Ok(mut tls)) = tls.next().await { + let handshake = tls.handshake().await?; + // We only need to check for the presence of client certificates + // here. Rusttls ensures that they are valid and signed by the CA. + match handshake.has_peer_certificates { + true => { yield Ok(tls); }, + false => { eprintln!("https_client_auth: no valid client certificate"); }, + }; } - .boxed(); + }; - let main_server_https_svc = make_service_fn(|_| async { - Ok::<_, Infallible>(service_fn(main_server)) - }); - let main_server_https = Server::builder(HyperAcceptor { - acceptor: incoming_tls_stream, - }) - .serve(main_server_https_svc); + let main_server_https_svc = + make_service_fn(|_| async { Ok::<_, Infallible>(service_fn(main_server)) }); + let main_server_https = Server::builder(HyperAcceptor { + acceptor: tls.boxed_local(), + }) + .serve(main_server_https_svc); - //continue to prevent TLS error stopping the server - if main_server_https.await.is_err() { - continue; - } - } + let _ = main_server_https.await; } // Use the single-threaded scheduler. The hyper server is used as a point of @@ -2057,16 +1762,11 @@ pub async fn run_all_servers() { let basic_auth_redirect_server_fut = wrap_basic_auth_redirect_server(); let abs_redirect_server_fut = wrap_abs_redirect_server(); - let ws_addr = SocketAddr::from(([127, 0, 0, 1], WS_PORT)); - let ws_server_fut = run_ws_server(&ws_addr); - let ws_ping_addr = SocketAddr::from(([127, 0, 0, 1], WS_PING_PORT)); - let ws_ping_server_fut = run_ws_ping_server(&ws_ping_addr); - let wss_addr = SocketAddr::from(([127, 0, 0, 1], WSS_PORT)); - let wss_server_fut = run_wss_server(&wss_addr); - let ws_close_addr = SocketAddr::from(([127, 0, 0, 1], WS_CLOSE_PORT)); - let ws_close_server_fut = run_ws_close_server(&ws_close_addr); - let wss2_addr = SocketAddr::from(([127, 0, 0, 1], WSS2_PORT)); - let wss2_server_fut = run_wss2_server(&wss2_addr); + let ws_server_fut = run_ws_server(WS_PORT); + let ws_ping_server_fut = run_ws_ping_server(WS_PING_PORT); + let wss_server_fut = run_wss_server(WSS_PORT); + let ws_close_server_fut = run_ws_close_server(WS_CLOSE_PORT); + let wss2_server_fut = run_wss2_server(WSS2_PORT); let tls_server_fut = run_tls_server(); let tls_client_auth_server_fut = run_tls_client_auth_server(); @@ -2076,11 +1776,11 @@ pub async fn run_all_servers() { let main_server_https_fut = wrap_main_https_server(); let h1_only_server_tls_fut = wrap_https_h1_only_tls_server(); let h2_only_server_tls_fut = wrap_https_h2_only_tls_server(); - let h1_only_server_fut = wrap_https_h1_only_server(); - let h2_only_server_fut = wrap_https_h2_only_server(); + let h1_only_server_fut = wrap_http_h1_only_server(); + let h2_only_server_fut = wrap_http_h2_only_server(); let h2_grpc_server_fut = h2_grpc_server(); - let mut server_fut = async { + let server_fut = async { futures::join!( redirect_server_fut, ws_server_fut, @@ -2107,17 +1807,9 @@ pub async fn run_all_servers() { h2_grpc_server_fut, ) } - .boxed(); + .boxed_local(); - let mut did_print_ready = false; - futures::future::poll_fn(move |cx| { - let poll_result = server_fut.poll_unpin(cx); - if !replace(&mut did_print_ready, true) { - println!("ready: server_fut"); // Eye catcher for HttpServerCount - } - poll_result - }) - .await; + server_fut.await; } fn custom_headers(p: &str, body: Vec<u8>) -> Response<Body> { @@ -2243,7 +1935,7 @@ impl HttpServerCount { if line.starts_with("ready:") { ready_count += 1; } - if ready_count == 6 { + if ready_count == 12 { break; } } else { |