diff options
author | Bert Belder <bertbelder@gmail.com> | 2021-10-20 01:30:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-20 01:30:04 +0200 |
commit | 6a9656098671d19b2cbfedfef4db0df6f84735d1 (patch) | |
tree | a29a1fbb03e6ced6eda1e53999cf2c18f61c1451 /ext/net/ops_tls.rs | |
parent | 4f48efcc55b9e6cc0dd212ebd8e729909efed1ab (diff) |
fix(ext/net): fix TLS bugs and add 'op_tls_handshake' (#12501)
A bug was fixed that could cause a hang when a method was
called on a TlsConn object that had thrown an exception earlier.
Additionally, a bug was fixed that caused TlsConn.write() to not
completely flush large buffers (>64kB) to the socket.
The public `TlsConn.handshake()` API is scheduled for inclusion in the
next minor release. See https://github.com/denoland/deno/pull/12467.
Diffstat (limited to 'ext/net/ops_tls.rs')
-rw-r--r-- | ext/net/ops_tls.rs | 188 |
1 files changed, 159 insertions, 29 deletions
diff --git a/ext/net/ops_tls.rs b/ext/net/ops_tls.rs index d6618440f..129a702bc 100644 --- a/ext/net/ops_tls.rs +++ b/ext/net/ops_tls.rs @@ -1,7 +1,6 @@ // Copyright 2018-2021 the Deno authors. All rights reserved. MIT license. use crate::io::TcpStreamResource; -use crate::io::TlsStreamResource; use crate::ops::IpAddr; use crate::ops::OpAddr; use crate::ops::OpConn; @@ -53,6 +52,7 @@ use io::Read; use io::Write; use serde::Deserialize; use std::borrow::Cow; +use std::cell::Cell; use std::cell::RefCell; use std::convert::From; use std::fs::File; @@ -67,7 +67,9 @@ use std::rc::Rc; use std::sync::Arc; use std::sync::Weak; use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; use tokio::io::ReadBuf; use tokio::net::TcpListener; use tokio::net::TcpStream; @@ -113,6 +115,7 @@ impl From<ServerSession> for TlsSession { #[derive(Copy, Clone, Debug, Eq, PartialEq)] enum Flow { + Handshake, Read, Write, } @@ -123,6 +126,7 @@ enum State { StreamClosed, TlsClosing, TlsClosed, + TlsError, TcpClosed, } @@ -157,10 +161,6 @@ impl TlsStream { Self::new(tcp, tls) } - pub async fn handshake(&mut self) -> io::Result<()> { - poll_fn(|cx| self.inner_mut().poll_io(cx, Flow::Write)).await - } - fn into_split(self) -> (ReadHalf, WriteHalf) { let shared = Shared::new(self); let rd = ReadHalf { @@ -180,6 +180,14 @@ impl TlsStream { fn inner_mut(&mut self) -> &mut TlsStreamInner { self.0.as_mut().unwrap() } + + pub async fn handshake(&mut self) -> io::Result<()> { + poll_fn(|cx| self.inner_mut().poll_handshake(cx)).await + } + + fn poll_handshake(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.inner_mut().poll_handshake(cx) + } } impl AsyncRead for TlsStream { @@ -282,20 +290,20 @@ impl TlsStreamInner { _ => {} } - // Poll whether there is space in the socket send buffer so we can flush - // the remaining outgoing ciphertext. - if self.tcp.poll_write_ready(cx)?.is_pending() { - break false; - } - // Write ciphertext to the TCP socket. let mut wrapped_tcp = ImplementWriteTrait(&mut self.tcp); match self.tls.write_tls(&mut wrapped_tcp) { - Ok(0) => unreachable!(), - Ok(_) => {} - Err(err) if err.kind() == ErrorKind::WouldBlock => {} + Ok(0) => {} // Wait until the socket has enough buffer space. + Ok(_) => continue, // Try to send more more data immediately. + Err(err) if err.kind() == ErrorKind::WouldBlock => unreachable!(), Err(err) => return Poll::Ready(Err(err)), } + + // Poll whether there is space in the socket send buffer so we can flush + // the remaining outgoing ciphertext. + if self.tcp.poll_write_ready(cx)?.is_pending() { + break false; + } }; let rd_ready = loop { @@ -304,6 +312,7 @@ impl TlsStreamInner { let err = Error::new(ErrorKind::UnexpectedEof, "tls handshake eof"); return Poll::Ready(Err(err)); } + State::TlsError => {} _ if self.tls.is_handshaking() && !self.tls.wants_read() => { break true; } @@ -343,22 +352,36 @@ impl TlsStreamInner { } } - // Poll whether more ciphertext is available in the socket receive - // buffer. - if self.tcp.poll_read_ready(cx)?.is_pending() { - break false; + if self.rd_state != State::TlsError { + // Receive ciphertext from the socket. + let mut wrapped_tcp = ImplementReadTrait(&mut self.tcp); + match self.tls.read_tls(&mut wrapped_tcp) { + Ok(0) => { + // End of TCP stream. + self.rd_state = State::TcpClosed; + continue; + } + Err(err) if err.kind() == ErrorKind::WouldBlock => { + // Get notified when more ciphertext becomes available in the + // socket receive buffer. + if self.tcp.poll_read_ready(cx)?.is_pending() { + break false; + } else { + continue; + } + } + Err(err) => return Poll::Ready(Err(err)), + _ => {} + } } - // Receive ciphertext from the socket. - let mut wrapped_tcp = ImplementReadTrait(&mut self.tcp); - match self.tls.read_tls(&mut wrapped_tcp) { - Ok(0) => self.rd_state = State::TcpClosed, - Ok(_) => self - .tls - .process_new_packets() - .map_err(|err| Error::new(ErrorKind::InvalidData, err))?, - Err(err) if err.kind() == ErrorKind::WouldBlock => {} - Err(err) => return Poll::Ready(Err(err)), + // Interpret and decrypt TLS protocol data. + match self.tls.process_new_packets() { + Ok(_) => assert!(self.rd_state < State::TcpClosed), + Err(err) => { + self.rd_state = State::TlsError; + return Poll::Ready(Err(Error::new(ErrorKind::InvalidData, err))); + } } }; @@ -376,6 +399,7 @@ impl TlsStreamInner { let io_ready = match flow { _ if self.tls.is_handshaking() => false, + Flow::Handshake => true, Flow::Read => rd_ready, Flow::Write => wr_ready, }; @@ -386,6 +410,13 @@ impl TlsStreamInner { } } + fn poll_handshake(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + if self.tls.is_handshaking() { + ready!(self.poll_io(cx, Flow::Handshake))?; + } + Poll::Ready(Ok(())) + } + fn poll_read( &mut self, cx: &mut Context<'_>, @@ -505,6 +536,19 @@ pub struct WriteHalf { shared: Arc<Shared>, } +impl WriteHalf { + pub async fn handshake(&mut self) -> io::Result<()> { + poll_fn(|cx| { + self + .shared + .poll_with_shared_waker(cx, Flow::Write, |mut tls, cx| { + tls.poll_handshake(cx) + }) + }) + .await + } +} + impl AsyncWrite for WriteHalf { fn poll_write( self: Pin<&mut Self>, @@ -561,6 +605,7 @@ impl Shared { mut f: impl FnMut(Pin<&mut TlsStream>, &mut Context<'_>) -> R, ) -> R { match flow { + Flow::Handshake => unreachable!(), Flow::Read => self.rd_waker.register(cx.waker()), Flow::Write => self.wr_waker.register(cx.waker()), } @@ -625,7 +670,11 @@ struct ImplementWriteTrait<'a, T>(&'a mut T); impl Write for ImplementWriteTrait<'_, TcpStream> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - self.0.try_write(buf) + match self.0.try_write(buf) { + Ok(n) => Ok(n), + Err(err) if err.kind() == ErrorKind::WouldBlock => Ok(0), + Err(err) => Err(err), + } } fn flush(&mut self) -> io::Result<()> { @@ -639,9 +688,78 @@ pub fn init<P: NetPermissions + 'static>() -> Vec<OpPair> { ("op_connect_tls", op_async(op_connect_tls::<P>)), ("op_listen_tls", op_sync(op_listen_tls::<P>)), ("op_accept_tls", op_async(op_accept_tls)), + ("op_tls_handshake", op_async(op_tls_handshake)), ] } +#[derive(Debug)] +pub struct TlsStreamResource { + rd: AsyncRefCell<ReadHalf>, + wr: AsyncRefCell<WriteHalf>, + handshake_done: Cell<bool>, + cancel_handle: CancelHandle, // Only read and handshake ops get canceled. +} + +impl TlsStreamResource { + pub fn new((rd, wr): (ReadHalf, WriteHalf)) -> Self { + Self { + rd: rd.into(), + wr: wr.into(), + handshake_done: Cell::new(false), + cancel_handle: Default::default(), + } + } + + pub fn into_inner(self) -> (ReadHalf, WriteHalf) { + (self.rd.into_inner(), self.wr.into_inner()) + } + + pub async fn read( + self: &Rc<Self>, + buf: &mut [u8], + ) -> Result<usize, AnyError> { + let mut rd = RcRef::map(self, |r| &r.rd).borrow_mut().await; + let cancel_handle = RcRef::map(self, |r| &r.cancel_handle); + let nread = rd.read(buf).try_or_cancel(cancel_handle).await?; + Ok(nread) + } + + pub async fn write(self: &Rc<Self>, buf: &[u8]) -> Result<usize, AnyError> { + self.handshake().await?; + let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; + let nwritten = wr.write(buf).await?; + wr.flush().await?; + Ok(nwritten) + } + + pub async fn shutdown(self: &Rc<Self>) -> Result<(), AnyError> { + self.handshake().await?; + let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; + wr.shutdown().await?; + Ok(()) + } + + pub async fn handshake(self: &Rc<Self>) -> Result<(), AnyError> { + if !self.handshake_done.get() { + let mut wr = RcRef::map(self, |r| &r.wr).borrow_mut().await; + let cancel_handle = RcRef::map(self, |r| &r.cancel_handle); + wr.handshake().try_or_cancel(cancel_handle).await?; + self.handshake_done.set(true); + } + Ok(()) + } +} + +impl Resource for TlsStreamResource { + fn name(&self) -> Cow<str> { + "tlsStream".into() + } + + fn close(self: Rc<Self>) { + self.cancel_handle.cancel(); + } +} + #[derive(Deserialize)] #[serde(rename_all = "camelCase")] pub struct ConnectTlsArgs { @@ -1015,3 +1133,15 @@ async fn op_accept_tls( })), }) } + +async fn op_tls_handshake( + state: Rc<RefCell<OpState>>, + rid: ResourceId, + _: (), +) -> Result<(), AnyError> { + let resource = state + .borrow() + .resource_table + .get::<TlsStreamResource>(rid)?; + resource.handshake().await +} |