diff options
Diffstat (limited to 'cli/tools/jupyter/jupyter_msg.rs')
-rw-r--r-- | cli/tools/jupyter/jupyter_msg.rs | 218 |
1 files changed, 112 insertions, 106 deletions
diff --git a/cli/tools/jupyter/jupyter_msg.rs b/cli/tools/jupyter/jupyter_msg.rs index 60703e365..233efcc8e 100644 --- a/cli/tools/jupyter/jupyter_msg.rs +++ b/cli/tools/jupyter/jupyter_msg.rs @@ -16,14 +16,14 @@ use uuid::Uuid; use crate::util::time::utc_now; -pub(crate) struct Connection<S> { - pub(crate) socket: S, +pub struct Connection<S> { + socket: S, /// Will be None if our key was empty (digest authentication disabled). - pub(crate) mac: Option<hmac::Key>, + mac: Option<hmac::Key>, } impl<S: zeromq::Socket> Connection<S> { - pub(crate) fn new(socket: S, key: &str) -> Self { + pub fn new(socket: S, key: &str) -> Self { let mac = if key.is_empty() { None } else { @@ -33,21 +33,107 @@ impl<S: zeromq::Socket> Connection<S> { } } +impl<S: zeromq::SocketSend + zeromq::SocketRecv> Connection<S> { + pub async fn single_heartbeat(&mut self) -> Result<(), AnyError> { + self.socket.recv().await?; + self + .socket + .send(zeromq::ZmqMessage::from(b"ping".to_vec())) + .await?; + Ok(()) + } +} + +impl<S: zeromq::SocketRecv> Connection<S> { + pub async fn read(&mut self) -> Result<JupyterMessage, AnyError> { + let multipart = self.socket.recv().await?; + let raw_message = RawMessage::from_multipart(multipart, self.mac.as_ref())?; + JupyterMessage::from_raw_message(raw_message) + } +} + +impl<S: zeromq::SocketSend> Connection<S> { + pub async fn send( + &mut self, + message: &JupyterMessage, + ) -> Result<(), AnyError> { + // If performance is a concern, we can probably avoid the clone and to_vec calls with a bit + // of refactoring. + let mut jparts: Vec<Bytes> = vec![ + serde_json::to_string(&message.header) + .unwrap() + .as_bytes() + .to_vec() + .into(), + serde_json::to_string(&message.parent_header) + .unwrap() + .as_bytes() + .to_vec() + .into(), + serde_json::to_string(&message.metadata) + .unwrap() + .as_bytes() + .to_vec() + .into(), + serde_json::to_string(&message.content) + .unwrap() + .as_bytes() + .to_vec() + .into(), + ]; + jparts.extend_from_slice(&message.buffers); + let raw_message = RawMessage { + zmq_identities: message.zmq_identities.clone(), + jparts, + }; + self.send_raw(raw_message).await + } + + async fn send_raw( + &mut self, + raw_message: RawMessage, + ) -> Result<(), AnyError> { + let hmac = if let Some(key) = &self.mac { + let ctx = digest(key, &raw_message.jparts); + let tag = ctx.sign(); + HEXLOWER.encode(tag.as_ref()) + } else { + String::new() + }; + let mut parts: Vec<bytes::Bytes> = Vec::new(); + for part in &raw_message.zmq_identities { + parts.push(part.to_vec().into()); + } + parts.push(DELIMITER.into()); + parts.push(hmac.as_bytes().to_vec().into()); + for part in &raw_message.jparts { + parts.push(part.to_vec().into()); + } + // ZmqMessage::try_from only fails if parts is empty, which it never + // will be here. + let message = zeromq::ZmqMessage::try_from(parts).unwrap(); + self.socket.send(message).await?; + Ok(()) + } +} + +fn digest(mac: &hmac::Key, jparts: &[Bytes]) -> hmac::Context { + let mut hmac_ctx = hmac::Context::with_key(mac); + for part in jparts { + hmac_ctx.update(part); + } + hmac_ctx +} + struct RawMessage { zmq_identities: Vec<Bytes>, jparts: Vec<Bytes>, } impl RawMessage { - pub(crate) async fn read<S: zeromq::SocketRecv>( - connection: &mut Connection<S>, - ) -> Result<RawMessage, AnyError> { - Self::from_multipart(connection.socket.recv().await?, connection) - } - - pub(crate) fn from_multipart<S>( + pub fn from_multipart( multipart: zeromq::ZmqMessage, - connection: &Connection<S>, + mac: Option<&hmac::Key>, ) -> Result<RawMessage, AnyError> { let delimiter_index = multipart .iter() @@ -65,7 +151,7 @@ impl RawMessage { jparts, }; - if let Some(key) = &connection.mac { + if let Some(key) = mac { let sig = HEXLOWER.decode(&expected_hmac)?; let mut msg = Vec::new(); for part in &raw_message.jparts { @@ -79,45 +165,10 @@ impl RawMessage { Ok(raw_message) } - - async fn send<S: zeromq::SocketSend>( - self, - connection: &mut Connection<S>, - ) -> Result<(), AnyError> { - let hmac = if let Some(key) = &connection.mac { - let ctx = self.digest(key); - let tag = ctx.sign(); - HEXLOWER.encode(tag.as_ref()) - } else { - String::new() - }; - let mut parts: Vec<bytes::Bytes> = Vec::new(); - for part in &self.zmq_identities { - parts.push(part.to_vec().into()); - } - parts.push(DELIMITER.into()); - parts.push(hmac.as_bytes().to_vec().into()); - for part in &self.jparts { - parts.push(part.to_vec().into()); - } - // ZmqMessage::try_from only fails if parts is empty, which it never - // will be here. - let message = zeromq::ZmqMessage::try_from(parts).unwrap(); - connection.socket.send(message).await?; - Ok(()) - } - - fn digest(&self, mac: &hmac::Key) -> hmac::Context { - let mut hmac_ctx = hmac::Context::with_key(mac); - for part in &self.jparts { - hmac_ctx.update(part); - } - hmac_ctx - } } #[derive(Clone)] -pub(crate) struct JupyterMessage { +pub struct JupyterMessage { zmq_identities: Vec<Bytes>, header: serde_json::Value, parent_header: serde_json::Value, @@ -129,12 +180,6 @@ pub(crate) struct JupyterMessage { const DELIMITER: &[u8] = b"<IDS|MSG>"; impl JupyterMessage { - pub(crate) async fn read<S: zeromq::SocketRecv>( - connection: &mut Connection<S>, - ) -> Result<JupyterMessage, AnyError> { - Self::from_raw_message(RawMessage::read(connection).await?) - } - fn from_raw_message( raw_message: RawMessage, ) -> Result<JupyterMessage, AnyError> { @@ -156,32 +201,32 @@ impl JupyterMessage { }) } - pub(crate) fn message_type(&self) -> &str { + pub fn message_type(&self) -> &str { self.header["msg_type"].as_str().unwrap_or("") } - pub(crate) fn store_history(&self) -> bool { + pub fn store_history(&self) -> bool { self.content["store_history"].as_bool().unwrap_or(true) } - pub(crate) fn silent(&self) -> bool { + pub fn silent(&self) -> bool { self.content["silent"].as_bool().unwrap_or(false) } - pub(crate) fn code(&self) -> &str { + pub fn code(&self) -> &str { self.content["code"].as_str().unwrap_or("") } - pub(crate) fn cursor_pos(&self) -> usize { + pub fn cursor_pos(&self) -> usize { self.content["cursor_pos"].as_u64().unwrap_or(0) as usize } - pub(crate) fn comm_id(&self) -> &str { + pub fn comm_id(&self) -> &str { self.content["comm_id"].as_str().unwrap_or("") } // Creates a new child message of this message. ZMQ identities are not transferred. - pub(crate) fn new_message(&self, msg_type: &str) -> JupyterMessage { + pub fn new_message(&self, msg_type: &str) -> JupyterMessage { let mut header = self.header.clone(); header["msg_type"] = serde_json::Value::String(msg_type.to_owned()); header["username"] = serde_json::Value::String("kernel".to_owned()); @@ -200,7 +245,7 @@ impl JupyterMessage { // Creates a reply to this message. This is a child with the message type determined // automatically by replacing "request" with "reply". ZMQ identities are transferred. - pub(crate) fn new_reply(&self) -> JupyterMessage { + pub fn new_reply(&self) -> JupyterMessage { let mut reply = self.new_message(&self.message_type().replace("_request", "_reply")); reply.zmq_identities = self.zmq_identities.clone(); @@ -208,21 +253,18 @@ impl JupyterMessage { } #[must_use = "Need to send this message for it to have any effect"] - pub(crate) fn comm_close_message(&self) -> JupyterMessage { + pub fn comm_close_message(&self) -> JupyterMessage { self.new_message("comm_close").with_content(json!({ "comm_id": self.comm_id() })) } - pub(crate) fn with_content( - mut self, - content: serde_json::Value, - ) -> JupyterMessage { + pub fn with_content(mut self, content: serde_json::Value) -> JupyterMessage { self.content = content; self } - pub(crate) fn with_metadata( + pub fn with_metadata( mut self, metadata: serde_json::Value, ) -> JupyterMessage { @@ -230,46 +272,10 @@ impl JupyterMessage { self } - pub(crate) fn with_buffers(mut self, buffers: Vec<Bytes>) -> JupyterMessage { + pub fn with_buffers(mut self, buffers: Vec<Bytes>) -> JupyterMessage { self.buffers = buffers; self } - - pub(crate) async fn send<S: zeromq::SocketSend>( - &self, - connection: &mut Connection<S>, - ) -> Result<(), AnyError> { - // If performance is a concern, we can probably avoid the clone and to_vec calls with a bit - // of refactoring. - let mut jparts: Vec<Bytes> = vec![ - serde_json::to_string(&self.header) - .unwrap() - .as_bytes() - .to_vec() - .into(), - serde_json::to_string(&self.parent_header) - .unwrap() - .as_bytes() - .to_vec() - .into(), - serde_json::to_string(&self.metadata) - .unwrap() - .as_bytes() - .to_vec() - .into(), - serde_json::to_string(&self.content) - .unwrap() - .as_bytes() - .to_vec() - .into(), - ]; - jparts.extend_from_slice(&self.buffers); - let raw_message = RawMessage { - zmq_identities: self.zmq_identities.clone(), - jparts, - }; - raw_message.send(connection).await - } } impl fmt::Debug for JupyterMessage { |