summaryrefslogtreecommitdiff
path: root/tests/integration/jupyter_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tests/integration/jupyter_tests.rs')
-rw-r--r--tests/integration/jupyter_tests.rs535
1 files changed, 535 insertions, 0 deletions
diff --git a/tests/integration/jupyter_tests.rs b/tests/integration/jupyter_tests.rs
new file mode 100644
index 000000000..29636f305
--- /dev/null
+++ b/tests/integration/jupyter_tests.rs
@@ -0,0 +1,535 @@
+// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
+
+use std::process::Output;
+use std::sync::Arc;
+use std::time::Duration;
+
+use bytes::Bytes;
+use test_util::DenoChild;
+use test_util::TestContext;
+use test_util::TestContextBuilder;
+
+use deno_core::anyhow::Result;
+use deno_core::serde_json;
+use deno_core::serde_json::json;
+use deno_core::serde_json::Value;
+use serde::Deserialize;
+use serde::Serialize;
+use tokio::sync::Mutex;
+use tokio::time::timeout;
+use uuid::Uuid;
+use zeromq::SocketRecv;
+use zeromq::SocketSend;
+use zeromq::ZmqMessage;
+
+// for the `utc_now` function
+include!("../../cli/util/time.rs");
+
+/// Jupyter connection file format
+#[derive(Serialize)]
+struct ConnectionSpec {
+ // key used for HMAC signature, if empty, hmac is not used
+ key: String,
+ signature_scheme: String,
+ transport: String,
+ ip: String,
+ hb_port: u16,
+ control_port: u16,
+ shell_port: u16,
+ stdin_port: u16,
+ iopub_port: u16,
+ kernel_name: String,
+}
+
+impl ConnectionSpec {
+ fn endpoint(&self, port: u16) -> String {
+ format!("{}://{}:{}", self.transport, self.ip, port)
+ }
+}
+
+fn pick_unused_port() -> u16 {
+ let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
+ listener.local_addr().unwrap().port()
+}
+
+impl Default for ConnectionSpec {
+ fn default() -> Self {
+ Self {
+ key: "".into(),
+ signature_scheme: "hmac-sha256".into(),
+ transport: "tcp".into(),
+ ip: "127.0.0.1".into(),
+ hb_port: pick_unused_port(),
+ control_port: pick_unused_port(),
+ shell_port: pick_unused_port(),
+ stdin_port: pick_unused_port(),
+ iopub_port: pick_unused_port(),
+ kernel_name: "deno".into(),
+ }
+ }
+}
+
+const DELIMITER: &[u8] = b"<IDS|MSG>";
+
+#[derive(Debug, Clone)]
+struct JupyterMsg {
+ routing_prefix: Vec<String>,
+ signature: String,
+ header: MsgHeader,
+ parent_header: Value,
+ metadata: Value,
+ content: Value,
+ buffers: Vec<Bytes>,
+}
+
+impl Default for JupyterMsg {
+ fn default() -> Self {
+ Self {
+ routing_prefix: vec![Uuid::new_v4().to_string()],
+ signature: "".into(),
+ header: MsgHeader::default(),
+ parent_header: json!({}),
+ metadata: json!({}),
+ content: json!({}),
+ buffers: Vec::new(),
+ }
+ }
+}
+
+#[derive(Serialize, Clone, Debug, Deserialize)]
+struct MsgHeader {
+ msg_id: Uuid,
+ session: Uuid,
+ date: String,
+ username: String,
+ msg_type: String,
+ version: String,
+}
+
+impl MsgHeader {
+ fn to_json(&self) -> Value {
+ serde_json::to_value(self).unwrap()
+ }
+}
+
+impl Default for MsgHeader {
+ fn default() -> Self {
+ Self {
+ msg_id: Uuid::new_v4(),
+ session: Uuid::new_v4(),
+ date: utc_now().to_rfc3339(),
+ username: "test".into(),
+ msg_type: "kernel_info_request".into(),
+ version: "5.3".into(),
+ }
+ }
+}
+
+impl JupyterMsg {
+ fn to_raw(&self) -> ZmqMessage {
+ let mut parts = Vec::new();
+ parts.extend(
+ self
+ .routing_prefix
+ .iter()
+ .map(|uuid| uuid.as_bytes().to_vec().into()),
+ );
+ parts.push(Bytes::from_static(DELIMITER));
+ parts.push(self.signature.clone().into());
+ parts.push(serde_json::to_vec(&self.header).unwrap().into());
+ parts.push(self.parent_header.to_string().into());
+ parts.push(self.metadata.to_string().into());
+ parts.push(self.content.to_string().into());
+ parts.extend(self.buffers.clone());
+ ZmqMessage::try_from(parts).unwrap()
+ }
+
+ fn new(session: Uuid, msg_type: impl AsRef<str>, content: Value) -> Self {
+ Self {
+ header: MsgHeader {
+ session,
+ msg_type: msg_type.as_ref().into(),
+ ..Default::default()
+ },
+ content,
+ ..Default::default()
+ }
+ }
+
+ fn from_raw(msg: ZmqMessage) -> Self {
+ let parts = msg.into_vec();
+ let delimiter = parts.iter().position(|part| part == DELIMITER).unwrap();
+ let routing_prefix = parts[..delimiter]
+ .iter()
+ .map(|part: &Bytes| String::from_utf8_lossy(part.as_ref()).to_string())
+ .collect();
+ let signature = String::from_utf8(parts[delimiter + 1].to_vec())
+ .expect("Failed to parse signature");
+ let header: MsgHeader = serde_json::from_slice(&parts[delimiter + 2])
+ .expect("Failed to parse header");
+ let parent_header: Value =
+ serde_json::from_slice(&parts[delimiter + 3]).unwrap();
+ let metadata: Value =
+ serde_json::from_slice(&parts[delimiter + 4]).unwrap();
+ let content: Value = serde_json::from_slice(&parts[delimiter + 5]).unwrap();
+ let buffers = parts[delimiter + 6..].to_vec();
+ Self {
+ routing_prefix,
+ signature,
+ header,
+ parent_header,
+ metadata,
+ content,
+ buffers,
+ }
+ }
+}
+
+async fn connect_socket<S: zeromq::Socket>(
+ spec: &ConnectionSpec,
+ port: u16,
+) -> S {
+ let addr = spec.endpoint(port);
+ let mut socket = S::new();
+ socket.connect(&addr).await.unwrap();
+ socket
+}
+
+#[derive(Clone)]
+struct JupyterClient {
+ recv_timeout: Duration,
+ session: Uuid,
+ heartbeat: Arc<Mutex<zeromq::ReqSocket>>,
+ control: Arc<Mutex<zeromq::DealerSocket>>,
+ shell: Arc<Mutex<zeromq::DealerSocket>>,
+ io_pub: Arc<Mutex<zeromq::SubSocket>>,
+ stdin: Arc<Mutex<zeromq::RouterSocket>>,
+}
+
+#[derive(Debug, Clone, Copy)]
+enum JupyterChannel {
+ Control,
+ Shell,
+ #[allow(dead_code)]
+ Stdin,
+ IoPub,
+}
+
+use JupyterChannel::*;
+
+impl JupyterClient {
+ async fn new(spec: &ConnectionSpec) -> Self {
+ Self::new_with_timeout(spec, Duration::from_secs(5)).await
+ }
+
+ async fn new_with_timeout(spec: &ConnectionSpec, timeout: Duration) -> Self {
+ let (heartbeat, control, shell, io_pub, stdin) = tokio::join!(
+ connect_socket::<zeromq::ReqSocket>(spec, spec.hb_port),
+ connect_socket::<zeromq::DealerSocket>(spec, spec.control_port),
+ connect_socket::<zeromq::DealerSocket>(spec, spec.shell_port),
+ connect_socket::<zeromq::SubSocket>(spec, spec.iopub_port),
+ connect_socket::<zeromq::RouterSocket>(spec, spec.stdin_port),
+ );
+
+ Self {
+ session: Uuid::new_v4(),
+ heartbeat: Arc::new(Mutex::new(heartbeat)),
+ control: Arc::new(Mutex::new(control)),
+ shell: Arc::new(Mutex::new(shell)),
+ io_pub: Arc::new(Mutex::new(io_pub)),
+ stdin: Arc::new(Mutex::new(stdin)),
+ recv_timeout: timeout,
+ }
+ }
+
+ async fn io_subscribe(&self, topic: &str) -> Result<()> {
+ Ok(self.io_pub.lock().await.subscribe(topic).await?)
+ }
+
+ async fn recv_with_timeout<S: SocketRecv>(
+ &self,
+ s: &mut S,
+ ) -> Result<JupyterMsg> {
+ let msg = timeout(self.recv_timeout, s.recv()).await??;
+ Ok(JupyterMsg::from_raw(msg))
+ }
+
+ async fn send_msg(
+ &self,
+ channel: JupyterChannel,
+ msg: JupyterMsg,
+ ) -> Result<JupyterMsg> {
+ let raw = msg.to_raw();
+ match channel {
+ Control => self.control.lock().await.send(raw).await?,
+ Shell => self.shell.lock().await.send(raw).await?,
+ Stdin => self.stdin.lock().await.send(raw).await?,
+ IoPub => panic!("Cannot send over IOPub"),
+ }
+ Ok(msg)
+ }
+
+ async fn send(
+ &self,
+ channel: JupyterChannel,
+ msg_type: &str,
+ content: Value,
+ ) -> Result<JupyterMsg> {
+ let msg = JupyterMsg::new(self.session, msg_type, content);
+ self.send_msg(channel, msg).await
+ }
+
+ async fn recv(&self, channel: JupyterChannel) -> Result<JupyterMsg> {
+ Ok(match channel {
+ Control => {
+ self
+ .recv_with_timeout(&mut *self.control.lock().await)
+ .await?
+ }
+ Shell => {
+ self
+ .recv_with_timeout(&mut *self.shell.lock().await)
+ .await?
+ }
+ Stdin => {
+ self
+ .recv_with_timeout(&mut *self.stdin.lock().await)
+ .await?
+ }
+ IoPub => {
+ self
+ .recv_with_timeout(&mut *self.io_pub.lock().await)
+ .await?
+ }
+ })
+ }
+
+ async fn send_heartbeat(&self, bytes: impl AsRef<[u8]>) -> Result<()> {
+ Ok(
+ self
+ .heartbeat
+ .lock()
+ .await
+ .send(ZmqMessage::from(bytes.as_ref().to_vec()))
+ .await?,
+ )
+ }
+
+ async fn recv_heartbeat(&self) -> Result<Bytes> {
+ Ok(
+ timeout(self.recv_timeout, self.heartbeat.lock().await.recv())
+ .await??
+ .into_vec()[0]
+ .clone(),
+ )
+ }
+}
+
+async fn wait_or_kill(
+ mut process: DenoChild,
+ wait: Duration,
+) -> Result<Output> {
+ let start = std::time::Instant::now();
+ while start.elapsed() < wait {
+ if process.try_wait()?.is_some() {
+ return Ok(process.wait_with_output()?);
+ }
+ tokio::time::sleep(Duration::from_millis(100)).await;
+ }
+ process.kill()?;
+ Ok(process.wait_with_output()?)
+}
+
+// Wrapper around the Jupyter server process that
+// ensures the process is killed when dropped.
+struct JupyterServerProcess(Option<DenoChild>);
+
+impl JupyterServerProcess {
+ // Wait for the process to exit, or kill it after the given duration.
+ //
+ // Ideally we could use this at the end of each test, but the server
+ // doesn't seem to exit in a reasonable amount of time after getting
+ // a shutdown request.
+ #[allow(dead_code)]
+ async fn wait_or_kill(mut self, wait: Duration) -> Output {
+ wait_or_kill(self.0.take().unwrap(), wait).await.unwrap()
+ }
+}
+
+impl Drop for JupyterServerProcess {
+ fn drop(&mut self) {
+ let Some(mut proc) = self.0.take() else {
+ return;
+ };
+ if proc.try_wait().unwrap().is_some() {
+ // already exited
+ return;
+ }
+ proc.kill().unwrap();
+ }
+}
+
+fn setup_server() -> (TestContext, ConnectionSpec, JupyterServerProcess) {
+ let context = TestContextBuilder::new().use_temp_cwd().build();
+ let conn = ConnectionSpec::default();
+ let conn_file = context.temp_dir().path().join("connection.json");
+ conn_file.write_json(&conn);
+ let process = context
+ .new_command()
+ .piped_output()
+ .args_vec(vec![
+ "jupyter",
+ "--kernel",
+ "--conn",
+ conn_file.to_string().as_str(),
+ ])
+ .spawn()
+ .unwrap();
+ (context, conn, JupyterServerProcess(Some(process)))
+}
+
+async fn setup() -> (TestContext, JupyterClient, JupyterServerProcess) {
+ let (context, conn, process) = setup_server();
+ let client = JupyterClient::new(&conn).await;
+ client.io_subscribe("").await.unwrap();
+
+ (context, client, process)
+}
+
+/// Asserts that the actual value is equal to the expected value, but
+/// only for the keys present in the expected value.
+/// In other words, `assert_eq_subset(json!({"a": 1, "b": 2}), json!({"a": 1}))` would pass.
+#[track_caller]
+fn assert_eq_subset(actual: Value, expected: Value) {
+ match (actual, expected) {
+ (Value::Object(actual), Value::Object(expected)) => {
+ for (k, v) in expected.iter() {
+ let Some(actual_v) = actual.get(k) else {
+ panic!("Key {k:?} not found in actual value ({actual:#?})");
+ };
+ assert_eq_subset(actual_v.clone(), v.clone());
+ }
+ }
+ (Value::Array(actual), Value::Array(expected)) => {
+ for (i, v) in expected.iter().enumerate() {
+ assert_eq_subset(actual[i].clone(), v.clone());
+ }
+ }
+ (actual, expected) => {
+ assert_eq!(actual, expected);
+ }
+ }
+}
+
+#[tokio::test]
+async fn jupyter_heartbeat_echoes() -> Result<()> {
+ let (_ctx, client, _process) = setup().await;
+ client.send_heartbeat(b"ping").await?;
+ let msg = client.recv_heartbeat().await?;
+ assert_eq!(msg, Bytes::from_static(b"ping"));
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn jupyter_kernel_info() -> Result<()> {
+ let (_ctx, client, _process) = setup().await;
+ client
+ .send(Control, "kernel_info_request", json!({}))
+ .await?;
+ let msg = client.recv(Control).await?;
+ assert_eq!(msg.header.msg_type, "kernel_info_reply");
+ assert_eq_subset(
+ msg.content,
+ json!({
+ "status": "ok",
+ "implementation": "Deno kernel",
+ "language_info": {
+ "name": "typescript",
+ "mimetype": "text/x.typescript",
+ "file_extension": ".ts",
+ "pygments_lexer": "typescript",
+ "nb_converter": "script"
+ },
+ }),
+ );
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn jupyter_execute_request() -> Result<()> {
+ let (_ctx, client, _process) = setup().await;
+ let request = client
+ .send(
+ Shell,
+ "execute_request",
+ json!({
+ "silent": false,
+ "store_history": true,
+ "user_expressions": {},
+ "allow_stdin": true,
+ "stop_on_error": false,
+ "code": "console.log(\"asdf\")"
+ }),
+ )
+ .await?;
+ let reply = client.recv(Shell).await?;
+ assert_eq!(reply.header.msg_type, "execute_reply");
+ assert_eq_subset(
+ reply.content,
+ json!({
+ "status": "ok",
+ "execution_count": 1,
+ }),
+ );
+
+ let mut msgs = Vec::new();
+
+ for _ in 0..4 {
+ match client.recv(IoPub).await {
+ Ok(msg) => msgs.push(msg),
+ Err(e) => {
+ if e.downcast_ref::<tokio::time::error::Elapsed>().is_some() {
+ // may timeout if we missed some messages
+ break;
+ }
+ panic!("Error: {:#?}", e);
+ }
+ }
+ }
+
+ let execution_idle = msgs
+ .iter()
+ .find(|msg| {
+ if let Some(state) = msg.content.get("execution_state") {
+ state == "idle"
+ } else {
+ false
+ }
+ })
+ .expect("execution_state idle not found");
+ assert_eq!(execution_idle.parent_header, request.header.to_json());
+ assert_eq_subset(
+ execution_idle.content.clone(),
+ json!({
+ "execution_state": "idle",
+ }),
+ );
+
+ let execution_result = msgs
+ .iter()
+ .find(|msg| msg.header.msg_type == "stream")
+ .expect("stream not found");
+ assert_eq!(execution_result.header.msg_type, "stream");
+ assert_eq!(execution_result.parent_header, request.header.to_json());
+ assert_eq_subset(
+ execution_result.content.clone(),
+ json!({
+ "name": "stdout",
+ "text": "asdf\n", // the trailing newline is added by console.log
+ }),
+ );
+
+ Ok(())
+}