diff options
author | Matt Mastracci <matthew@mastracci.com> | 2024-04-08 16:18:14 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-08 16:18:14 -0600 |
commit | 47061a4539feab411fbbd7db5604f4bd4a532051 (patch) | |
tree | 5f6f17066b6f967b1504ef9b762288ad670d1389 /ext/net/raw.rs | |
parent | 6157c8563484e53b1917c811e94e4b5afa01dc67 (diff) |
feat(ext/net): Refactor TCP socket listeners for future clustering mode (#23037)
Changes:
- Implements a TCP socket listener that will allow for round-robin
load-balancing in-process.
- Cleans up the raw networking code to make it easier to work with.
Diffstat (limited to 'ext/net/raw.rs')
-rw-r--r-- | ext/net/raw.rs | 472 |
1 files changed, 272 insertions, 200 deletions
diff --git a/ext/net/raw.rs b/ext/net/raw.rs index c583da3bd..f2de76065 100644 --- a/ext/net/raw.rs +++ b/ext/net/raw.rs @@ -1,176 +1,305 @@ // Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. use crate::io::TcpStreamResource; -#[cfg(unix)] -use crate::io::UnixStreamResource; -use crate::ops::TcpListenerResource; -use crate::ops_tls::TlsListenerResource; use crate::ops_tls::TlsStreamResource; -use crate::ops_tls::TLS_BUFFER_SIZE; -#[cfg(unix)] -use crate::ops_unix::UnixListenerResource; use deno_core::error::bad_resource; use deno_core::error::bad_resource_id; use deno_core::error::AnyError; +use deno_core::AsyncRefCell; +use deno_core::CancelHandle; +use deno_core::Resource; use deno_core::ResourceId; use deno_core::ResourceTable; -use deno_tls::rustls::ServerConfig; -use pin_project::pin_project; -use rustls_tokio_stream::TlsStream; +use std::borrow::Cow; use std::rc::Rc; -use std::sync::Arc; -use tokio::net::TcpStream; -#[cfg(unix)] -use tokio::net::UnixStream; -/// A raw stream of one of the types handled by this extension. -#[pin_project(project = NetworkStreamProject)] -pub enum NetworkStream { - Tcp(#[pin] TcpStream), - Tls(#[pin] TlsStream), - #[cfg(unix)] - Unix(#[pin] UnixStream), +pub trait NetworkStreamTrait: Into<NetworkStream> { + type Resource; + const RESOURCE_NAME: &'static str; + fn local_address(&self) -> Result<NetworkStreamAddress, std::io::Error>; + fn peer_address(&self) -> Result<NetworkStreamAddress, std::io::Error>; } -impl From<TcpStream> for NetworkStream { - fn from(value: TcpStream) -> Self { - NetworkStream::Tcp(value) - } +#[allow(async_fn_in_trait)] +pub trait NetworkStreamListenerTrait: + Into<NetworkStreamListener> + Send + Sync +{ + type Stream: NetworkStreamTrait + 'static; + type Addr: Into<NetworkStreamAddress> + 'static; + /// Additional data, if needed + type ResourceData: Default; + const RESOURCE_NAME: &'static str; + async fn accept(&self) -> std::io::Result<(Self::Stream, Self::Addr)>; + fn listen_address(&self) -> Result<Self::Addr, std::io::Error>; } -impl From<TlsStream> for NetworkStream { - fn from(value: TlsStream) -> Self { - NetworkStream::Tls(value) - } +/// A strongly-typed network listener resource for something that +/// implements `NetworkListenerTrait`. +pub struct NetworkListenerResource<T: NetworkStreamListenerTrait> { + pub listener: AsyncRefCell<T>, + /// Associated data for this resource. Not required. + #[allow(unused)] + pub data: T::ResourceData, + pub cancel: CancelHandle, } -#[cfg(unix)] -impl From<UnixStream> for NetworkStream { - fn from(value: UnixStream) -> Self { - NetworkStream::Unix(value) +impl<T: NetworkStreamListenerTrait + 'static> Resource + for NetworkListenerResource<T> +{ + fn name(&self) -> Cow<str> { + T::RESOURCE_NAME.into() } -} -/// A raw stream of one of the types handled by this extension. -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum NetworkStreamType { - Tcp, - Tls, - #[cfg(unix)] - Unix, + fn close(self: Rc<Self>) { + self.cancel.cancel(); + } } -impl NetworkStream { - pub fn local_address(&self) -> Result<NetworkStreamAddress, std::io::Error> { - match self { - Self::Tcp(tcp) => Ok(NetworkStreamAddress::Ip(tcp.local_addr()?)), - Self::Tls(tls) => Ok(NetworkStreamAddress::Ip(tls.local_addr()?)), - #[cfg(unix)] - Self::Unix(unix) => Ok(NetworkStreamAddress::Unix(unix.local_addr()?)), +impl<T: NetworkStreamListenerTrait + 'static> NetworkListenerResource<T> { + pub fn new(t: T) -> Self { + Self { + listener: AsyncRefCell::new(t), + data: Default::default(), + cancel: Default::default(), } } - pub fn peer_address(&self) -> Result<NetworkStreamAddress, std::io::Error> { - match self { - Self::Tcp(tcp) => Ok(NetworkStreamAddress::Ip(tcp.peer_addr()?)), - Self::Tls(tls) => Ok(NetworkStreamAddress::Ip(tls.peer_addr()?)), - #[cfg(unix)] - Self::Unix(unix) => Ok(NetworkStreamAddress::Unix(unix.peer_addr()?)), + /// Returns a [`NetworkStreamListener`] from this resource if it is not in use elsewhere. + fn take( + resource_table: &mut ResourceTable, + listener_rid: ResourceId, + ) -> Result<Option<NetworkStreamListener>, AnyError> { + if let Ok(resource_rc) = resource_table.take::<Self>(listener_rid) { + let resource = Rc::try_unwrap(resource_rc) + .map_err(|_| bad_resource("Listener is currently in use"))?; + return Ok(Some(resource.listener.into_inner().into())); } + Ok(None) } +} - pub fn stream(&self) -> NetworkStreamType { - match self { - Self::Tcp(_) => NetworkStreamType::Tcp, - Self::Tls(_) => NetworkStreamType::Tls, - #[cfg(unix)] - Self::Unix(_) => NetworkStreamType::Unix, +/// Each of the network streams has the exact same pattern for listening, accepting, etc, so +/// we just codegen them all via macro to avoid repeating each one of these N times. +macro_rules! network_stream { + ( $([$i:ident, $il:ident, $stream:path, $listener:path, $addr:path, $stream_resource:ty]),* ) => { + /// A raw stream of one of the types handled by this extension. + #[pin_project::pin_project(project = NetworkStreamProject)] + pub enum NetworkStream { + $( $i (#[pin] $stream), )* } - } -} -impl tokio::io::AsyncRead for NetworkStream { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll<std::io::Result<()>> { - match self.project() { - NetworkStreamProject::Tcp(s) => s.poll_read(cx, buf), - NetworkStreamProject::Tls(s) => s.poll_read(cx, buf), - #[cfg(unix)] - NetworkStreamProject::Unix(s) => s.poll_read(cx, buf), + /// A raw stream of one of the types handled by this extension. + #[derive(Copy, Clone, PartialEq, Eq)] + pub enum NetworkStreamType { + $( $i, )* } - } -} -impl tokio::io::AsyncWrite for NetworkStream { - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll<Result<usize, std::io::Error>> { - match self.project() { - NetworkStreamProject::Tcp(s) => s.poll_write(cx, buf), - NetworkStreamProject::Tls(s) => s.poll_write(cx, buf), - #[cfg(unix)] - NetworkStreamProject::Unix(s) => s.poll_write(cx, buf), + /// A raw stream listener of one of the types handled by this extension. + pub enum NetworkStreamListener { + $( $i( $listener ), )* } - } - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Result<(), std::io::Error>> { - match self.project() { - NetworkStreamProject::Tcp(s) => s.poll_flush(cx), - NetworkStreamProject::Tls(s) => s.poll_flush(cx), - #[cfg(unix)] - NetworkStreamProject::Unix(s) => s.poll_flush(cx), + $( + impl NetworkStreamListenerTrait for $listener { + type Stream = $stream; + type Addr = $addr; + type ResourceData = (); + const RESOURCE_NAME: &'static str = concat!(stringify!($il), "Listener"); + async fn accept(&self) -> std::io::Result<(Self::Stream, Self::Addr)> { + <$listener> :: accept(self).await + } + fn listen_address(&self) -> std::io::Result<Self::Addr> { + self.local_addr() + } + } + + impl From<$listener> for NetworkStreamListener { + fn from(value: $listener) -> Self { + Self::$i(value) + } + } + + impl NetworkStreamTrait for $stream { + type Resource = $stream_resource; + const RESOURCE_NAME: &'static str = concat!(stringify!($il), "Stream"); + fn local_address(&self) -> Result<NetworkStreamAddress, std::io::Error> { + Ok(NetworkStreamAddress::from(self.local_addr()?)) + } + fn peer_address(&self) -> Result<NetworkStreamAddress, std::io::Error> { + Ok(NetworkStreamAddress::from(self.peer_addr()?)) + } + } + + impl From<$stream> for NetworkStream { + fn from(value: $stream) -> Self { + Self::$i(value) + } + } + )* + + impl NetworkStream { + pub fn local_address(&self) -> Result<NetworkStreamAddress, std::io::Error> { + match self { + $( Self::$i(stm) => Ok(NetworkStreamAddress::from(stm.local_addr()?)), )* + } + } + + pub fn peer_address(&self) -> Result<NetworkStreamAddress, std::io::Error> { + match self { + $( Self::$i(stm) => Ok(NetworkStreamAddress::from(stm.peer_addr()?)), )* + } + } + + pub fn stream(&self) -> NetworkStreamType { + match self { + $( Self::$i(_) => NetworkStreamType::$i, )* + } + } } - } - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Result<(), std::io::Error>> { - match self.project() { - NetworkStreamProject::Tcp(s) => s.poll_shutdown(cx), - NetworkStreamProject::Tls(s) => s.poll_shutdown(cx), - #[cfg(unix)] - NetworkStreamProject::Unix(s) => s.poll_shutdown(cx), + impl tokio::io::AsyncRead for NetworkStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll<std::io::Result<()>> { + match self.project() { + $( NetworkStreamProject::$i(s) => s.poll_read(cx, buf), )* + } + } } - } - fn is_write_vectored(&self) -> bool { - match self { - Self::Tcp(s) => s.is_write_vectored(), - Self::Tls(s) => s.is_write_vectored(), - #[cfg(unix)] - Self::Unix(s) => s.is_write_vectored(), + impl tokio::io::AsyncWrite for NetworkStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll<Result<usize, std::io::Error>> { + match self.project() { + $( NetworkStreamProject::$i(s) => s.poll_write(cx, buf), )* + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), std::io::Error>> { + match self.project() { + $( NetworkStreamProject::$i(s) => s.poll_flush(cx), )* + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), std::io::Error>> { + match self.project() { + $( NetworkStreamProject::$i(s) => s.poll_shutdown(cx), )* + } + } + + fn is_write_vectored(&self) -> bool { + match self { + $( NetworkStream::$i(s) => s.is_write_vectored(), )* + } + } + + fn poll_write_vectored( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll<Result<usize, std::io::Error>> { + match self.project() { + $( NetworkStreamProject::$i(s) => s.poll_write_vectored(cx, bufs), )* + } + } } - } - fn poll_write_vectored( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> std::task::Poll<Result<usize, std::io::Error>> { - match self.project() { - NetworkStreamProject::Tcp(s) => s.poll_write_vectored(cx, bufs), - NetworkStreamProject::Tls(s) => s.poll_write_vectored(cx, bufs), - #[cfg(unix)] - NetworkStreamProject::Unix(s) => s.poll_write_vectored(cx, bufs), + impl NetworkStreamListener { + /// Accepts a connection on this listener. + pub async fn accept(&self) -> Result<(NetworkStream, NetworkStreamAddress), std::io::Error> { + Ok(match self { + $( + Self::$i(s) => { + let (stm, addr) = s.accept().await?; + (NetworkStream::$i(stm), addr.into()) + } + )* + }) + } + + pub fn listen_address(&self) -> Result<NetworkStreamAddress, std::io::Error> { + match self { + $( Self::$i(s) => { Ok(NetworkStreamAddress::from(s.listen_address()?)) } )* + } + } + + pub fn stream(&self) -> NetworkStreamType { + match self { + $( Self::$i(_) => { NetworkStreamType::$i } )* + } + } + + /// Return a `NetworkStreamListener` if a resource exists for this `ResourceId` and it is currently + /// not locked. + pub fn take_resource(resource_table: &mut ResourceTable, listener_rid: ResourceId) -> Result<NetworkStreamListener, AnyError> { + $( + if let Some(resource) = NetworkListenerResource::<$listener>::take(resource_table, listener_rid)? { + return Ok(resource) + } + )* + Err(bad_resource_id()) + } } - } + }; } -/// A raw stream listener of one of the types handled by this extension. -pub enum NetworkStreamListener { - Tcp(tokio::net::TcpListener), - Tls(tokio::net::TcpListener, Arc<ServerConfig>), - #[cfg(unix)] - Unix(tokio::net::UnixListener), -} +#[cfg(unix)] +network_stream!( + [ + Tcp, + tcp, + tokio::net::TcpStream, + crate::tcp::TcpListener, + std::net::SocketAddr, + TcpStreamResource + ], + [ + Tls, + tls, + crate::ops_tls::TlsStream, + crate::ops_tls::TlsListener, + std::net::SocketAddr, + TlsStreamResource + ], + [ + Unix, + unix, + tokio::net::UnixStream, + tokio::net::UnixListener, + tokio::net::unix::SocketAddr, + crate::io::UnixStreamResource + ] +); + +#[cfg(not(unix))] +network_stream!( + [ + Tcp, + tcp, + tokio::net::TcpStream, + crate::tcp::TcpListener, + std::net::SocketAddr, + TcpStreamResource + ], + [ + Tls, + tls, + crate::ops_tls::TlsStream, + crate::ops_tls::TlsListener, + std::net::SocketAddr, + TlsStreamResource + ] +); pub enum NetworkStreamAddress { Ip(std::net::SocketAddr), @@ -178,46 +307,16 @@ pub enum NetworkStreamAddress { Unix(tokio::net::unix::SocketAddr), } -impl NetworkStreamListener { - /// Accepts a connection on this listener. - pub async fn accept(&self) -> Result<NetworkStream, std::io::Error> { - Ok(match self { - Self::Tcp(tcp) => { - let (stream, _addr) = tcp.accept().await?; - NetworkStream::Tcp(stream) - } - Self::Tls(tcp, config) => { - let (stream, _addr) = tcp.accept().await?; - NetworkStream::Tls(TlsStream::new_server_side( - stream, - config.clone(), - TLS_BUFFER_SIZE, - )) - } - #[cfg(unix)] - Self::Unix(unix) => { - let (stream, _addr) = unix.accept().await?; - NetworkStream::Unix(stream) - } - }) - } - - pub fn listen_address(&self) -> Result<NetworkStreamAddress, std::io::Error> { - match self { - Self::Tcp(tcp) => Ok(NetworkStreamAddress::Ip(tcp.local_addr()?)), - Self::Tls(tcp, _) => Ok(NetworkStreamAddress::Ip(tcp.local_addr()?)), - #[cfg(unix)] - Self::Unix(unix) => Ok(NetworkStreamAddress::Unix(unix.local_addr()?)), - } +impl From<std::net::SocketAddr> for NetworkStreamAddress { + fn from(value: std::net::SocketAddr) -> Self { + NetworkStreamAddress::Ip(value) } +} - pub fn stream(&self) -> NetworkStreamType { - match self { - Self::Tcp(..) => NetworkStreamType::Tcp, - Self::Tls(..) => NetworkStreamType::Tls, - #[cfg(unix)] - Self::Unix(..) => NetworkStreamType::Unix, - } +#[cfg(unix)] +impl From<tokio::net::unix::SocketAddr> for NetworkStreamAddress { + fn from(value: tokio::net::unix::SocketAddr) -> Self { + NetworkStreamAddress::Unix(value) } } @@ -252,7 +351,8 @@ pub fn take_network_stream_resource( } #[cfg(unix)] - if let Ok(resource_rc) = resource_table.take::<UnixStreamResource>(stream_rid) + if let Ok(resource_rc) = + resource_table.take::<crate::io::UnixStreamResource>(stream_rid) { // This UNIX socket might be used somewhere else. let resource = Rc::try_unwrap(resource_rc) @@ -271,33 +371,5 @@ pub fn take_network_stream_listener_resource( resource_table: &mut ResourceTable, listener_rid: ResourceId, ) -> Result<NetworkStreamListener, AnyError> { - if let Ok(resource_rc) = - resource_table.take::<TcpListenerResource>(listener_rid) - { - let resource = Rc::try_unwrap(resource_rc) - .map_err(|_| bad_resource("TCP socket listener is currently in use"))?; - return Ok(NetworkStreamListener::Tcp(resource.listener.into_inner())); - } - - if let Ok(resource_rc) = - resource_table.take::<TlsListenerResource>(listener_rid) - { - let resource = Rc::try_unwrap(resource_rc) - .map_err(|_| bad_resource("TLS socket listener is currently in use"))?; - return Ok(NetworkStreamListener::Tls( - resource.tcp_listener.into_inner(), - resource.tls_config, - )); - } - - #[cfg(unix)] - if let Ok(resource_rc) = - resource_table.take::<UnixListenerResource>(listener_rid) - { - let resource = Rc::try_unwrap(resource_rc) - .map_err(|_| bad_resource("UNIX socket listener is currently in use"))?; - return Ok(NetworkStreamListener::Unix(resource.listener.into_inner())); - } - - Err(bad_resource_id()) + NetworkStreamListener::take_resource(resource_table, listener_rid) } |