diff options
author | Matt Mastracci <matthew@mastracci.com> | 2024-05-09 10:54:47 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-09 10:54:47 -0600 |
commit | 684377c92c88877d97c522bcc4cd6a4175277dfb (patch) | |
tree | 192e84a3f3daceb5bd47d787eedba32416dcba3c /ext/tls | |
parent | dc29986ae591425f4a653a7155d41d75fbf7931a (diff) |
refactor(ext/tls): Implement required functionality for later SNI support (#23686)
Precursor to #23236
This implements the SNI features, but uses private symbols to avoid
exposing the functionality at this time. Note that to properly test this
feature, we need to add a way for `connectTls` to specify a hostname.
This is something that should be pushed into that API at a later time as
well.
```ts
Deno.test(
{ permissions: { net: true, read: true } },
async function listenResolver() {
let sniRequests = [];
const listener = Deno.listenTls({
hostname: "localhost",
port: 0,
[resolverSymbol]: (sni: string) => {
sniRequests.push(sni);
return {
cert,
key,
};
},
});
{
const conn = await Deno.connectTls({
hostname: "localhost",
[serverNameSymbol]: "server-1",
port: listener.addr.port,
});
const [_handshake, serverConn] = await Promise.all([
conn.handshake(),
listener.accept(),
]);
conn.close();
serverConn.close();
}
{
const conn = await Deno.connectTls({
hostname: "localhost",
[serverNameSymbol]: "server-2",
port: listener.addr.port,
});
const [_handshake, serverConn] = await Promise.all([
conn.handshake(),
listener.accept(),
]);
conn.close();
serverConn.close();
}
assertEquals(sniRequests, ["server-1", "server-2"]);
listener.close();
},
);
```
---------
Signed-off-by: Matt Mastracci <matthew@mastracci.com>
Diffstat (limited to 'ext/tls')
-rw-r--r-- | ext/tls/Cargo.toml | 1 | ||||
-rw-r--r-- | ext/tls/lib.rs | 47 | ||||
-rw-r--r-- | ext/tls/tls_key.rs | 321 |
3 files changed, 340 insertions, 29 deletions
diff --git a/ext/tls/Cargo.toml b/ext/tls/Cargo.toml index 6f587f101..b809b4ebe 100644 --- a/ext/tls/Cargo.toml +++ b/ext/tls/Cargo.toml @@ -22,4 +22,5 @@ rustls-pemfile.workspace = true rustls-tokio-stream.workspace = true rustls-webpki.workspace = true serde.workspace = true +tokio.workspace = true webpki-roots.workspace = true diff --git a/ext/tls/lib.rs b/ext/tls/lib.rs index 7e68971e2..5122264bf 100644 --- a/ext/tls/lib.rs +++ b/ext/tls/lib.rs @@ -30,6 +30,9 @@ use std::io::Cursor; use std::sync::Arc; use std::time::SystemTime; +mod tls_key; +pub use tls_key::*; + pub type Certificate = rustls::Certificate; pub type PrivateKey = rustls::PrivateKey; pub type RootCertStore = rustls::RootCertStore; @@ -175,7 +178,7 @@ pub fn create_client_config( root_cert_store: Option<RootCertStore>, ca_certs: Vec<Vec<u8>>, unsafely_ignore_certificate_errors: Option<Vec<String>>, - maybe_cert_chain_and_key: Option<TlsKey>, + maybe_cert_chain_and_key: TlsKeys, socket_use: SocketUse, ) -> Result<ClientConfig, AnyError> { if let Some(ic_allowlist) = unsafely_ignore_certificate_errors { @@ -189,14 +192,13 @@ pub fn create_client_config( // However it's not really feasible to deduplicate it as the `client_config` instances // are not type-compatible - one wants "client cert", the other wants "transparency policy // or client cert". - let mut client = - if let Some(TlsKey(cert_chain, private_key)) = maybe_cert_chain_and_key { - client_config - .with_client_auth_cert(cert_chain, private_key) - .expect("invalid client key or certificate") - } else { - client_config.with_no_client_auth() - }; + let mut client = match maybe_cert_chain_and_key { + TlsKeys::Static(TlsKey(cert_chain, private_key)) => client_config + .with_client_auth_cert(cert_chain, private_key) + .expect("invalid client key or certificate"), + TlsKeys::Null => client_config.with_no_client_auth(), + TlsKeys::Resolver(_) => unimplemented!(), + }; add_alpn(&mut client, socket_use); return Ok(client); @@ -226,14 +228,13 @@ pub fn create_client_config( root_cert_store }); - let mut client = - if let Some(TlsKey(cert_chain, private_key)) = maybe_cert_chain_and_key { - client_config - .with_client_auth_cert(cert_chain, private_key) - .expect("invalid client key or certificate") - } else { - client_config.with_no_client_auth() - }; + let mut client = match maybe_cert_chain_and_key { + TlsKeys::Static(TlsKey(cert_chain, private_key)) => client_config + .with_client_auth_cert(cert_chain, private_key) + .expect("invalid client key or certificate"), + TlsKeys::Null => client_config.with_no_client_auth(), + TlsKeys::Resolver(_) => unimplemented!(), + }; add_alpn(&mut client, socket_use); Ok(client) @@ -325,15 +326,3 @@ pub fn load_private_keys(bytes: &[u8]) -> Result<Vec<PrivateKey>, AnyError> { Ok(keys) } - -/// A loaded key. -// FUTURE(mmastrac): add resolver enum value to support dynamic SNI -pub enum TlsKeys { - // TODO(mmastrac): We need Option<&T> for cppgc -- this is a workaround - Null, - Static(TlsKey), -} - -/// A TLS certificate/private key pair. -#[derive(Clone, Debug)] -pub struct TlsKey(pub Vec<Certificate>, pub PrivateKey); diff --git a/ext/tls/tls_key.rs b/ext/tls/tls_key.rs new file mode 100644 index 000000000..18064a91a --- /dev/null +++ b/ext/tls/tls_key.rs @@ -0,0 +1,321 @@ +// Copyright 2018-2024 the Deno authors. All rights reserved. MIT license. + +//! These represent the various types of TLS keys we support for both client and server +//! connections. +//! +//! A TLS key will most often be static, and will loaded from a certificate and key file +//! or string. These are represented by `TlsKey`, which is stored in `TlsKeys::Static`. +//! +//! In more complex cases, you may need a `TlsKeyResolver`/`TlsKeyLookup` pair, which +//! requires polling of the `TlsKeyLookup` lookup queue. The underlying channels that used for +//! key lookup can handle closing one end of the pair, in which case they will just +//! attempt to clean up the associated resources. + +use crate::Certificate; +use crate::PrivateKey; +use deno_core::anyhow::anyhow; +use deno_core::error::AnyError; +use deno_core::futures::future::poll_fn; +use deno_core::futures::future::Either; +use deno_core::futures::FutureExt; +use deno_core::unsync::spawn; +use rustls::ServerConfig; +use rustls_tokio_stream::ServerConfigProvider; +use std::cell::RefCell; +use std::collections::HashMap; +use std::fmt::Debug; +use std::future::ready; +use std::future::Future; +use std::io::ErrorKind; +use std::rc::Rc; +use std::sync::Arc; +use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tokio::sync::oneshot; + +type ErrorType = Rc<AnyError>; + +/// A TLS certificate/private key pair. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct TlsKey(pub Vec<Certificate>, pub PrivateKey); + +#[derive(Clone, Debug, Default)] +pub enum TlsKeys { + // TODO(mmastrac): We need Option<&T> for cppgc -- this is a workaround + #[default] + Null, + Static(TlsKey), + Resolver(TlsKeyResolver), +} + +pub struct TlsKeysHolder(RefCell<TlsKeys>); + +impl TlsKeysHolder { + pub fn take(&self) -> TlsKeys { + std::mem::take(&mut *self.0.borrow_mut()) + } +} + +impl From<TlsKeys> for TlsKeysHolder { + fn from(value: TlsKeys) -> Self { + TlsKeysHolder(RefCell::new(value)) + } +} + +impl TryInto<Option<TlsKey>> for TlsKeys { + type Error = Self; + fn try_into(self) -> Result<Option<TlsKey>, Self::Error> { + match self { + Self::Null => Ok(None), + Self::Static(key) => Ok(Some(key)), + Self::Resolver(_) => Err(self), + } + } +} + +impl From<Option<TlsKey>> for TlsKeys { + fn from(value: Option<TlsKey>) -> Self { + match value { + None => TlsKeys::Null, + Some(key) => TlsKeys::Static(key), + } + } +} + +enum TlsKeyState { + Resolving(broadcast::Receiver<Result<TlsKey, ErrorType>>), + Resolved(Result<TlsKey, ErrorType>), +} + +struct TlsKeyResolverInner { + resolution_tx: mpsc::UnboundedSender<( + String, + broadcast::Sender<Result<TlsKey, ErrorType>>, + )>, + cache: RefCell<HashMap<String, TlsKeyState>>, +} + +#[derive(Clone)] +pub struct TlsKeyResolver { + inner: Rc<TlsKeyResolverInner>, +} + +impl TlsKeyResolver { + async fn resolve_internal( + &self, + sni: String, + alpn: Vec<Vec<u8>>, + ) -> Result<Arc<ServerConfig>, AnyError> { + let key = self.resolve(sni).await?; + + let mut tls_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(key.0, key.1)?; + tls_config.alpn_protocols = alpn; + Ok(tls_config.into()) + } + + pub fn into_server_config_provider( + self, + alpn: Vec<Vec<u8>>, + ) -> ServerConfigProvider { + let (tx, mut rx) = mpsc::unbounded_channel::<(_, oneshot::Sender<_>)>(); + + // We don't want to make the resolver multi-threaded, but the `ServerConfigProvider` is + // required to be wrapped in an Arc. To fix this, we spawn a task in our current runtime + // to respond to the requests. + spawn(async move { + while let Some((sni, txr)) = rx.recv().await { + _ = txr.send(self.resolve_internal(sni, alpn.clone()).await); + } + }); + + Arc::new(move |hello| { + // Take ownership of the SNI information + let sni = hello.server_name().unwrap_or_default().to_owned(); + let (txr, rxr) = tokio::sync::oneshot::channel::<_>(); + _ = tx.send((sni, txr)); + rxr + .map(|res| match res { + Err(e) => Err(std::io::Error::new(ErrorKind::InvalidData, e)), + Ok(Err(e)) => Err(std::io::Error::new(ErrorKind::InvalidData, e)), + Ok(Ok(res)) => Ok(res), + }) + .boxed() + }) + } +} + +impl Debug for TlsKeyResolver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TlsKeyResolver").finish() + } +} + +pub fn new_resolver() -> (TlsKeyResolver, TlsKeyLookup) { + let (resolution_tx, resolution_rx) = mpsc::unbounded_channel(); + ( + TlsKeyResolver { + inner: Rc::new(TlsKeyResolverInner { + resolution_tx, + cache: Default::default(), + }), + }, + TlsKeyLookup { + resolution_rx: RefCell::new(resolution_rx), + pending: Default::default(), + }, + ) +} + +impl TlsKeyResolver { + /// Resolve the certificate and key for a given host. This immediately spawns a task in the + /// background and is therefore cancellation-safe. + pub fn resolve( + &self, + sni: String, + ) -> impl Future<Output = Result<TlsKey, AnyError>> { + let mut cache = self.inner.cache.borrow_mut(); + let mut recv = match cache.get(&sni) { + None => { + let (tx, rx) = broadcast::channel(1); + cache.insert(sni.clone(), TlsKeyState::Resolving(rx.resubscribe())); + _ = self.inner.resolution_tx.send((sni.clone(), tx)); + rx + } + Some(TlsKeyState::Resolving(recv)) => recv.resubscribe(), + Some(TlsKeyState::Resolved(res)) => { + return Either::Left(ready(res.clone().map_err(|_| anyhow!("Failed")))); + } + }; + drop(cache); + + // Make this cancellation safe + let inner = self.inner.clone(); + let handle = spawn(async move { + let res = recv.recv().await?; + let mut cache = inner.cache.borrow_mut(); + match cache.get(&sni) { + None | Some(TlsKeyState::Resolving(..)) => { + cache.insert(sni, TlsKeyState::Resolved(res.clone())); + } + Some(TlsKeyState::Resolved(..)) => { + // Someone beat us to it + } + } + res.map_err(|_| anyhow!("Failed")) + }); + Either::Right(async move { handle.await? }) + } +} + +pub struct TlsKeyLookup { + #[allow(clippy::type_complexity)] + resolution_rx: RefCell< + mpsc::UnboundedReceiver<( + String, + broadcast::Sender<Result<TlsKey, ErrorType>>, + )>, + >, + pending: + RefCell<HashMap<String, broadcast::Sender<Result<TlsKey, ErrorType>>>>, +} + +impl TlsKeyLookup { + /// Multiple `poll` calls are safe, but this method is not starvation-safe. Generally + /// only one `poll`er should be active at any time. + pub async fn poll(&self) -> Option<String> { + if let Some((sni, sender)) = + poll_fn(|cx| self.resolution_rx.borrow_mut().poll_recv(cx)).await + { + self.pending.borrow_mut().insert(sni.clone(), sender); + Some(sni) + } else { + None + } + } + + /// Resolve a previously polled item. + pub fn resolve(&self, sni: String, res: Result<TlsKey, AnyError>) { + _ = self + .pending + .borrow_mut() + .remove(&sni) + .unwrap() + .send(res.map_err(Rc::new)); + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use deno_core::unsync::spawn; + use rustls::Certificate; + use rustls::PrivateKey; + + fn tls_key_for_test(sni: &str) -> TlsKey { + TlsKey( + vec![Certificate(format!("{sni}-cert").into_bytes())], + PrivateKey(format!("{sni}-key").into_bytes()), + ) + } + + #[tokio::test] + async fn test_resolve_once() { + let (resolver, lookup) = new_resolver(); + let task = spawn(async move { + while let Some(sni) = lookup.poll().await { + lookup.resolve(sni.clone(), Ok(tls_key_for_test(&sni))); + } + }); + + let key = resolver.resolve("example.com".to_owned()).await.unwrap(); + assert_eq!(tls_key_for_test("example.com"), key); + drop(resolver); + + task.await.unwrap(); + } + + #[tokio::test] + async fn test_resolve_concurrent() { + let (resolver, lookup) = new_resolver(); + let task = spawn(async move { + while let Some(sni) = lookup.poll().await { + lookup.resolve(sni.clone(), Ok(tls_key_for_test(&sni))); + } + }); + + let f1 = resolver.resolve("example.com".to_owned()); + let f2 = resolver.resolve("example.com".to_owned()); + + let key = f1.await.unwrap(); + assert_eq!(tls_key_for_test("example.com"), key); + let key = f2.await.unwrap(); + assert_eq!(tls_key_for_test("example.com"), key); + drop(resolver); + + task.await.unwrap(); + } + + #[tokio::test] + async fn test_resolve_multiple_concurrent() { + let (resolver, lookup) = new_resolver(); + let task = spawn(async move { + while let Some(sni) = lookup.poll().await { + lookup.resolve(sni.clone(), Ok(tls_key_for_test(&sni))); + } + }); + + let f1 = resolver.resolve("example1.com".to_owned()); + let f2 = resolver.resolve("example2.com".to_owned()); + + let key = f1.await.unwrap(); + assert_eq!(tls_key_for_test("example1.com"), key); + let key = f2.await.unwrap(); + assert_eq!(tls_key_for_test("example2.com"), key); + drop(resolver); + + task.await.unwrap(); + } +} |