summaryrefslogtreecommitdiff
path: root/ext/net/raw.rs
diff options
context:
space:
mode:
authorMatt Mastracci <matthew@mastracci.com>2024-04-08 16:18:14 -0600
committerGitHub <noreply@github.com>2024-04-08 16:18:14 -0600
commit47061a4539feab411fbbd7db5604f4bd4a532051 (patch)
tree5f6f17066b6f967b1504ef9b762288ad670d1389 /ext/net/raw.rs
parent6157c8563484e53b1917c811e94e4b5afa01dc67 (diff)
feat(ext/net): Refactor TCP socket listeners for future clustering mode (#23037)
Changes: - Implements a TCP socket listener that will allow for round-robin load-balancing in-process. - Cleans up the raw networking code to make it easier to work with.
Diffstat (limited to 'ext/net/raw.rs')
-rw-r--r--ext/net/raw.rs472
1 files changed, 272 insertions, 200 deletions
diff --git a/ext/net/raw.rs b/ext/net/raw.rs
index c583da3bd..f2de76065 100644
--- a/ext/net/raw.rs
+++ b/ext/net/raw.rs
@@ -1,176 +1,305 @@
// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
use crate::io::TcpStreamResource;
-#[cfg(unix)]
-use crate::io::UnixStreamResource;
-use crate::ops::TcpListenerResource;
-use crate::ops_tls::TlsListenerResource;
use crate::ops_tls::TlsStreamResource;
-use crate::ops_tls::TLS_BUFFER_SIZE;
-#[cfg(unix)]
-use crate::ops_unix::UnixListenerResource;
use deno_core::error::bad_resource;
use deno_core::error::bad_resource_id;
use deno_core::error::AnyError;
+use deno_core::AsyncRefCell;
+use deno_core::CancelHandle;
+use deno_core::Resource;
use deno_core::ResourceId;
use deno_core::ResourceTable;
-use deno_tls::rustls::ServerConfig;
-use pin_project::pin_project;
-use rustls_tokio_stream::TlsStream;
+use std::borrow::Cow;
use std::rc::Rc;
-use std::sync::Arc;
-use tokio::net::TcpStream;
-#[cfg(unix)]
-use tokio::net::UnixStream;
-/// A raw stream of one of the types handled by this extension.
-#[pin_project(project = NetworkStreamProject)]
-pub enum NetworkStream {
- Tcp(#[pin] TcpStream),
- Tls(#[pin] TlsStream),
- #[cfg(unix)]
- Unix(#[pin] UnixStream),
+pub trait NetworkStreamTrait: Into<NetworkStream> {
+ type Resource;
+ const RESOURCE_NAME: &'static str;
+ fn local_address(&self) -> Result<NetworkStreamAddress, std::io::Error>;
+ fn peer_address(&self) -> Result<NetworkStreamAddress, std::io::Error>;
}
-impl From<TcpStream> for NetworkStream {
- fn from(value: TcpStream) -> Self {
- NetworkStream::Tcp(value)
- }
+#[allow(async_fn_in_trait)]
+pub trait NetworkStreamListenerTrait:
+ Into<NetworkStreamListener> + Send + Sync
+{
+ type Stream: NetworkStreamTrait + 'static;
+ type Addr: Into<NetworkStreamAddress> + 'static;
+ /// Additional data, if needed
+ type ResourceData: Default;
+ const RESOURCE_NAME: &'static str;
+ async fn accept(&self) -> std::io::Result<(Self::Stream, Self::Addr)>;
+ fn listen_address(&self) -> Result<Self::Addr, std::io::Error>;
}
-impl From<TlsStream> for NetworkStream {
- fn from(value: TlsStream) -> Self {
- NetworkStream::Tls(value)
- }
+/// A strongly-typed network listener resource for something that
+/// implements `NetworkListenerTrait`.
+pub struct NetworkListenerResource<T: NetworkStreamListenerTrait> {
+ pub listener: AsyncRefCell<T>,
+ /// Associated data for this resource. Not required.
+ #[allow(unused)]
+ pub data: T::ResourceData,
+ pub cancel: CancelHandle,
}
-#[cfg(unix)]
-impl From<UnixStream> for NetworkStream {
- fn from(value: UnixStream) -> Self {
- NetworkStream::Unix(value)
+impl<T: NetworkStreamListenerTrait + 'static> Resource
+ for NetworkListenerResource<T>
+{
+ fn name(&self) -> Cow<str> {
+ T::RESOURCE_NAME.into()
}
-}
-/// A raw stream of one of the types handled by this extension.
-#[derive(Copy, Clone, PartialEq, Eq)]
-pub enum NetworkStreamType {
- Tcp,
- Tls,
- #[cfg(unix)]
- Unix,
+ fn close(self: Rc<Self>) {
+ self.cancel.cancel();
+ }
}
-impl NetworkStream {
- pub fn local_address(&self) -> Result<NetworkStreamAddress, std::io::Error> {
- match self {
- Self::Tcp(tcp) => Ok(NetworkStreamAddress::Ip(tcp.local_addr()?)),
- Self::Tls(tls) => Ok(NetworkStreamAddress::Ip(tls.local_addr()?)),
- #[cfg(unix)]
- Self::Unix(unix) => Ok(NetworkStreamAddress::Unix(unix.local_addr()?)),
+impl<T: NetworkStreamListenerTrait + 'static> NetworkListenerResource<T> {
+ pub fn new(t: T) -> Self {
+ Self {
+ listener: AsyncRefCell::new(t),
+ data: Default::default(),
+ cancel: Default::default(),
}
}
- pub fn peer_address(&self) -> Result<NetworkStreamAddress, std::io::Error> {
- match self {
- Self::Tcp(tcp) => Ok(NetworkStreamAddress::Ip(tcp.peer_addr()?)),
- Self::Tls(tls) => Ok(NetworkStreamAddress::Ip(tls.peer_addr()?)),
- #[cfg(unix)]
- Self::Unix(unix) => Ok(NetworkStreamAddress::Unix(unix.peer_addr()?)),
+ /// Returns a [`NetworkStreamListener`] from this resource if it is not in use elsewhere.
+ fn take(
+ resource_table: &mut ResourceTable,
+ listener_rid: ResourceId,
+ ) -> Result<Option<NetworkStreamListener>, AnyError> {
+ if let Ok(resource_rc) = resource_table.take::<Self>(listener_rid) {
+ let resource = Rc::try_unwrap(resource_rc)
+ .map_err(|_| bad_resource("Listener is currently in use"))?;
+ return Ok(Some(resource.listener.into_inner().into()));
}
+ Ok(None)
}
+}
- pub fn stream(&self) -> NetworkStreamType {
- match self {
- Self::Tcp(_) => NetworkStreamType::Tcp,
- Self::Tls(_) => NetworkStreamType::Tls,
- #[cfg(unix)]
- Self::Unix(_) => NetworkStreamType::Unix,
+/// Each of the network streams has the exact same pattern for listening, accepting, etc, so
+/// we just codegen them all via macro to avoid repeating each one of these N times.
+macro_rules! network_stream {
+ ( $([$i:ident, $il:ident, $stream:path, $listener:path, $addr:path, $stream_resource:ty]),* ) => {
+ /// A raw stream of one of the types handled by this extension.
+ #[pin_project::pin_project(project = NetworkStreamProject)]
+ pub enum NetworkStream {
+ $( $i (#[pin] $stream), )*
}
- }
-}
-impl tokio::io::AsyncRead for NetworkStream {
- fn poll_read(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- buf: &mut tokio::io::ReadBuf<'_>,
- ) -> std::task::Poll<std::io::Result<()>> {
- match self.project() {
- NetworkStreamProject::Tcp(s) => s.poll_read(cx, buf),
- NetworkStreamProject::Tls(s) => s.poll_read(cx, buf),
- #[cfg(unix)]
- NetworkStreamProject::Unix(s) => s.poll_read(cx, buf),
+ /// A raw stream of one of the types handled by this extension.
+ #[derive(Copy, Clone, PartialEq, Eq)]
+ pub enum NetworkStreamType {
+ $( $i, )*
}
- }
-}
-impl tokio::io::AsyncWrite for NetworkStream {
- fn poll_write(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- buf: &[u8],
- ) -> std::task::Poll<Result<usize, std::io::Error>> {
- match self.project() {
- NetworkStreamProject::Tcp(s) => s.poll_write(cx, buf),
- NetworkStreamProject::Tls(s) => s.poll_write(cx, buf),
- #[cfg(unix)]
- NetworkStreamProject::Unix(s) => s.poll_write(cx, buf),
+ /// A raw stream listener of one of the types handled by this extension.
+ pub enum NetworkStreamListener {
+ $( $i( $listener ), )*
}
- }
- fn poll_flush(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Result<(), std::io::Error>> {
- match self.project() {
- NetworkStreamProject::Tcp(s) => s.poll_flush(cx),
- NetworkStreamProject::Tls(s) => s.poll_flush(cx),
- #[cfg(unix)]
- NetworkStreamProject::Unix(s) => s.poll_flush(cx),
+ $(
+ impl NetworkStreamListenerTrait for $listener {
+ type Stream = $stream;
+ type Addr = $addr;
+ type ResourceData = ();
+ const RESOURCE_NAME: &'static str = concat!(stringify!($il), "Listener");
+ async fn accept(&self) -> std::io::Result<(Self::Stream, Self::Addr)> {
+ <$listener> :: accept(self).await
+ }
+ fn listen_address(&self) -> std::io::Result<Self::Addr> {
+ self.local_addr()
+ }
+ }
+
+ impl From<$listener> for NetworkStreamListener {
+ fn from(value: $listener) -> Self {
+ Self::$i(value)
+ }
+ }
+
+ impl NetworkStreamTrait for $stream {
+ type Resource = $stream_resource;
+ const RESOURCE_NAME: &'static str = concat!(stringify!($il), "Stream");
+ fn local_address(&self) -> Result<NetworkStreamAddress, std::io::Error> {
+ Ok(NetworkStreamAddress::from(self.local_addr()?))
+ }
+ fn peer_address(&self) -> Result<NetworkStreamAddress, std::io::Error> {
+ Ok(NetworkStreamAddress::from(self.peer_addr()?))
+ }
+ }
+
+ impl From<$stream> for NetworkStream {
+ fn from(value: $stream) -> Self {
+ Self::$i(value)
+ }
+ }
+ )*
+
+ impl NetworkStream {
+ pub fn local_address(&self) -> Result<NetworkStreamAddress, std::io::Error> {
+ match self {
+ $( Self::$i(stm) => Ok(NetworkStreamAddress::from(stm.local_addr()?)), )*
+ }
+ }
+
+ pub fn peer_address(&self) -> Result<NetworkStreamAddress, std::io::Error> {
+ match self {
+ $( Self::$i(stm) => Ok(NetworkStreamAddress::from(stm.peer_addr()?)), )*
+ }
+ }
+
+ pub fn stream(&self) -> NetworkStreamType {
+ match self {
+ $( Self::$i(_) => NetworkStreamType::$i, )*
+ }
+ }
}
- }
- fn poll_shutdown(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Result<(), std::io::Error>> {
- match self.project() {
- NetworkStreamProject::Tcp(s) => s.poll_shutdown(cx),
- NetworkStreamProject::Tls(s) => s.poll_shutdown(cx),
- #[cfg(unix)]
- NetworkStreamProject::Unix(s) => s.poll_shutdown(cx),
+ impl tokio::io::AsyncRead for NetworkStream {
+ fn poll_read(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &mut tokio::io::ReadBuf<'_>,
+ ) -> std::task::Poll<std::io::Result<()>> {
+ match self.project() {
+ $( NetworkStreamProject::$i(s) => s.poll_read(cx, buf), )*
+ }
+ }
}
- }
- fn is_write_vectored(&self) -> bool {
- match self {
- Self::Tcp(s) => s.is_write_vectored(),
- Self::Tls(s) => s.is_write_vectored(),
- #[cfg(unix)]
- Self::Unix(s) => s.is_write_vectored(),
+ impl tokio::io::AsyncWrite for NetworkStream {
+ fn poll_write(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &[u8],
+ ) -> std::task::Poll<Result<usize, std::io::Error>> {
+ match self.project() {
+ $( NetworkStreamProject::$i(s) => s.poll_write(cx, buf), )*
+ }
+ }
+
+ fn poll_flush(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), std::io::Error>> {
+ match self.project() {
+ $( NetworkStreamProject::$i(s) => s.poll_flush(cx), )*
+ }
+ }
+
+ fn poll_shutdown(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), std::io::Error>> {
+ match self.project() {
+ $( NetworkStreamProject::$i(s) => s.poll_shutdown(cx), )*
+ }
+ }
+
+ fn is_write_vectored(&self) -> bool {
+ match self {
+ $( NetworkStream::$i(s) => s.is_write_vectored(), )*
+ }
+ }
+
+ fn poll_write_vectored(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ bufs: &[std::io::IoSlice<'_>],
+ ) -> std::task::Poll<Result<usize, std::io::Error>> {
+ match self.project() {
+ $( NetworkStreamProject::$i(s) => s.poll_write_vectored(cx, bufs), )*
+ }
+ }
}
- }
- fn poll_write_vectored(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- bufs: &[std::io::IoSlice<'_>],
- ) -> std::task::Poll<Result<usize, std::io::Error>> {
- match self.project() {
- NetworkStreamProject::Tcp(s) => s.poll_write_vectored(cx, bufs),
- NetworkStreamProject::Tls(s) => s.poll_write_vectored(cx, bufs),
- #[cfg(unix)]
- NetworkStreamProject::Unix(s) => s.poll_write_vectored(cx, bufs),
+ impl NetworkStreamListener {
+ /// Accepts a connection on this listener.
+ pub async fn accept(&self) -> Result<(NetworkStream, NetworkStreamAddress), std::io::Error> {
+ Ok(match self {
+ $(
+ Self::$i(s) => {
+ let (stm, addr) = s.accept().await?;
+ (NetworkStream::$i(stm), addr.into())
+ }
+ )*
+ })
+ }
+
+ pub fn listen_address(&self) -> Result<NetworkStreamAddress, std::io::Error> {
+ match self {
+ $( Self::$i(s) => { Ok(NetworkStreamAddress::from(s.listen_address()?)) } )*
+ }
+ }
+
+ pub fn stream(&self) -> NetworkStreamType {
+ match self {
+ $( Self::$i(_) => { NetworkStreamType::$i } )*
+ }
+ }
+
+ /// Return a `NetworkStreamListener` if a resource exists for this `ResourceId` and it is currently
+ /// not locked.
+ pub fn take_resource(resource_table: &mut ResourceTable, listener_rid: ResourceId) -> Result<NetworkStreamListener, AnyError> {
+ $(
+ if let Some(resource) = NetworkListenerResource::<$listener>::take(resource_table, listener_rid)? {
+ return Ok(resource)
+ }
+ )*
+ Err(bad_resource_id())
+ }
}
- }
+ };
}
-/// A raw stream listener of one of the types handled by this extension.
-pub enum NetworkStreamListener {
- Tcp(tokio::net::TcpListener),
- Tls(tokio::net::TcpListener, Arc<ServerConfig>),
- #[cfg(unix)]
- Unix(tokio::net::UnixListener),
-}
+#[cfg(unix)]
+network_stream!(
+ [
+ Tcp,
+ tcp,
+ tokio::net::TcpStream,
+ crate::tcp::TcpListener,
+ std::net::SocketAddr,
+ TcpStreamResource
+ ],
+ [
+ Tls,
+ tls,
+ crate::ops_tls::TlsStream,
+ crate::ops_tls::TlsListener,
+ std::net::SocketAddr,
+ TlsStreamResource
+ ],
+ [
+ Unix,
+ unix,
+ tokio::net::UnixStream,
+ tokio::net::UnixListener,
+ tokio::net::unix::SocketAddr,
+ crate::io::UnixStreamResource
+ ]
+);
+
+#[cfg(not(unix))]
+network_stream!(
+ [
+ Tcp,
+ tcp,
+ tokio::net::TcpStream,
+ crate::tcp::TcpListener,
+ std::net::SocketAddr,
+ TcpStreamResource
+ ],
+ [
+ Tls,
+ tls,
+ crate::ops_tls::TlsStream,
+ crate::ops_tls::TlsListener,
+ std::net::SocketAddr,
+ TlsStreamResource
+ ]
+);
pub enum NetworkStreamAddress {
Ip(std::net::SocketAddr),
@@ -178,46 +307,16 @@ pub enum NetworkStreamAddress {
Unix(tokio::net::unix::SocketAddr),
}
-impl NetworkStreamListener {
- /// Accepts a connection on this listener.
- pub async fn accept(&self) -> Result<NetworkStream, std::io::Error> {
- Ok(match self {
- Self::Tcp(tcp) => {
- let (stream, _addr) = tcp.accept().await?;
- NetworkStream::Tcp(stream)
- }
- Self::Tls(tcp, config) => {
- let (stream, _addr) = tcp.accept().await?;
- NetworkStream::Tls(TlsStream::new_server_side(
- stream,
- config.clone(),
- TLS_BUFFER_SIZE,
- ))
- }
- #[cfg(unix)]
- Self::Unix(unix) => {
- let (stream, _addr) = unix.accept().await?;
- NetworkStream::Unix(stream)
- }
- })
- }
-
- pub fn listen_address(&self) -> Result<NetworkStreamAddress, std::io::Error> {
- match self {
- Self::Tcp(tcp) => Ok(NetworkStreamAddress::Ip(tcp.local_addr()?)),
- Self::Tls(tcp, _) => Ok(NetworkStreamAddress::Ip(tcp.local_addr()?)),
- #[cfg(unix)]
- Self::Unix(unix) => Ok(NetworkStreamAddress::Unix(unix.local_addr()?)),
- }
+impl From<std::net::SocketAddr> for NetworkStreamAddress {
+ fn from(value: std::net::SocketAddr) -> Self {
+ NetworkStreamAddress::Ip(value)
}
+}
- pub fn stream(&self) -> NetworkStreamType {
- match self {
- Self::Tcp(..) => NetworkStreamType::Tcp,
- Self::Tls(..) => NetworkStreamType::Tls,
- #[cfg(unix)]
- Self::Unix(..) => NetworkStreamType::Unix,
- }
+#[cfg(unix)]
+impl From<tokio::net::unix::SocketAddr> for NetworkStreamAddress {
+ fn from(value: tokio::net::unix::SocketAddr) -> Self {
+ NetworkStreamAddress::Unix(value)
}
}
@@ -252,7 +351,8 @@ pub fn take_network_stream_resource(
}
#[cfg(unix)]
- if let Ok(resource_rc) = resource_table.take::<UnixStreamResource>(stream_rid)
+ if let Ok(resource_rc) =
+ resource_table.take::<crate::io::UnixStreamResource>(stream_rid)
{
// This UNIX socket might be used somewhere else.
let resource = Rc::try_unwrap(resource_rc)
@@ -271,33 +371,5 @@ pub fn take_network_stream_listener_resource(
resource_table: &mut ResourceTable,
listener_rid: ResourceId,
) -> Result<NetworkStreamListener, AnyError> {
- if let Ok(resource_rc) =
- resource_table.take::<TcpListenerResource>(listener_rid)
- {
- let resource = Rc::try_unwrap(resource_rc)
- .map_err(|_| bad_resource("TCP socket listener is currently in use"))?;
- return Ok(NetworkStreamListener::Tcp(resource.listener.into_inner()));
- }
-
- if let Ok(resource_rc) =
- resource_table.take::<TlsListenerResource>(listener_rid)
- {
- let resource = Rc::try_unwrap(resource_rc)
- .map_err(|_| bad_resource("TLS socket listener is currently in use"))?;
- return Ok(NetworkStreamListener::Tls(
- resource.tcp_listener.into_inner(),
- resource.tls_config,
- ));
- }
-
- #[cfg(unix)]
- if let Ok(resource_rc) =
- resource_table.take::<UnixListenerResource>(listener_rid)
- {
- let resource = Rc::try_unwrap(resource_rc)
- .map_err(|_| bad_resource("UNIX socket listener is currently in use"))?;
- return Ok(NetworkStreamListener::Unix(resource.listener.into_inner()));
- }
-
- Err(bad_resource_id())
+ NetworkStreamListener::take_resource(resource_table, listener_rid)
}