summaryrefslogtreecommitdiff
path: root/ext/tls
diff options
context:
space:
mode:
authorMatt Mastracci <matthew@mastracci.com>2024-05-09 10:54:47 -0600
committerGitHub <noreply@github.com>2024-05-09 10:54:47 -0600
commit684377c92c88877d97c522bcc4cd6a4175277dfb (patch)
tree192e84a3f3daceb5bd47d787eedba32416dcba3c /ext/tls
parentdc29986ae591425f4a653a7155d41d75fbf7931a (diff)
refactor(ext/tls): Implement required functionality for later SNI support (#23686)
Precursor to #23236 This implements the SNI features, but uses private symbols to avoid exposing the functionality at this time. Note that to properly test this feature, we need to add a way for `connectTls` to specify a hostname. This is something that should be pushed into that API at a later time as well. ```ts Deno.test( { permissions: { net: true, read: true } }, async function listenResolver() { let sniRequests = []; const listener = Deno.listenTls({ hostname: "localhost", port: 0, [resolverSymbol]: (sni: string) => { sniRequests.push(sni); return { cert, key, }; }, }); { const conn = await Deno.connectTls({ hostname: "localhost", [serverNameSymbol]: "server-1", port: listener.addr.port, }); const [_handshake, serverConn] = await Promise.all([ conn.handshake(), listener.accept(), ]); conn.close(); serverConn.close(); } { const conn = await Deno.connectTls({ hostname: "localhost", [serverNameSymbol]: "server-2", port: listener.addr.port, }); const [_handshake, serverConn] = await Promise.all([ conn.handshake(), listener.accept(), ]); conn.close(); serverConn.close(); } assertEquals(sniRequests, ["server-1", "server-2"]); listener.close(); }, ); ``` --------- Signed-off-by: Matt Mastracci <matthew@mastracci.com>
Diffstat (limited to 'ext/tls')
-rw-r--r--ext/tls/Cargo.toml1
-rw-r--r--ext/tls/lib.rs47
-rw-r--r--ext/tls/tls_key.rs321
3 files changed, 340 insertions, 29 deletions
diff --git a/ext/tls/Cargo.toml b/ext/tls/Cargo.toml
index 6f587f101..b809b4ebe 100644
--- a/ext/tls/Cargo.toml
+++ b/ext/tls/Cargo.toml
@@ -22,4 +22,5 @@ rustls-pemfile.workspace = true
rustls-tokio-stream.workspace = true
rustls-webpki.workspace = true
serde.workspace = true
+tokio.workspace = true
webpki-roots.workspace = true
diff --git a/ext/tls/lib.rs b/ext/tls/lib.rs
index 7e68971e2..5122264bf 100644
--- a/ext/tls/lib.rs
+++ b/ext/tls/lib.rs
@@ -30,6 +30,9 @@ use std::io::Cursor;
use std::sync::Arc;
use std::time::SystemTime;
+mod tls_key;
+pub use tls_key::*;
+
pub type Certificate = rustls::Certificate;
pub type PrivateKey = rustls::PrivateKey;
pub type RootCertStore = rustls::RootCertStore;
@@ -175,7 +178,7 @@ pub fn create_client_config(
root_cert_store: Option<RootCertStore>,
ca_certs: Vec<Vec<u8>>,
unsafely_ignore_certificate_errors: Option<Vec<String>>,
- maybe_cert_chain_and_key: Option<TlsKey>,
+ maybe_cert_chain_and_key: TlsKeys,
socket_use: SocketUse,
) -> Result<ClientConfig, AnyError> {
if let Some(ic_allowlist) = unsafely_ignore_certificate_errors {
@@ -189,14 +192,13 @@ pub fn create_client_config(
// However it's not really feasible to deduplicate it as the `client_config` instances
// are not type-compatible - one wants "client cert", the other wants "transparency policy
// or client cert".
- let mut client =
- if let Some(TlsKey(cert_chain, private_key)) = maybe_cert_chain_and_key {
- client_config
- .with_client_auth_cert(cert_chain, private_key)
- .expect("invalid client key or certificate")
- } else {
- client_config.with_no_client_auth()
- };
+ let mut client = match maybe_cert_chain_and_key {
+ TlsKeys::Static(TlsKey(cert_chain, private_key)) => client_config
+ .with_client_auth_cert(cert_chain, private_key)
+ .expect("invalid client key or certificate"),
+ TlsKeys::Null => client_config.with_no_client_auth(),
+ TlsKeys::Resolver(_) => unimplemented!(),
+ };
add_alpn(&mut client, socket_use);
return Ok(client);
@@ -226,14 +228,13 @@ pub fn create_client_config(
root_cert_store
});
- let mut client =
- if let Some(TlsKey(cert_chain, private_key)) = maybe_cert_chain_and_key {
- client_config
- .with_client_auth_cert(cert_chain, private_key)
- .expect("invalid client key or certificate")
- } else {
- client_config.with_no_client_auth()
- };
+ let mut client = match maybe_cert_chain_and_key {
+ TlsKeys::Static(TlsKey(cert_chain, private_key)) => client_config
+ .with_client_auth_cert(cert_chain, private_key)
+ .expect("invalid client key or certificate"),
+ TlsKeys::Null => client_config.with_no_client_auth(),
+ TlsKeys::Resolver(_) => unimplemented!(),
+ };
add_alpn(&mut client, socket_use);
Ok(client)
@@ -325,15 +326,3 @@ pub fn load_private_keys(bytes: &[u8]) -> Result<Vec<PrivateKey>, AnyError> {
Ok(keys)
}
-
-/// A loaded key.
-// FUTURE(mmastrac): add resolver enum value to support dynamic SNI
-pub enum TlsKeys {
- // TODO(mmastrac): We need Option<&T> for cppgc -- this is a workaround
- Null,
- Static(TlsKey),
-}
-
-/// A TLS certificate/private key pair.
-#[derive(Clone, Debug)]
-pub struct TlsKey(pub Vec<Certificate>, pub PrivateKey);
diff --git a/ext/tls/tls_key.rs b/ext/tls/tls_key.rs
new file mode 100644
index 000000000..18064a91a
--- /dev/null
+++ b/ext/tls/tls_key.rs
@@ -0,0 +1,321 @@
+// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license.
+
+//! These represent the various types of TLS keys we support for both client and server
+//! connections.
+//!
+//! A TLS key will most often be static, and will loaded from a certificate and key file
+//! or string. These are represented by `TlsKey`, which is stored in `TlsKeys::Static`.
+//!
+//! In more complex cases, you may need a `TlsKeyResolver`/`TlsKeyLookup` pair, which
+//! requires polling of the `TlsKeyLookup` lookup queue. The underlying channels that used for
+//! key lookup can handle closing one end of the pair, in which case they will just
+//! attempt to clean up the associated resources.
+
+use crate::Certificate;
+use crate::PrivateKey;
+use deno_core::anyhow::anyhow;
+use deno_core::error::AnyError;
+use deno_core::futures::future::poll_fn;
+use deno_core::futures::future::Either;
+use deno_core::futures::FutureExt;
+use deno_core::unsync::spawn;
+use rustls::ServerConfig;
+use rustls_tokio_stream::ServerConfigProvider;
+use std::cell::RefCell;
+use std::collections::HashMap;
+use std::fmt::Debug;
+use std::future::ready;
+use std::future::Future;
+use std::io::ErrorKind;
+use std::rc::Rc;
+use std::sync::Arc;
+use tokio::sync::broadcast;
+use tokio::sync::mpsc;
+use tokio::sync::oneshot;
+
+type ErrorType = Rc<AnyError>;
+
+/// A TLS certificate/private key pair.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct TlsKey(pub Vec<Certificate>, pub PrivateKey);
+
+#[derive(Clone, Debug, Default)]
+pub enum TlsKeys {
+ // TODO(mmastrac): We need Option<&T> for cppgc -- this is a workaround
+ #[default]
+ Null,
+ Static(TlsKey),
+ Resolver(TlsKeyResolver),
+}
+
+pub struct TlsKeysHolder(RefCell<TlsKeys>);
+
+impl TlsKeysHolder {
+ pub fn take(&self) -> TlsKeys {
+ std::mem::take(&mut *self.0.borrow_mut())
+ }
+}
+
+impl From<TlsKeys> for TlsKeysHolder {
+ fn from(value: TlsKeys) -> Self {
+ TlsKeysHolder(RefCell::new(value))
+ }
+}
+
+impl TryInto<Option<TlsKey>> for TlsKeys {
+ type Error = Self;
+ fn try_into(self) -> Result<Option<TlsKey>, Self::Error> {
+ match self {
+ Self::Null => Ok(None),
+ Self::Static(key) => Ok(Some(key)),
+ Self::Resolver(_) => Err(self),
+ }
+ }
+}
+
+impl From<Option<TlsKey>> for TlsKeys {
+ fn from(value: Option<TlsKey>) -> Self {
+ match value {
+ None => TlsKeys::Null,
+ Some(key) => TlsKeys::Static(key),
+ }
+ }
+}
+
+enum TlsKeyState {
+ Resolving(broadcast::Receiver<Result<TlsKey, ErrorType>>),
+ Resolved(Result<TlsKey, ErrorType>),
+}
+
+struct TlsKeyResolverInner {
+ resolution_tx: mpsc::UnboundedSender<(
+ String,
+ broadcast::Sender<Result<TlsKey, ErrorType>>,
+ )>,
+ cache: RefCell<HashMap<String, TlsKeyState>>,
+}
+
+#[derive(Clone)]
+pub struct TlsKeyResolver {
+ inner: Rc<TlsKeyResolverInner>,
+}
+
+impl TlsKeyResolver {
+ async fn resolve_internal(
+ &self,
+ sni: String,
+ alpn: Vec<Vec<u8>>,
+ ) -> Result<Arc<ServerConfig>, AnyError> {
+ let key = self.resolve(sni).await?;
+
+ let mut tls_config = ServerConfig::builder()
+ .with_safe_defaults()
+ .with_no_client_auth()
+ .with_single_cert(key.0, key.1)?;
+ tls_config.alpn_protocols = alpn;
+ Ok(tls_config.into())
+ }
+
+ pub fn into_server_config_provider(
+ self,
+ alpn: Vec<Vec<u8>>,
+ ) -> ServerConfigProvider {
+ let (tx, mut rx) = mpsc::unbounded_channel::<(_, oneshot::Sender<_>)>();
+
+ // We don't want to make the resolver multi-threaded, but the `ServerConfigProvider` is
+ // required to be wrapped in an Arc. To fix this, we spawn a task in our current runtime
+ // to respond to the requests.
+ spawn(async move {
+ while let Some((sni, txr)) = rx.recv().await {
+ _ = txr.send(self.resolve_internal(sni, alpn.clone()).await);
+ }
+ });
+
+ Arc::new(move |hello| {
+ // Take ownership of the SNI information
+ let sni = hello.server_name().unwrap_or_default().to_owned();
+ let (txr, rxr) = tokio::sync::oneshot::channel::<_>();
+ _ = tx.send((sni, txr));
+ rxr
+ .map(|res| match res {
+ Err(e) => Err(std::io::Error::new(ErrorKind::InvalidData, e)),
+ Ok(Err(e)) => Err(std::io::Error::new(ErrorKind::InvalidData, e)),
+ Ok(Ok(res)) => Ok(res),
+ })
+ .boxed()
+ })
+ }
+}
+
+impl Debug for TlsKeyResolver {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("TlsKeyResolver").finish()
+ }
+}
+
+pub fn new_resolver() -> (TlsKeyResolver, TlsKeyLookup) {
+ let (resolution_tx, resolution_rx) = mpsc::unbounded_channel();
+ (
+ TlsKeyResolver {
+ inner: Rc::new(TlsKeyResolverInner {
+ resolution_tx,
+ cache: Default::default(),
+ }),
+ },
+ TlsKeyLookup {
+ resolution_rx: RefCell::new(resolution_rx),
+ pending: Default::default(),
+ },
+ )
+}
+
+impl TlsKeyResolver {
+ /// Resolve the certificate and key for a given host. This immediately spawns a task in the
+ /// background and is therefore cancellation-safe.
+ pub fn resolve(
+ &self,
+ sni: String,
+ ) -> impl Future<Output = Result<TlsKey, AnyError>> {
+ let mut cache = self.inner.cache.borrow_mut();
+ let mut recv = match cache.get(&sni) {
+ None => {
+ let (tx, rx) = broadcast::channel(1);
+ cache.insert(sni.clone(), TlsKeyState::Resolving(rx.resubscribe()));
+ _ = self.inner.resolution_tx.send((sni.clone(), tx));
+ rx
+ }
+ Some(TlsKeyState::Resolving(recv)) => recv.resubscribe(),
+ Some(TlsKeyState::Resolved(res)) => {
+ return Either::Left(ready(res.clone().map_err(|_| anyhow!("Failed"))));
+ }
+ };
+ drop(cache);
+
+ // Make this cancellation safe
+ let inner = self.inner.clone();
+ let handle = spawn(async move {
+ let res = recv.recv().await?;
+ let mut cache = inner.cache.borrow_mut();
+ match cache.get(&sni) {
+ None | Some(TlsKeyState::Resolving(..)) => {
+ cache.insert(sni, TlsKeyState::Resolved(res.clone()));
+ }
+ Some(TlsKeyState::Resolved(..)) => {
+ // Someone beat us to it
+ }
+ }
+ res.map_err(|_| anyhow!("Failed"))
+ });
+ Either::Right(async move { handle.await? })
+ }
+}
+
+pub struct TlsKeyLookup {
+ #[allow(clippy::type_complexity)]
+ resolution_rx: RefCell<
+ mpsc::UnboundedReceiver<(
+ String,
+ broadcast::Sender<Result<TlsKey, ErrorType>>,
+ )>,
+ >,
+ pending:
+ RefCell<HashMap<String, broadcast::Sender<Result<TlsKey, ErrorType>>>>,
+}
+
+impl TlsKeyLookup {
+ /// Multiple `poll` calls are safe, but this method is not starvation-safe. Generally
+ /// only one `poll`er should be active at any time.
+ pub async fn poll(&self) -> Option<String> {
+ if let Some((sni, sender)) =
+ poll_fn(|cx| self.resolution_rx.borrow_mut().poll_recv(cx)).await
+ {
+ self.pending.borrow_mut().insert(sni.clone(), sender);
+ Some(sni)
+ } else {
+ None
+ }
+ }
+
+ /// Resolve a previously polled item.
+ pub fn resolve(&self, sni: String, res: Result<TlsKey, AnyError>) {
+ _ = self
+ .pending
+ .borrow_mut()
+ .remove(&sni)
+ .unwrap()
+ .send(res.map_err(Rc::new));
+ }
+}
+
+#[cfg(test)]
+pub mod tests {
+ use super::*;
+ use deno_core::unsync::spawn;
+ use rustls::Certificate;
+ use rustls::PrivateKey;
+
+ fn tls_key_for_test(sni: &str) -> TlsKey {
+ TlsKey(
+ vec![Certificate(format!("{sni}-cert").into_bytes())],
+ PrivateKey(format!("{sni}-key").into_bytes()),
+ )
+ }
+
+ #[tokio::test]
+ async fn test_resolve_once() {
+ let (resolver, lookup) = new_resolver();
+ let task = spawn(async move {
+ while let Some(sni) = lookup.poll().await {
+ lookup.resolve(sni.clone(), Ok(tls_key_for_test(&sni)));
+ }
+ });
+
+ let key = resolver.resolve("example.com".to_owned()).await.unwrap();
+ assert_eq!(tls_key_for_test("example.com"), key);
+ drop(resolver);
+
+ task.await.unwrap();
+ }
+
+ #[tokio::test]
+ async fn test_resolve_concurrent() {
+ let (resolver, lookup) = new_resolver();
+ let task = spawn(async move {
+ while let Some(sni) = lookup.poll().await {
+ lookup.resolve(sni.clone(), Ok(tls_key_for_test(&sni)));
+ }
+ });
+
+ let f1 = resolver.resolve("example.com".to_owned());
+ let f2 = resolver.resolve("example.com".to_owned());
+
+ let key = f1.await.unwrap();
+ assert_eq!(tls_key_for_test("example.com"), key);
+ let key = f2.await.unwrap();
+ assert_eq!(tls_key_for_test("example.com"), key);
+ drop(resolver);
+
+ task.await.unwrap();
+ }
+
+ #[tokio::test]
+ async fn test_resolve_multiple_concurrent() {
+ let (resolver, lookup) = new_resolver();
+ let task = spawn(async move {
+ while let Some(sni) = lookup.poll().await {
+ lookup.resolve(sni.clone(), Ok(tls_key_for_test(&sni)));
+ }
+ });
+
+ let f1 = resolver.resolve("example1.com".to_owned());
+ let f2 = resolver.resolve("example2.com".to_owned());
+
+ let key = f1.await.unwrap();
+ assert_eq!(tls_key_for_test("example1.com"), key);
+ let key = f2.await.unwrap();
+ assert_eq!(tls_key_for_test("example2.com"), key);
+ drop(resolver);
+
+ task.await.unwrap();
+ }
+}