summaryrefslogtreecommitdiff
path: root/ext/flash/socket.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ext/flash/socket.rs')
-rw-r--r--ext/flash/socket.rs147
1 files changed, 147 insertions, 0 deletions
diff --git a/ext/flash/socket.rs b/ext/flash/socket.rs
new file mode 100644
index 000000000..8256be8a0
--- /dev/null
+++ b/ext/flash/socket.rs
@@ -0,0 +1,147 @@
+use deno_core::error::AnyError;
+use mio::net::TcpStream;
+use std::{
+ cell::UnsafeCell,
+ future::Future,
+ io::{Read, Write},
+ pin::Pin,
+ sync::{Arc, Mutex},
+};
+use tokio::sync::mpsc;
+
+use crate::ParseStatus;
+
+type TlsTcpStream = rustls::StreamOwned<rustls::ServerConnection, TcpStream>;
+
+pub enum InnerStream {
+ Tcp(TcpStream),
+ Tls(Box<TlsTcpStream>),
+}
+
+pub struct Stream {
+ pub inner: InnerStream,
+ pub detached: bool,
+ pub read_rx: Option<mpsc::Receiver<()>>,
+ pub read_tx: Option<mpsc::Sender<()>>,
+ pub parse_done: ParseStatus,
+ pub buffer: UnsafeCell<Vec<u8>>,
+ pub read_lock: Arc<Mutex<()>>,
+}
+
+impl Stream {
+ pub fn detach_ownership(&mut self) {
+ self.detached = true;
+ }
+
+ /// Try to write to the socket.
+ #[inline]
+ pub fn try_write(&mut self, buf: &[u8]) -> usize {
+ let mut nwritten = 0;
+ while nwritten < buf.len() {
+ match self.write(&buf[nwritten..]) {
+ Ok(n) => nwritten += n,
+ Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
+ break;
+ }
+ Err(e) => {
+ log::trace!("Error writing to socket: {}", e);
+ break;
+ }
+ }
+ }
+ nwritten
+ }
+
+ #[inline]
+ pub fn shutdown(&mut self) {
+ match &mut self.inner {
+ InnerStream::Tcp(stream) => {
+ // Typically shutdown shouldn't fail.
+ let _ = stream.shutdown(std::net::Shutdown::Both);
+ }
+ InnerStream::Tls(stream) => {
+ let _ = stream.sock.shutdown(std::net::Shutdown::Both);
+ }
+ }
+ }
+
+ pub fn as_std(&mut self) -> std::net::TcpStream {
+ #[cfg(unix)]
+ let std_stream = {
+ use std::os::unix::prelude::AsRawFd;
+ use std::os::unix::prelude::FromRawFd;
+ let fd = match self.inner {
+ InnerStream::Tcp(ref tcp) => tcp.as_raw_fd(),
+ _ => todo!(),
+ };
+ // SAFETY: `fd` is a valid file descriptor.
+ unsafe { std::net::TcpStream::from_raw_fd(fd) }
+ };
+ #[cfg(windows)]
+ let std_stream = {
+ use std::os::windows::prelude::AsRawSocket;
+ use std::os::windows::prelude::FromRawSocket;
+ let fd = match self.inner {
+ InnerStream::Tcp(ref tcp) => tcp.as_raw_socket(),
+ _ => todo!(),
+ };
+ // SAFETY: `fd` is a valid file descriptor.
+ unsafe { std::net::TcpStream::from_raw_socket(fd) }
+ };
+ std_stream
+ }
+
+ #[inline]
+ pub async fn with_async_stream<F, T>(&mut self, f: F) -> Result<T, AnyError>
+ where
+ F: FnOnce(
+ &mut tokio::net::TcpStream,
+ ) -> Pin<Box<dyn '_ + Future<Output = Result<T, AnyError>>>>,
+ {
+ let mut async_stream = tokio::net::TcpStream::from_std(self.as_std())?;
+ let result = f(&mut async_stream).await?;
+ forget_stream(async_stream.into_std()?);
+ Ok(result)
+ }
+}
+
+#[inline]
+pub fn forget_stream(stream: std::net::TcpStream) {
+ #[cfg(unix)]
+ {
+ use std::os::unix::prelude::IntoRawFd;
+ let _ = stream.into_raw_fd();
+ }
+ #[cfg(windows)]
+ {
+ use std::os::windows::prelude::IntoRawSocket;
+ let _ = stream.into_raw_socket();
+ }
+}
+
+impl Write for Stream {
+ #[inline]
+ fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
+ match self.inner {
+ InnerStream::Tcp(ref mut stream) => stream.write(buf),
+ InnerStream::Tls(ref mut stream) => stream.write(buf),
+ }
+ }
+ #[inline]
+ fn flush(&mut self) -> std::io::Result<()> {
+ match self.inner {
+ InnerStream::Tcp(ref mut stream) => stream.flush(),
+ InnerStream::Tls(ref mut stream) => stream.flush(),
+ }
+ }
+}
+
+impl Read for Stream {
+ #[inline]
+ fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
+ match self.inner {
+ InnerStream::Tcp(ref mut stream) => stream.read(buf),
+ InnerStream::Tls(ref mut stream) => stream.read(buf),
+ }
+ }
+}