diff options
Diffstat (limited to 'ext/net/ops_tls.rs')
-rw-r--r-- | ext/net/ops_tls.rs | 69 |
1 files changed, 56 insertions, 13 deletions
diff --git a/ext/net/ops_tls.rs b/ext/net/ops_tls.rs index 87744ed63..fd2308ef1 100644 --- a/ext/net/ops_tls.rs +++ b/ext/net/ops_tls.rs @@ -4,6 +4,7 @@ use crate::io::TcpStreamResource; use crate::ops::IpAddr; use crate::ops::OpAddr; use crate::ops::OpConn; +use crate::ops::TlsHandshakeInfo; use crate::resolve_addr::resolve_addr; use crate::resolve_addr::resolve_addr_sync; use crate::DefaultTlsOptions; @@ -29,6 +30,7 @@ use deno_core::op_sync; use deno_core::parking_lot::Mutex; use deno_core::AsyncRefCell; use deno_core::AsyncResult; +use deno_core::ByteString; use deno_core::CancelHandle; use deno_core::CancelTryFuture; use deno_core::OpPair; @@ -54,7 +56,6 @@ use io::Read; use io::Write; use serde::Deserialize; use std::borrow::Cow; -use std::cell::Cell; use std::cell::RefCell; use std::convert::From; use std::fs::File; @@ -190,6 +191,14 @@ impl TlsStream { fn poll_handshake(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.inner_mut().poll_handshake(cx) } + + fn get_alpn_protocol(&mut self) -> Option<ByteString> { + self + .inner_mut() + .tls + .get_alpn_protocol() + .map(|s| ByteString(s.to_owned())) + } } impl AsyncRead for TlsStream { @@ -549,6 +558,10 @@ impl WriteHalf { }) .await } + + fn get_alpn_protocol(&mut self) -> Option<ByteString> { + self.shared.get_alpn_protocol() + } } impl AsyncWrite for WriteHalf { @@ -658,6 +671,11 @@ impl Shared { fn drop_shared_waker(self_ptr: *const ()) { let _ = unsafe { Weak::from_raw(self_ptr as *const Self) }; } + + fn get_alpn_protocol(self: &Arc<Self>) -> Option<ByteString> { + let mut tls_stream = self.tls_stream.lock(); + tls_stream.get_alpn_protocol() + } } struct ImplementReadTrait<'a, T>(&'a mut T); @@ -698,7 +716,8 @@ pub fn init<P: NetPermissions + 'static>() -> Vec<OpPair> { pub struct TlsStreamResource { rd: AsyncRefCell<ReadHalf>, wr: AsyncRefCell<WriteHalf>, - handshake_done: Cell<bool>, + // `None` when a TLS handshake hasn't been done. + handshake_info: RefCell<Option<TlsHandshakeInfo>>, cancel_handle: CancelHandle, // Only read and handshake ops get canceled. } @@ -707,7 +726,7 @@ impl TlsStreamResource { Self { rd: rd.into(), wr: wr.into(), - handshake_done: Cell::new(false), + handshake_info: RefCell::new(None), cancel_handle: Default::default(), } } @@ -744,14 +763,21 @@ impl TlsStreamResource { Ok(()) } - pub async fn handshake(self: &Rc<Self>) -> Result<(), AnyError> { - if !self.handshake_done.get() { - let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; - let cancel_handle = RcRef::map(self, |r| &r.cancel_handle); - wr.handshake().try_or_cancel(cancel_handle).await?; - self.handshake_done.set(true); + pub async fn handshake( + self: &Rc<Self>, + ) -> Result<TlsHandshakeInfo, AnyError> { + if let Some(tls_info) = &*self.handshake_info.borrow() { + return Ok(tls_info.clone()); } - Ok(()) + + let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; + let cancel_handle = RcRef::map(self, |r| &r.cancel_handle); + wr.handshake().try_or_cancel(cancel_handle).await?; + + let alpn_protocol = wr.get_alpn_protocol(); + let tls_info = TlsHandshakeInfo { alpn_protocol }; + self.handshake_info.replace(Some(tls_info.clone())); + Ok(tls_info) } } @@ -787,6 +813,7 @@ pub struct ConnectTlsArgs { ca_certs: Vec<String>, cert_chain: Option<String>, private_key: Option<String>, + alpn_protocols: Option<Vec<String>>, } #[derive(Deserialize)] @@ -795,6 +822,7 @@ pub struct StartTlsArgs { rid: ResourceId, ca_certs: Vec<String>, hostname: String, + alpn_protocols: Option<Vec<String>>, } pub async fn op_tls_start<NP>( @@ -851,11 +879,20 @@ where let local_addr = tcp_stream.local_addr()?; let remote_addr = tcp_stream.peer_addr()?; - let tls_config = Arc::new(create_client_config( + let mut tls_config = create_client_config( root_cert_store, ca_certs, unsafely_ignore_certificate_errors, - )?); + )?; + + if let Some(alpn_protocols) = args.alpn_protocols { + super::check_unstable2(&state, "Deno.startTls#alpnProtocols"); + tls_config.alpn_protocols = + alpn_protocols.into_iter().map(|s| s.into_bytes()).collect(); + } + + let tls_config = Arc::new(tls_config); + let tls_stream = TlsStream::new_client_side(tcp_stream, &tls_config, hostname_dns); @@ -948,6 +985,12 @@ where unsafely_ignore_certificate_errors, )?; + if let Some(alpn_protocols) = args.alpn_protocols { + super::check_unstable2(&state, "Deno.connectTls#alpnProtocols"); + tls_config.alpn_protocols = + alpn_protocols.into_iter().map(|s| s.into_bytes()).collect(); + } + if args.cert_chain.is_some() || args.private_key.is_some() { let cert_chain = args .cert_chain @@ -1144,7 +1187,7 @@ pub async fn op_tls_handshake( state: Rc<RefCell<OpState>>, rid: ResourceId, _: (), -) -> Result<(), AnyError> { +) -> Result<TlsHandshakeInfo, AnyError> { let resource = state .borrow() .resource_table |