summaryrefslogtreecommitdiff
path: root/extensions/websocket/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'extensions/websocket/lib.rs')
-rw-r--r--extensions/websocket/lib.rs98
1 files changed, 81 insertions, 17 deletions
diff --git a/extensions/websocket/lib.rs b/extensions/websocket/lib.rs
index c6752d23b..f5bf15c79 100644
--- a/extensions/websocket/lib.rs
+++ b/extensions/websocket/lib.rs
@@ -64,13 +64,81 @@ impl WebSocketPermissions for NoWebSocketPermissions {
}
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
-struct WsStreamResource {
- tx: AsyncRefCell<SplitSink<WsStream, Message>>,
- rx: AsyncRefCell<SplitStream<WsStream>>,
+pub enum WebSocketStreamType {
+ Client {
+ tx: AsyncRefCell<SplitSink<WsStream, Message>>,
+ rx: AsyncRefCell<SplitStream<WsStream>>,
+ },
+ Server {
+ tx: AsyncRefCell<
+ SplitSink<WebSocketStream<hyper::upgrade::Upgraded>, Message>,
+ >,
+ rx: AsyncRefCell<SplitStream<WebSocketStream<hyper::upgrade::Upgraded>>>,
+ },
+}
+
+pub struct WsStreamResource {
+ pub stream: WebSocketStreamType,
// When a `WsStreamResource` resource is closed, all pending 'read' ops are
// canceled, while 'write' ops are allowed to complete. Therefore only
// 'read' futures are attached to this cancel handle.
- cancel: CancelHandle,
+ pub cancel: CancelHandle,
+}
+
+impl WsStreamResource {
+ async fn send(self: &Rc<Self>, message: Message) -> Result<(), AnyError> {
+ match self.stream {
+ WebSocketStreamType::Client { .. } => {
+ let mut tx = RcRef::map(self, |r| match &r.stream {
+ WebSocketStreamType::Client { tx, .. } => tx,
+ WebSocketStreamType::Server { .. } => unreachable!(),
+ })
+ .borrow_mut()
+ .await;
+ tx.send(message).await?;
+ }
+ WebSocketStreamType::Server { .. } => {
+ let mut tx = RcRef::map(self, |r| match &r.stream {
+ WebSocketStreamType::Client { .. } => unreachable!(),
+ WebSocketStreamType::Server { tx, .. } => tx,
+ })
+ .borrow_mut()
+ .await;
+ tx.send(message).await?;
+ }
+ }
+
+ Ok(())
+ }
+
+ async fn next_message(
+ self: &Rc<Self>,
+ cancel: RcRef<CancelHandle>,
+ ) -> Result<
+ Option<Result<Message, tokio_tungstenite::tungstenite::Error>>,
+ AnyError,
+ > {
+ match &self.stream {
+ WebSocketStreamType::Client { .. } => {
+ let mut rx = RcRef::map(self, |r| match &r.stream {
+ WebSocketStreamType::Client { rx, .. } => rx,
+ WebSocketStreamType::Server { .. } => unreachable!(),
+ })
+ .borrow_mut()
+ .await;
+ rx.next().or_cancel(cancel).await.map_err(AnyError::from)
+ }
+ WebSocketStreamType::Server { .. } => {
+ let mut rx = RcRef::map(self, |r| match &r.stream {
+ WebSocketStreamType::Client { .. } => unreachable!(),
+ WebSocketStreamType::Server { rx, .. } => rx,
+ })
+ .borrow_mut()
+ .await;
+ rx.next().or_cancel(cancel).await.map_err(AnyError::from)
+ }
+ }
+ }
}
impl Resource for WsStreamResource {
@@ -79,8 +147,6 @@ impl Resource for WsStreamResource {
}
}
-impl WsStreamResource {}
-
// This op is needed because creating a WS instance in JavaScript is a sync
// operation and should throw error when permissions are not fulfilled,
// but actual op that connects WS is async.
@@ -184,8 +250,10 @@ where
let (ws_tx, ws_rx) = stream.split();
let resource = WsStreamResource {
- rx: AsyncRefCell::new(ws_rx),
- tx: AsyncRefCell::new(ws_tx),
+ stream: WebSocketStreamType::Client {
+ rx: AsyncRefCell::new(ws_rx),
+ tx: AsyncRefCell::new(ws_tx),
+ },
cancel: Default::default(),
};
let mut state = state.borrow_mut();
@@ -227,15 +295,13 @@ pub async fn op_ws_send(
"pong" => Message::Pong(vec![]),
_ => unreachable!(),
};
- let rid = args.rid;
let resource = state
.borrow_mut()
.resource_table
- .get::<WsStreamResource>(rid)
+ .get::<WsStreamResource>(args.rid)
.ok_or_else(bad_resource_id)?;
- let mut tx = RcRef::map(&resource, |r| &r.tx).borrow_mut().await;
- tx.send(msg).await?;
+ resource.send(msg).await?;
Ok(())
}
@@ -266,8 +332,7 @@ pub async fn op_ws_close(
.resource_table
.get::<WsStreamResource>(rid)
.ok_or_else(bad_resource_id)?;
- let mut tx = RcRef::map(&resource, |r| &r.tx).borrow_mut().await;
- tx.send(msg).await?;
+ resource.send(msg).await?;
Ok(())
}
@@ -294,9 +359,8 @@ pub async fn op_ws_next_event(
.get::<WsStreamResource>(rid)
.ok_or_else(bad_resource_id)?;
- let mut rx = RcRef::map(&resource, |r| &r.rx).borrow_mut().await;
- let cancel = RcRef::map(resource, |r| &r.cancel);
- let val = rx.next().or_cancel(cancel).await?;
+ let cancel = RcRef::map(&resource, |r| &r.cancel);
+ let val = resource.next_message(cancel).await?;
let res = match val {
Some(Ok(Message::Text(text))) => NextEventResponse::String(text),
Some(Ok(Message::Binary(data))) => NextEventResponse::Binary(data.into()),