diff options
Diffstat (limited to 'ext/websocket/lib.rs')
-rw-r--r-- | ext/websocket/lib.rs | 151 |
1 files changed, 95 insertions, 56 deletions
diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index b492be0c0..1df71abaa 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -14,7 +14,6 @@ use deno_core::OpState; use deno_core::RcRef; use deno_core::Resource; use deno_core::ResourceId; -use deno_core::StringOrBuffer; use deno_core::ZeroCopyBuf; use deno_net::raw::NetworkStream; use deno_tls::create_client_config; @@ -290,15 +289,8 @@ where state.borrow_mut().resource_table.close(cancel_rid).ok(); } - let resource = ServerWebSocket { - buffered: Cell::new(0), - errored: Cell::new(None), - ws: AsyncRefCell::new(FragmentCollector::new(stream)), - closed: Cell::new(false), - tx_lock: AsyncRefCell::new(()), - }; let mut state = state.borrow_mut(); - let rid = state.resource_table.add(resource); + let rid = state.resource_table.add(ServerWebSocket::new(stream)); let protocol = match response.headers().get("Sec-WebSocket-Protocol") { Some(header) => header.to_str().unwrap(), @@ -323,18 +315,43 @@ pub enum MessageKind { Binary = 1, Pong = 2, Error = 3, - Closed = 4, + ClosedDefault = 1005, } +/// To avoid locks, we keep as much as we can inside of [`Cell`]s. pub struct ServerWebSocket { buffered: Cell<usize>, - errored: Cell<Option<AnyError>>, - ws: AsyncRefCell<FragmentCollector<WebSocketStream>>, + error: Cell<Option<String>>, + errored: Cell<bool>, closed: Cell<bool>, + buffer: Cell<Option<Vec<u8>>>, + ws: AsyncRefCell<FragmentCollector<WebSocketStream>>, tx_lock: AsyncRefCell<()>, } impl ServerWebSocket { + fn new(ws: WebSocket<WebSocketStream>) -> Self { + Self { + buffered: Cell::new(0), + error: Cell::new(None), + errored: Cell::new(false), + closed: Cell::new(false), + buffer: Cell::new(None), + ws: AsyncRefCell::new(FragmentCollector::new(ws)), + tx_lock: AsyncRefCell::new(()), + } + } + + fn set_error(&self, error: Option<String>) { + if let Some(error) = error { + self.error.set(Some(error)); + self.errored.set(true); + } else { + self.error.set(None); + self.errored.set(false); + } + } + #[inline] pub async fn write_frame( self: &Rc<Self>, @@ -374,15 +391,7 @@ pub fn ws_create_server_stream( ws.set_auto_close(true); ws.set_auto_pong(true); - let ws_resource = ServerWebSocket { - buffered: Cell::new(0), - errored: Cell::new(None), - ws: AsyncRefCell::new(FragmentCollector::new(ws)), - closed: Cell::new(false), - tx_lock: AsyncRefCell::new(()), - }; - - let rid = state.resource_table.add(ws_resource); + let rid = state.resource_table.add(ServerWebSocket::new(ws)); Ok(rid) } @@ -401,7 +410,7 @@ pub fn op_ws_send_binary( .write_frame(Frame::new(true, OpCode::Binary, None, data)) .await { - resource.errored.set(Some(err)); + resource.set_error(Some(err.to_string())); } else { resource.buffered.set(resource.buffered.get() - len); } @@ -418,7 +427,7 @@ pub fn op_ws_send_text(state: &mut OpState, rid: ResourceId, data: String) { .write_frame(Frame::new(true, OpCode::Text, None, data.into_bytes())) .await { - resource.errored.set(Some(err)); + resource.set_error(Some(err.to_string())); } else { resource.buffered.set(resource.buffered.get() - len); } @@ -514,18 +523,47 @@ pub async fn op_ws_close( Ok(()) } +#[op] +pub fn op_ws_get_buffer(state: &mut OpState, rid: ResourceId) -> ZeroCopyBuf { + let resource = state.resource_table.get::<ServerWebSocket>(rid).unwrap(); + resource.buffer.take().unwrap().into() +} + +#[op] +pub fn op_ws_get_buffer_as_string( + state: &mut OpState, + rid: ResourceId, +) -> String { + let resource = state.resource_table.get::<ServerWebSocket>(rid).unwrap(); + // TODO(mmastrac): We won't panic on a bad string, but we return an empty one. + String::from_utf8(resource.buffer.take().unwrap()).unwrap_or_default() +} + +#[op] +pub fn op_ws_get_error(state: &mut OpState, rid: ResourceId) -> String { + let Ok(resource) = state.resource_table.get::<ServerWebSocket>(rid) else { + return "Bad resource".into(); + }; + resource.errored.set(false); + resource.error.take().unwrap_or_default() +} + #[op(fast)] pub async fn op_ws_next_event( state: Rc<RefCell<OpState>>, rid: ResourceId, -) -> Result<(u16, StringOrBuffer), AnyError> { - let resource = state +) -> u16 { + let Ok(resource) = state .borrow_mut() .resource_table - .get::<ServerWebSocket>(rid)?; + .get::<ServerWebSocket>(rid) else { + // op_ws_get_error will correctly handle a bad resource + return MessageKind::Error as u16; + }; - if let Some(err) = resource.errored.take() { - return Err(err); + // If there's a pending error, this always returns error + if resource.errored.get() { + return MessageKind::Error as u16; } let mut ws = RcRef::map(&resource, |r| &r.ws).borrow_mut().await; @@ -537,46 +575,44 @@ pub async fn op_ws_next_event( // Try close the stream, ignoring any errors, and report closed status to JavaScript. if resource.closed.get() { let _ = state.borrow_mut().resource_table.close(rid); - return Ok(( - MessageKind::Closed as u16, - StringOrBuffer::Buffer(vec![].into()), - )); + resource.set_error(None); + return MessageKind::ClosedDefault as u16; } - return Ok(( - MessageKind::Error as u16, - StringOrBuffer::String(err.to_string()), - )); + resource.set_error(Some(err.to_string())); + return MessageKind::Error as u16; } }; - break Ok(match val.opcode { - OpCode::Text => ( - MessageKind::Text as u16, - StringOrBuffer::String(String::from_utf8(val.payload).unwrap()), - ), - OpCode::Binary => ( - MessageKind::Binary as u16, - StringOrBuffer::Buffer(val.payload.into()), - ), + break match val.opcode { + OpCode::Text => { + resource.buffer.set(Some(val.payload)); + MessageKind::Text as u16 + } + OpCode::Binary => { + resource.buffer.set(Some(val.payload)); + MessageKind::Binary as u16 + } OpCode::Close => { + // Close reason is returned through error if val.payload.len() < 2 { - return Ok((1005, StringOrBuffer::String("".to_string()))); + resource.set_error(None); + MessageKind::ClosedDefault as u16 + } else { + let close_code = CloseCode::from(u16::from_be_bytes([ + val.payload[0], + val.payload[1], + ])); + let reason = String::from_utf8(val.payload[2..].to_vec()).ok(); + resource.set_error(reason); + close_code.into() } - - let close_code = - CloseCode::from(u16::from_be_bytes([val.payload[0], val.payload[1]])); - let reason = String::from_utf8(val.payload[2..].to_vec()).unwrap(); - (close_code.into(), StringOrBuffer::String(reason)) } - OpCode::Pong => ( - MessageKind::Pong as u16, - StringOrBuffer::Buffer(vec![].into()), - ), + OpCode::Pong => MessageKind::Pong as u16, OpCode::Continuation | OpCode::Ping => { continue; } - }); + }; } } @@ -588,6 +624,9 @@ deno_core::extension!(deno_websocket, op_ws_create<P>, op_ws_close, op_ws_next_event, + op_ws_get_buffer, + op_ws_get_buffer_as_string, + op_ws_get_error, op_ws_send_binary, op_ws_send_text, op_ws_send_binary_async, |