From 486437fee1ac53610a901b07bda91909844ec9ab Mon Sep 17 00:00:00 2001 From: David Sherret Date: Tue, 30 Apr 2024 22:30:40 -0400 Subject: [PATCH] refactor(jupyter): move communication methods out of data structs (#23622) Moves the communication methods out of the data structs and onto the `Connection` struct. --- cli/ops/jupyter.rs | 16 ++- cli/tools/jupyter/jupyter_msg.rs | 218 ++++++++++++++++--------------- cli/tools/jupyter/mod.rs | 4 +- cli/tools/jupyter/server.rs | 148 +++++++++------------ 4 files changed, 188 insertions(+), 198 deletions(-) diff --git a/cli/ops/jupyter.rs b/cli/ops/jupyter.rs index 1d727c33fb..e7e206de58 100644 --- a/cli/ops/jupyter.rs +++ b/cli/ops/jupyter.rs @@ -50,12 +50,16 @@ pub async fn op_jupyter_broadcast( let maybe_last_request = last_execution_request.borrow().clone(); if let Some(last_request) = maybe_last_request { - last_request - .new_message(&message_type) - .with_content(content) - .with_metadata(metadata) - .with_buffers(buffers.into_iter().map(|b| b.to_vec().into()).collect()) - .send(&mut *iopub_socket.lock().await) + (*iopub_socket.lock().await) + .send( + &last_request + .new_message(&message_type) + .with_content(content) + .with_metadata(metadata) + .with_buffers( + buffers.into_iter().map(|b| b.to_vec().into()).collect(), + ), + ) .await?; } diff --git a/cli/tools/jupyter/jupyter_msg.rs b/cli/tools/jupyter/jupyter_msg.rs index 60703e3656..233efcc8e4 100644 --- a/cli/tools/jupyter/jupyter_msg.rs +++ b/cli/tools/jupyter/jupyter_msg.rs @@ -16,14 +16,14 @@ use uuid::Uuid; use crate::util::time::utc_now; -pub(crate) struct Connection { - pub(crate) socket: S, +pub struct Connection { + socket: S, /// Will be None if our key was empty (digest authentication disabled). - pub(crate) mac: Option, + mac: Option, } impl Connection { - pub(crate) fn new(socket: S, key: &str) -> Self { + pub fn new(socket: S, key: &str) -> Self { let mac = if key.is_empty() { None } else { @@ -33,21 +33,107 @@ impl Connection { } } +impl Connection { + pub async fn single_heartbeat(&mut self) -> Result<(), AnyError> { + self.socket.recv().await?; + self + .socket + .send(zeromq::ZmqMessage::from(b"ping".to_vec())) + .await?; + Ok(()) + } +} + +impl Connection { + pub async fn read(&mut self) -> Result { + let multipart = self.socket.recv().await?; + let raw_message = RawMessage::from_multipart(multipart, self.mac.as_ref())?; + JupyterMessage::from_raw_message(raw_message) + } +} + +impl Connection { + pub async fn send( + &mut self, + message: &JupyterMessage, + ) -> Result<(), AnyError> { + // If performance is a concern, we can probably avoid the clone and to_vec calls with a bit + // of refactoring. + let mut jparts: Vec = vec![ + serde_json::to_string(&message.header) + .unwrap() + .as_bytes() + .to_vec() + .into(), + serde_json::to_string(&message.parent_header) + .unwrap() + .as_bytes() + .to_vec() + .into(), + serde_json::to_string(&message.metadata) + .unwrap() + .as_bytes() + .to_vec() + .into(), + serde_json::to_string(&message.content) + .unwrap() + .as_bytes() + .to_vec() + .into(), + ]; + jparts.extend_from_slice(&message.buffers); + let raw_message = RawMessage { + zmq_identities: message.zmq_identities.clone(), + jparts, + }; + self.send_raw(raw_message).await + } + + async fn send_raw( + &mut self, + raw_message: RawMessage, + ) -> Result<(), AnyError> { + let hmac = if let Some(key) = &self.mac { + let ctx = digest(key, &raw_message.jparts); + let tag = ctx.sign(); + HEXLOWER.encode(tag.as_ref()) + } else { + String::new() + }; + let mut parts: Vec = Vec::new(); + for part in &raw_message.zmq_identities { + parts.push(part.to_vec().into()); + } + parts.push(DELIMITER.into()); + parts.push(hmac.as_bytes().to_vec().into()); + for part in &raw_message.jparts { + parts.push(part.to_vec().into()); + } + // ZmqMessage::try_from only fails if parts is empty, which it never + // will be here. + let message = zeromq::ZmqMessage::try_from(parts).unwrap(); + self.socket.send(message).await?; + Ok(()) + } +} + +fn digest(mac: &hmac::Key, jparts: &[Bytes]) -> hmac::Context { + let mut hmac_ctx = hmac::Context::with_key(mac); + for part in jparts { + hmac_ctx.update(part); + } + hmac_ctx +} + struct RawMessage { zmq_identities: Vec, jparts: Vec, } impl RawMessage { - pub(crate) async fn read( - connection: &mut Connection, - ) -> Result { - Self::from_multipart(connection.socket.recv().await?, connection) - } - - pub(crate) fn from_multipart( + pub fn from_multipart( multipart: zeromq::ZmqMessage, - connection: &Connection, + mac: Option<&hmac::Key>, ) -> Result { let delimiter_index = multipart .iter() @@ -65,7 +151,7 @@ impl RawMessage { jparts, }; - if let Some(key) = &connection.mac { + if let Some(key) = mac { let sig = HEXLOWER.decode(&expected_hmac)?; let mut msg = Vec::new(); for part in &raw_message.jparts { @@ -79,45 +165,10 @@ impl RawMessage { Ok(raw_message) } - - async fn send( - self, - connection: &mut Connection, - ) -> Result<(), AnyError> { - let hmac = if let Some(key) = &connection.mac { - let ctx = self.digest(key); - let tag = ctx.sign(); - HEXLOWER.encode(tag.as_ref()) - } else { - String::new() - }; - let mut parts: Vec = Vec::new(); - for part in &self.zmq_identities { - parts.push(part.to_vec().into()); - } - parts.push(DELIMITER.into()); - parts.push(hmac.as_bytes().to_vec().into()); - for part in &self.jparts { - parts.push(part.to_vec().into()); - } - // ZmqMessage::try_from only fails if parts is empty, which it never - // will be here. - let message = zeromq::ZmqMessage::try_from(parts).unwrap(); - connection.socket.send(message).await?; - Ok(()) - } - - fn digest(&self, mac: &hmac::Key) -> hmac::Context { - let mut hmac_ctx = hmac::Context::with_key(mac); - for part in &self.jparts { - hmac_ctx.update(part); - } - hmac_ctx - } } #[derive(Clone)] -pub(crate) struct JupyterMessage { +pub struct JupyterMessage { zmq_identities: Vec, header: serde_json::Value, parent_header: serde_json::Value, @@ -129,12 +180,6 @@ pub(crate) struct JupyterMessage { const DELIMITER: &[u8] = b""; impl JupyterMessage { - pub(crate) async fn read( - connection: &mut Connection, - ) -> Result { - Self::from_raw_message(RawMessage::read(connection).await?) - } - fn from_raw_message( raw_message: RawMessage, ) -> Result { @@ -156,32 +201,32 @@ impl JupyterMessage { }) } - pub(crate) fn message_type(&self) -> &str { + pub fn message_type(&self) -> &str { self.header["msg_type"].as_str().unwrap_or("") } - pub(crate) fn store_history(&self) -> bool { + pub fn store_history(&self) -> bool { self.content["store_history"].as_bool().unwrap_or(true) } - pub(crate) fn silent(&self) -> bool { + pub fn silent(&self) -> bool { self.content["silent"].as_bool().unwrap_or(false) } - pub(crate) fn code(&self) -> &str { + pub fn code(&self) -> &str { self.content["code"].as_str().unwrap_or("") } - pub(crate) fn cursor_pos(&self) -> usize { + pub fn cursor_pos(&self) -> usize { self.content["cursor_pos"].as_u64().unwrap_or(0) as usize } - pub(crate) fn comm_id(&self) -> &str { + pub fn comm_id(&self) -> &str { self.content["comm_id"].as_str().unwrap_or("") } // Creates a new child message of this message. ZMQ identities are not transferred. - pub(crate) fn new_message(&self, msg_type: &str) -> JupyterMessage { + pub fn new_message(&self, msg_type: &str) -> JupyterMessage { let mut header = self.header.clone(); header["msg_type"] = serde_json::Value::String(msg_type.to_owned()); header["username"] = serde_json::Value::String("kernel".to_owned()); @@ -200,7 +245,7 @@ impl JupyterMessage { // Creates a reply to this message. This is a child with the message type determined // automatically by replacing "request" with "reply". ZMQ identities are transferred. - pub(crate) fn new_reply(&self) -> JupyterMessage { + pub fn new_reply(&self) -> JupyterMessage { let mut reply = self.new_message(&self.message_type().replace("_request", "_reply")); reply.zmq_identities = self.zmq_identities.clone(); @@ -208,21 +253,18 @@ impl JupyterMessage { } #[must_use = "Need to send this message for it to have any effect"] - pub(crate) fn comm_close_message(&self) -> JupyterMessage { + pub fn comm_close_message(&self) -> JupyterMessage { self.new_message("comm_close").with_content(json!({ "comm_id": self.comm_id() })) } - pub(crate) fn with_content( - mut self, - content: serde_json::Value, - ) -> JupyterMessage { + pub fn with_content(mut self, content: serde_json::Value) -> JupyterMessage { self.content = content; self } - pub(crate) fn with_metadata( + pub fn with_metadata( mut self, metadata: serde_json::Value, ) -> JupyterMessage { @@ -230,46 +272,10 @@ impl JupyterMessage { self } - pub(crate) fn with_buffers(mut self, buffers: Vec) -> JupyterMessage { + pub fn with_buffers(mut self, buffers: Vec) -> JupyterMessage { self.buffers = buffers; self } - - pub(crate) async fn send( - &self, - connection: &mut Connection, - ) -> Result<(), AnyError> { - // If performance is a concern, we can probably avoid the clone and to_vec calls with a bit - // of refactoring. - let mut jparts: Vec = vec![ - serde_json::to_string(&self.header) - .unwrap() - .as_bytes() - .to_vec() - .into(), - serde_json::to_string(&self.parent_header) - .unwrap() - .as_bytes() - .to_vec() - .into(), - serde_json::to_string(&self.metadata) - .unwrap() - .as_bytes() - .to_vec() - .into(), - serde_json::to_string(&self.content) - .unwrap() - .as_bytes() - .to_vec() - .into(), - ]; - jparts.extend_from_slice(&self.buffers); - let raw_message = RawMessage { - zmq_identities: self.zmq_identities.clone(), - jparts, - }; - raw_message.send(connection).await - } } impl fmt::Debug for JupyterMessage { diff --git a/cli/tools/jupyter/mod.rs b/cli/tools/jupyter/mod.rs index 6531b0339a..da1c4bc4d8 100644 --- a/cli/tools/jupyter/mod.rs +++ b/cli/tools/jupyter/mod.rs @@ -28,8 +28,8 @@ use tokio::sync::mpsc; use tokio::sync::mpsc::UnboundedSender; mod install; -pub(crate) mod jupyter_msg; -pub(crate) mod server; +pub mod jupyter_msg; +pub mod server; pub async fn kernel( flags: Flags, diff --git a/cli/tools/jupyter/server.rs b/cli/tools/jupyter/server.rs index cd02b9891e..2107dcfbfc 100644 --- a/cli/tools/jupyter/server.rs +++ b/cli/tools/jupyter/server.rs @@ -18,8 +18,6 @@ use deno_core::CancelFuture; use deno_core::CancelHandle; use tokio::sync::mpsc; use tokio::sync::Mutex; -use zeromq::SocketRecv; -use zeromq::SocketSend; use super::jupyter_msg::Connection; use super::jupyter_msg::JupyterMessage; @@ -67,7 +65,6 @@ impl JupyterServer { } let cancel_handle = CancelHandle::new_rc(); - let cancel_handle2 = CancelHandle::new_rc(); let mut server = Self { execution_count: 0, @@ -82,11 +79,14 @@ impl JupyterServer { } }); - let handle2 = deno_core::unsync::spawn(async move { - if let Err(err) = - Self::handle_control(control_socket, cancel_handle2).await - { - eprintln!("Control error: {}", err); + let handle2 = deno_core::unsync::spawn({ + let cancel_handle = cancel_handle.clone(); + async move { + if let Err(err) = + Self::handle_control(control_socket, cancel_handle).await + { + eprintln!("Control error: {}", err); + } } }); @@ -129,13 +129,11 @@ impl JupyterServer { StdioMsg::Stderr(text) => ("stderr", text), }; - let result = exec_request - .new_message("stream") - .with_content(json!({ + let result = (*iopub_socket.lock().await) + .send(&exec_request.new_message("stream").with_content(json!({ "name": name, "text": text - })) - .send(&mut *iopub_socket.lock().await) + }))) .await; if let Err(err) = result { @@ -148,11 +146,7 @@ impl JupyterServer { connection: &mut Connection, ) -> Result<(), AnyError> { loop { - connection.socket.recv().await?; - connection - .socket - .send(zeromq::ZmqMessage::from(b"ping".to_vec())) - .await?; + connection.single_heartbeat().await?; } } @@ -161,13 +155,11 @@ impl JupyterServer { cancel_handle: Rc, ) -> Result<(), AnyError> { loop { - let msg = JupyterMessage::read(&mut connection).await?; + let msg = connection.read().await?; match msg.message_type() { "kernel_info_request" => { - msg - .new_reply() - .with_content(kernel_info()) - .send(&mut connection) + connection + .send(&msg.new_reply().with_content(kernel_info())) .await?; } "shutdown_request" => { @@ -191,7 +183,7 @@ impl JupyterServer { mut connection: Connection, ) -> Result<(), AnyError> { loop { - let msg = JupyterMessage::read(&mut connection).await?; + let msg = connection.read().await?; self.handle_shell_message(msg, &mut connection).await?; } } @@ -201,25 +193,23 @@ impl JupyterServer { msg: JupyterMessage, connection: &mut Connection, ) -> Result<(), AnyError> { - msg - .new_message("status") - .with_content(json!({"execution_state": "busy"})) - .send(&mut *self.iopub_socket.lock().await) + self + .send_iopub( + &msg + .new_message("status") + .with_content(json!({"execution_state": "busy"})), + ) .await?; match msg.message_type() { "kernel_info_request" => { - msg - .new_reply() - .with_content(kernel_info()) - .send(connection) + connection + .send(&msg.new_reply().with_content(kernel_info())) .await?; } "is_complete_request" => { - msg - .new_reply() - .with_content(json!({"status": "complete"})) - .send(connection) + connection + .send(&msg.new_reply().with_content(json!({"status": "complete"}))) .await?; } "execute_request" => { @@ -228,10 +218,7 @@ impl JupyterServer { .await?; } "comm_open" => { - msg - .comm_close_message() - .send(&mut *self.iopub_socket.lock().await) - .await?; + self.send_iopub(&msg.comm_close_message()).await?; } "complete_request" => { let user_code = msg.code(); @@ -259,16 +246,14 @@ impl JupyterServer { .map(|item| item.range.end) .unwrap_or(cursor_pos); - msg - .new_reply() - .with_content(json!({ + connection + .send(&msg.new_reply().with_content(json!({ "status": "ok", "matches": matches, "cursor_start": cursor_start, "cursor_end": cursor_end, "metadata": {}, - })) - .send(connection) + }))) .await?; } else { let expr = get_expr_from_line_at_pos(user_code, cursor_pos); @@ -307,16 +292,14 @@ impl JupyterServer { (candidates, cursor_pos - expr.len()) }; - msg - .new_reply() - .with_content(json!({ + connection + .send(&msg.new_reply().with_content(json!({ "status": "ok", "matches": completions, "cursor_start": cursor_start, "cursor_end": cursor_pos, "metadata": {}, - })) - .send(connection) + }))) .await?; } } @@ -328,10 +311,12 @@ impl JupyterServer { } } - msg - .new_message("status") - .with_content(json!({"execution_state": "idle"})) - .send(&mut *self.iopub_socket.lock().await) + self + .send_iopub( + &msg + .new_message("status") + .with_content(json!({"execution_state": "idle"})), + ) .await?; Ok(()) } @@ -346,13 +331,11 @@ impl JupyterServer { } *self.last_execution_request.borrow_mut() = Some(msg.clone()); - msg - .new_message("execute_input") - .with_content(json!({ + self + .send_iopub(&msg.new_message("execute_input").with_content(json!({ "execution_count": self.execution_count, "code": msg.code() - })) - .send(&mut *self.iopub_socket.lock().await) + }))) .await?; let result = self @@ -363,22 +346,18 @@ impl JupyterServer { let evaluate_response = match result { Ok(eval_response) => eval_response, Err(err) => { - msg - .new_message("error") - .with_content(json!({ + self + .send_iopub(&msg.new_message("error").with_content(json!({ "ename": err.to_string(), "evalue": " ", // Fake value, otherwise old Jupyter frontends don't show the error "traceback": [], - })) - .send(&mut *self.iopub_socket.lock().await) + }))) .await?; - msg - .new_reply() - .with_content(json!({ + connection + .send(&msg.new_reply().with_content(json!({ "status": "error", "execution_count": self.execution_count, - })) - .send(connection) + }))) .await?; return Ok(()); } @@ -393,14 +372,12 @@ impl JupyterServer { publish_result(&mut self.repl_session, &result, self.execution_count) .await?; - msg - .new_reply() - .with_content(json!({ + connection + .send(&msg.new_reply().with_content(json!({ "status": "ok", "execution_count": self.execution_count, // FIXME: also include user_expressions - })) - .send(connection) + }))) .await?; // Let's sleep here for a few ms, so we give a chance to the task that is // handling stdout and stderr streams to receive and flush the content. @@ -479,27 +456,30 @@ impl JupyterServer { message }; - msg - .new_message("error") - .with_content(json!({ + self + .send_iopub(&msg.new_message("error").with_content(json!({ "ename": ename, "evalue": evalue, "traceback": traceback, - })) - .send(&mut *self.iopub_socket.lock().await) + }))) .await?; - msg - .new_reply() - .with_content(json!({ + connection + .send(&msg.new_reply().with_content(json!({ "status": "error", "execution_count": self.execution_count, - })) - .send(connection) + }))) .await?; } Ok(()) } + + async fn send_iopub( + &mut self, + message: &JupyterMessage, + ) -> Result<(), AnyError> { + self.iopub_socket.lock().await.send(message).await + } } async fn bind_socket(