summaryrefslogtreecommitdiff
path: root/ext/websocket/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ext/websocket/lib.rs')
-rw-r--r--ext/websocket/lib.rs46
1 files changed, 29 insertions, 17 deletions
diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs
index 83d553eeb..0f3456eef 100644
--- a/ext/websocket/lib.rs
+++ b/ext/websocket/lib.rs
@@ -41,17 +41,21 @@ use std::rc::Rc;
use std::sync::Arc;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
+use tokio::io::ReadHalf;
+use tokio::io::WriteHalf;
use tokio::net::TcpStream;
use tokio_rustls::rustls::RootCertStore;
use tokio_rustls::rustls::ServerName;
use tokio_rustls::TlsConnector;
use fastwebsockets::CloseCode;
-use fastwebsockets::FragmentCollector;
+use fastwebsockets::FragmentCollectorRead;
use fastwebsockets::Frame;
use fastwebsockets::OpCode;
use fastwebsockets::Role;
use fastwebsockets::WebSocket;
+use fastwebsockets::WebSocketWrite;
+
mod stream;
static USE_WRITEV: Lazy<bool> = Lazy::new(|| {
@@ -332,12 +336,13 @@ pub struct ServerWebSocket {
closed: Cell<bool>,
buffer: Cell<Option<Vec<u8>>>,
string: Cell<Option<String>>,
- ws: AsyncRefCell<FragmentCollector<WebSocketStream>>,
- tx_lock: AsyncRefCell<()>,
+ ws_read: AsyncRefCell<FragmentCollectorRead<ReadHalf<WebSocketStream>>>,
+ ws_write: AsyncRefCell<WebSocketWrite<WriteHalf<WebSocketStream>>>,
}
impl ServerWebSocket {
fn new(ws: WebSocket<WebSocketStream>) -> Self {
+ let (ws_read, ws_write) = ws.split(tokio::io::split);
Self {
buffered: Cell::new(0),
error: Cell::new(None),
@@ -345,8 +350,8 @@ impl ServerWebSocket {
closed: Cell::new(false),
buffer: Cell::new(None),
string: Cell::new(None),
- ws: AsyncRefCell::new(FragmentCollector::new(ws)),
- tx_lock: AsyncRefCell::new(()),
+ ws_read: AsyncRefCell::new(FragmentCollectorRead::new(ws_read)),
+ ws_write: AsyncRefCell::new(ws_write),
}
}
@@ -361,22 +366,22 @@ impl ServerWebSocket {
}
/// Reserve a lock, but don't wait on it. This gets us our place in line.
- pub fn reserve_lock(self: &Rc<Self>) -> AsyncMutFuture<()> {
- RcRef::map(self, |r| &r.tx_lock).borrow_mut()
+ fn reserve_lock(
+ self: &Rc<Self>,
+ ) -> AsyncMutFuture<WebSocketWrite<WriteHalf<WebSocketStream>>> {
+ RcRef::map(self, |r| &r.ws_write).borrow_mut()
}
#[inline]
- pub async fn write_frame(
+ async fn write_frame(
self: &Rc<Self>,
- lock: AsyncMutFuture<()>,
+ lock: AsyncMutFuture<WebSocketWrite<WriteHalf<WebSocketStream>>>,
frame: Frame<'_>,
) -> Result<(), AnyError> {
- lock.await;
-
- // SAFETY: fastwebsockets only needs a mutable reference to the WebSocket
- // to populate the write buffer. We encounter an await point when writing
- // to the socket after the frame has already been written to the buffer.
- let ws = unsafe { &mut *self.ws.as_ptr() };
+ let mut ws = lock.await;
+ if ws.is_closed() {
+ return Ok(());
+ }
ws.write_frame(frame)
.await
.map_err(|err| type_error(err.to_string()))?;
@@ -405,6 +410,7 @@ pub fn ws_create_server_stream(
ws.set_writev(*USE_WRITEV);
ws.set_auto_close(true);
ws.set_auto_pong(true);
+
let rid = state.resource_table.add(ServerWebSocket::new(ws));
Ok(rid)
}
@@ -627,9 +633,15 @@ pub async fn op_ws_next_event(
return MessageKind::Error as u16;
}
- let mut ws = RcRef::map(&resource, |r| &r.ws).borrow_mut().await;
+ let mut ws = RcRef::map(&resource, |r| &r.ws_read).borrow_mut().await;
+ let writer = RcRef::map(&resource, |r| &r.ws_write);
+ let mut sender = move |frame| {
+ let writer = writer.clone();
+ async move { writer.borrow_mut().await.write_frame(frame).await }
+ };
loop {
- let val = match ws.read_frame().await {
+ let res = ws.read_frame(&mut sender).await;
+ let val = match res {
Ok(val) => val,
Err(err) => {
// No message was received, socket closed while we waited.