summaryrefslogtreecommitdiff
path: root/tests/integration/jupyter_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tests/integration/jupyter_tests.rs')
-rw-r--r--tests/integration/jupyter_tests.rs148
1 files changed, 101 insertions, 47 deletions
diff --git a/tests/integration/jupyter_tests.rs b/tests/integration/jupyter_tests.rs
index af8101ea7..75b1da085 100644
--- a/tests/integration/jupyter_tests.rs
+++ b/tests/integration/jupyter_tests.rs
@@ -47,25 +47,43 @@ impl ConnectionSpec {
}
}
-fn pick_unused_port() -> u16 {
+/// Gets an unused port from the OS, and returns the port number and a
+/// `TcpListener` bound to that port. You can keep the listener alive
+/// to prevent another process from binding to the port.
+fn pick_unused_port() -> (u16, std::net::TcpListener) {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
- listener.local_addr().unwrap().port()
+ (listener.local_addr().unwrap().port(), listener)
}
-impl Default for ConnectionSpec {
- fn default() -> Self {
- Self {
- key: "".into(),
- signature_scheme: "hmac-sha256".into(),
- transport: "tcp".into(),
- ip: "127.0.0.1".into(),
- hb_port: pick_unused_port(),
- control_port: pick_unused_port(),
- shell_port: pick_unused_port(),
- stdin_port: pick_unused_port(),
- iopub_port: pick_unused_port(),
- kernel_name: "deno".into(),
- }
+impl ConnectionSpec {
+ fn new() -> (Self, Vec<std::net::TcpListener>) {
+ let mut listeners = Vec::new();
+ let (hb_port, listener) = pick_unused_port();
+ listeners.push(listener);
+ let (control_port, listener) = pick_unused_port();
+ listeners.push(listener);
+ let (shell_port, listener) = pick_unused_port();
+ listeners.push(listener);
+ let (stdin_port, listener) = pick_unused_port();
+ listeners.push(listener);
+ let (iopub_port, listener) = pick_unused_port();
+ listeners.push(listener);
+
+ (
+ Self {
+ key: "".into(),
+ signature_scheme: "hmac-sha256".into(),
+ transport: "tcp".into(),
+ ip: "127.0.0.1".into(),
+ hb_port,
+ control_port,
+ shell_port,
+ stdin_port,
+ iopub_port,
+ kernel_name: "deno".into(),
+ },
+ listeners,
+ )
}
}
@@ -191,25 +209,15 @@ async fn connect_socket<S: zeromq::Socket>(
) -> S {
let addr = spec.endpoint(port);
let mut socket = S::new();
- let mut connected = false;
- for _ in 0..5 {
- match timeout(Duration::from_secs(5), socket.connect(&addr)).await {
- Ok(Ok(_)) => {
- connected = true;
- break;
- }
- Ok(Err(e)) => {
- eprintln!("Failed to connect to {addr}: {e}");
- }
- Err(e) => {
- eprintln!("Timed out connecting to {addr}: {e}");
- }
+ match timeout(Duration::from_millis(5000), socket.connect(&addr)).await {
+ Ok(Ok(_)) => socket,
+ Ok(Err(e)) => {
+ panic!("Failed to connect to {addr}: {e}");
+ }
+ Err(e) => {
+ panic!("Timed out connecting to {addr}: {e}");
}
}
- if !connected {
- panic!("Failed to connect to {addr}");
- }
- socket
}
#[derive(Clone)]
@@ -236,7 +244,7 @@ use JupyterChannel::*;
impl JupyterClient {
async fn new(spec: &ConnectionSpec) -> Self {
- Self::new_with_timeout(spec, Duration::from_secs(5)).await
+ Self::new_with_timeout(spec, Duration::from_secs(10)).await
}
async fn new_with_timeout(spec: &ConnectionSpec, timeout: Duration) -> Self {
@@ -386,9 +394,36 @@ impl Drop for JupyterServerProcess {
}
}
+async fn server_ready_on(addr: &str) -> bool {
+ matches!(
+ timeout(
+ Duration::from_millis(1000),
+ tokio::net::TcpStream::connect(addr.trim_start_matches("tcp://")),
+ )
+ .await,
+ Ok(Ok(_))
+ )
+}
+
+async fn server_ready(conn: &ConnectionSpec) -> bool {
+ let hb = conn.endpoint(conn.hb_port);
+ let control = conn.endpoint(conn.control_port);
+ let shell = conn.endpoint(conn.shell_port);
+ let stdin = conn.endpoint(conn.stdin_port);
+ let iopub = conn.endpoint(conn.iopub_port);
+ let (a, b, c, d, e) = tokio::join!(
+ server_ready_on(&hb),
+ server_ready_on(&control),
+ server_ready_on(&shell),
+ server_ready_on(&stdin),
+ server_ready_on(&iopub),
+ );
+ a && b && c && d && e
+}
+
async fn setup_server() -> (TestContext, ConnectionSpec, JupyterServerProcess) {
let context = TestContextBuilder::new().use_temp_cwd().build();
- let mut conn = ConnectionSpec::default();
+ let (mut conn, mut listeners) = ConnectionSpec::new();
let conn_file = context.temp_dir().path().join("connection.json");
conn_file.write_json(&conn);
@@ -405,22 +440,38 @@ async fn setup_server() -> (TestContext, ConnectionSpec, JupyterServerProcess) {
.unwrap()
};
+ // drop the listeners so the server can listen on the ports
+ drop(listeners);
+
// try to start the server, retrying up to 5 times
// (this can happen due to TOCTOU errors with selecting unused TCP ports)
let mut process = start_process(&conn_file);
- tokio::time::sleep(Duration::from_millis(1000)).await;
-
- for _ in 0..5 {
- if process.try_wait().unwrap().is_none() {
- break;
- } else {
- conn = ConnectionSpec::default();
- conn_file.write_json(&conn);
- process = start_process(&conn_file);
- tokio::time::sleep(Duration::from_millis(1000)).await;
+
+ 'outer: for i in 0..10 {
+ // try to see if the server is healthy
+ for _ in 0..10 {
+ // server still running?
+ if process.try_wait().unwrap().is_none() {
+ // listening on all ports?
+ if server_ready(&conn).await {
+ // server is ready to go
+ break 'outer;
+ }
+ } else {
+ // server exited, try again
+ break;
+ }
+ tokio::time::sleep(Duration::from_millis(500)).await;
}
+
+ // pick new ports and try again
+ (conn, listeners) = ConnectionSpec::new();
+ conn_file.write_json(&conn);
+ drop(listeners);
+ process = start_process(&conn_file);
+ tokio::time::sleep(Duration::from_millis((i + 1) * 250)).await;
}
- if process.try_wait().unwrap().is_some() {
+ if process.try_wait().unwrap().is_some() || !server_ready(&conn).await {
panic!("Failed to start Jupyter server");
}
(context, conn, JupyterServerProcess(Some(process)))
@@ -430,6 +481,9 @@ async fn setup() -> (TestContext, JupyterClient, JupyterServerProcess) {
let (context, conn, process) = setup_server().await;
let client = JupyterClient::new(&conn).await;
client.io_subscribe("").await.unwrap();
+ // make sure server is ready to receive messages
+ client.send_heartbeat(b"ping").await.unwrap();
+ let _ = client.recv_heartbeat().await.unwrap();
(context, client, process)
}
@@ -530,7 +584,7 @@ async fn jupyter_execute_request() -> Result<()> {
Err(e) => {
if e.downcast_ref::<tokio::time::error::Elapsed>().is_some() {
// may timeout if we missed some messages
- break;
+ eprintln!("Timed out waiting for messages");
}
panic!("Error: {:#?}", e);
}