summaryrefslogtreecommitdiff
path: root/ext/net/ops_tls.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ext/net/ops_tls.rs')
-rw-r--r--ext/net/ops_tls.rs69
1 files changed, 56 insertions, 13 deletions
diff --git a/ext/net/ops_tls.rs b/ext/net/ops_tls.rs
index 87744ed63..fd2308ef1 100644
--- a/ext/net/ops_tls.rs
+++ b/ext/net/ops_tls.rs
@@ -4,6 +4,7 @@ use crate::io::TcpStreamResource;
use crate::ops::IpAddr;
use crate::ops::OpAddr;
use crate::ops::OpConn;
+use crate::ops::TlsHandshakeInfo;
use crate::resolve_addr::resolve_addr;
use crate::resolve_addr::resolve_addr_sync;
use crate::DefaultTlsOptions;
@@ -29,6 +30,7 @@ use deno_core::op_sync;
use deno_core::parking_lot::Mutex;
use deno_core::AsyncRefCell;
use deno_core::AsyncResult;
+use deno_core::ByteString;
use deno_core::CancelHandle;
use deno_core::CancelTryFuture;
use deno_core::OpPair;
@@ -54,7 +56,6 @@ use io::Read;
use io::Write;
use serde::Deserialize;
use std::borrow::Cow;
-use std::cell::Cell;
use std::cell::RefCell;
use std::convert::From;
use std::fs::File;
@@ -190,6 +191,14 @@ impl TlsStream {
fn poll_handshake(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner_mut().poll_handshake(cx)
}
+
+ fn get_alpn_protocol(&mut self) -> Option<ByteString> {
+ self
+ .inner_mut()
+ .tls
+ .get_alpn_protocol()
+ .map(|s| ByteString(s.to_owned()))
+ }
}
impl AsyncRead for TlsStream {
@@ -549,6 +558,10 @@ impl WriteHalf {
})
.await
}
+
+ fn get_alpn_protocol(&mut self) -> Option<ByteString> {
+ self.shared.get_alpn_protocol()
+ }
}
impl AsyncWrite for WriteHalf {
@@ -658,6 +671,11 @@ impl Shared {
fn drop_shared_waker(self_ptr: *const ()) {
let _ = unsafe { Weak::from_raw(self_ptr as *const Self) };
}
+
+ fn get_alpn_protocol(self: &Arc<Self>) -> Option<ByteString> {
+ let mut tls_stream = self.tls_stream.lock();
+ tls_stream.get_alpn_protocol()
+ }
}
struct ImplementReadTrait<'a, T>(&'a mut T);
@@ -698,7 +716,8 @@ pub fn init<P: NetPermissions + 'static>() -> Vec<OpPair> {
pub struct TlsStreamResource {
rd: AsyncRefCell<ReadHalf>,
wr: AsyncRefCell<WriteHalf>,
- handshake_done: Cell<bool>,
+ // `None` when a TLS handshake hasn't been done.
+ handshake_info: RefCell<Option<TlsHandshakeInfo>>,
cancel_handle: CancelHandle, // Only read and handshake ops get canceled.
}
@@ -707,7 +726,7 @@ impl TlsStreamResource {
Self {
rd: rd.into(),
wr: wr.into(),
- handshake_done: Cell::new(false),
+ handshake_info: RefCell::new(None),
cancel_handle: Default::default(),
}
}
@@ -744,14 +763,21 @@ impl TlsStreamResource {
Ok(())
}
- pub async fn handshake(self: &Rc<Self>) -> Result<(), AnyError> {
- if !self.handshake_done.get() {
- let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await;
- let cancel_handle = RcRef::map(self, |r| &r.cancel_handle);
- wr.handshake().try_or_cancel(cancel_handle).await?;
- self.handshake_done.set(true);
+ pub async fn handshake(
+ self: &Rc<Self>,
+ ) -> Result<TlsHandshakeInfo, AnyError> {
+ if let Some(tls_info) = &*self.handshake_info.borrow() {
+ return Ok(tls_info.clone());
}
- Ok(())
+
+ let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await;
+ let cancel_handle = RcRef::map(self, |r| &r.cancel_handle);
+ wr.handshake().try_or_cancel(cancel_handle).await?;
+
+ let alpn_protocol = wr.get_alpn_protocol();
+ let tls_info = TlsHandshakeInfo { alpn_protocol };
+ self.handshake_info.replace(Some(tls_info.clone()));
+ Ok(tls_info)
}
}
@@ -787,6 +813,7 @@ pub struct ConnectTlsArgs {
ca_certs: Vec<String>,
cert_chain: Option<String>,
private_key: Option<String>,
+ alpn_protocols: Option<Vec<String>>,
}
#[derive(Deserialize)]
@@ -795,6 +822,7 @@ pub struct StartTlsArgs {
rid: ResourceId,
ca_certs: Vec<String>,
hostname: String,
+ alpn_protocols: Option<Vec<String>>,
}
pub async fn op_tls_start<NP>(
@@ -851,11 +879,20 @@ where
let local_addr = tcp_stream.local_addr()?;
let remote_addr = tcp_stream.peer_addr()?;
- let tls_config = Arc::new(create_client_config(
+ let mut tls_config = create_client_config(
root_cert_store,
ca_certs,
unsafely_ignore_certificate_errors,
- )?);
+ )?;
+
+ if let Some(alpn_protocols) = args.alpn_protocols {
+ super::check_unstable2(&state, "Deno.startTls#alpnProtocols");
+ tls_config.alpn_protocols =
+ alpn_protocols.into_iter().map(|s| s.into_bytes()).collect();
+ }
+
+ let tls_config = Arc::new(tls_config);
+
let tls_stream =
TlsStream::new_client_side(tcp_stream, &tls_config, hostname_dns);
@@ -948,6 +985,12 @@ where
unsafely_ignore_certificate_errors,
)?;
+ if let Some(alpn_protocols) = args.alpn_protocols {
+ super::check_unstable2(&state, "Deno.connectTls#alpnProtocols");
+ tls_config.alpn_protocols =
+ alpn_protocols.into_iter().map(|s| s.into_bytes()).collect();
+ }
+
if args.cert_chain.is_some() || args.private_key.is_some() {
let cert_chain = args
.cert_chain
@@ -1144,7 +1187,7 @@ pub async fn op_tls_handshake(
state: Rc<RefCell<OpState>>,
rid: ResourceId,
_: (),
-) -> Result<(), AnyError> {
+) -> Result<TlsHandshakeInfo, AnyError> {
let resource = state
.borrow()
.resource_table