diff options
Diffstat (limited to 'ext/net/ops_tls.rs')
-rw-r--r-- | ext/net/ops_tls.rs | 236 |
1 files changed, 86 insertions, 150 deletions
diff --git a/ext/net/ops_tls.rs b/ext/net/ops_tls.rs index 3de5d8476..036aab5e6 100644 --- a/ext/net/ops_tls.rs +++ b/ext/net/ops_tls.rs @@ -44,13 +44,12 @@ use deno_tls::load_certs; use deno_tls::load_private_keys; use deno_tls::rustls::Certificate; use deno_tls::rustls::ClientConfig; -use deno_tls::rustls::ClientSession; -use deno_tls::rustls::NoClientAuth; +use deno_tls::rustls::ClientConnection; +use deno_tls::rustls::Connection; use deno_tls::rustls::PrivateKey; use deno_tls::rustls::ServerConfig; -use deno_tls::rustls::ServerSession; -use deno_tls::rustls::Session; -use deno_tls::webpki::DNSNameRef; +use deno_tls::rustls::ServerConnection; +use deno_tls::rustls::ServerName; use io::Error; use io::Read; use io::Write; @@ -58,12 +57,11 @@ use serde::Deserialize; use std::borrow::Cow; use std::cell::RefCell; use std::convert::From; +use std::convert::TryFrom; use std::fs::File; use std::io; use std::io::BufReader; use std::io::ErrorKind; -use std::ops::Deref; -use std::ops::DerefMut; use std::path::Path; use std::pin::Pin; use std::rc::Rc; @@ -78,44 +76,6 @@ use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio::task::spawn_local; -#[derive(Debug)] -enum TlsSession { - Client(ClientSession), - Server(ServerSession), -} - -impl Deref for TlsSession { - type Target = dyn Session; - - fn deref(&self) -> &Self::Target { - match self { - TlsSession::Client(client_session) => client_session, - TlsSession::Server(server_session) => server_session, - } - } -} - -impl DerefMut for TlsSession { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - TlsSession::Client(client_session) => client_session, - TlsSession::Server(server_session) => server_session, - } - } -} - -impl From<ClientSession> for TlsSession { - fn from(client_session: ClientSession) -> Self { - TlsSession::Client(client_session) - } -} - -impl From<ServerSession> for TlsSession { - fn from(server_session: ServerSession) -> Self { - TlsSession::Server(server_session) - } -} - #[derive(Copy, Clone, Debug, Eq, PartialEq)] enum Flow { Handshake, @@ -129,15 +89,15 @@ enum State { StreamClosed, TlsClosing, TlsClosed, - TlsError, TcpClosed, } -#[derive(Debug)] pub struct TlsStream(Option<TlsStreamInner>); impl TlsStream { - fn new(tcp: TcpStream, tls: TlsSession) -> Self { + fn new(tcp: TcpStream, mut tls: Connection) -> Self { + tls.set_buffer_limit(None); + let inner = TlsStreamInner { tcp, tls, @@ -149,19 +109,19 @@ impl TlsStream { pub fn new_client_side( tcp: TcpStream, - tls_config: &Arc<ClientConfig>, - hostname: DNSNameRef, + tls_config: Arc<ClientConfig>, + server_name: ServerName, ) -> Self { - let tls = TlsSession::Client(ClientSession::new(tls_config, hostname)); - Self::new(tcp, tls) + let tls = ClientConnection::new(tls_config, server_name).unwrap(); + Self::new(tcp, Connection::Client(tls)) } pub fn new_server_side( tcp: TcpStream, - tls_config: &Arc<ServerConfig>, + tls_config: Arc<ServerConfig>, ) -> Self { - let tls = TlsSession::Server(ServerSession::new(tls_config)); - Self::new(tcp, tls) + let tls = ServerConnection::new(tls_config).unwrap(); + Self::new(tcp, Connection::Server(tls)) } fn into_split(self) -> (ReadHalf, WriteHalf) { @@ -174,10 +134,10 @@ impl TlsStream { } /// Tokio-rustls compatibility: returns a reference to the underlying TCP - /// stream, and a reference to the Rustls `Session` object. - pub fn get_ref(&self) -> (&TcpStream, &dyn Session) { + /// stream, and a reference to the Rustls `Connection` object. + pub fn get_ref(&self) -> (&TcpStream, &Connection) { let inner = self.0.as_ref().unwrap(); - (&inner.tcp, &*inner.tls) + (&inner.tcp, &inner.tls) } fn inner_mut(&mut self) -> &mut TlsStreamInner { @@ -196,7 +156,7 @@ impl TlsStream { self .inner_mut() .tls - .get_alpn_protocol() + .alpn_protocol() .map(|s| ByteString(s.to_owned())) } } @@ -251,9 +211,8 @@ impl Drop for TlsStream { } } -#[derive(Debug)] pub struct TlsStreamInner { - tls: TlsSession, + tls: Connection, tcp: TcpStream, rd_state: State, wr_state: State, @@ -275,7 +234,7 @@ impl TlsStreamInner { State::StreamOpen if !self.tls.wants_write() => break true, State::StreamClosed => { // Rustls will enqueue the 'CloseNotify' alert and send it after - // flusing the data that is already in the queue. + // flushing the data that is already in the queue. self.tls.send_close_notify(); self.wr_state = State::TlsClosing; continue; @@ -318,19 +277,30 @@ impl TlsStreamInner { }; let rd_ready = loop { + // Interpret and decrypt unprocessed TLS protocol data. + let tls_state = self + .tls + .process_new_packets() + .map_err(|e| Error::new(ErrorKind::InvalidData, e))?; + match self.rd_state { State::TcpClosed if self.tls.is_handshaking() => { let err = Error::new(ErrorKind::UnexpectedEof, "tls handshake eof"); return Poll::Ready(Err(err)); } - State::TlsError => {} _ if self.tls.is_handshaking() && !self.tls.wants_read() => { break true; } _ if self.tls.is_handshaking() => {} - State::StreamOpen if !self.tls.wants_read() => break true, + State::StreamOpen if tls_state.plaintext_bytes_to_read() > 0 => { + break true; + } + State::StreamOpen if tls_state.peer_has_closed() => { + self.rd_state = State::TlsClosed; + continue; + } State::StreamOpen => {} - State::StreamClosed if !self.tls.wants_read() => { + State::StreamClosed if tls_state.plaintext_bytes_to_read() > 0 => { // Rustls has more incoming cleartext buffered up, but the TLS // session is closing so this data will never be processed by the // application layer. Just like what would happen if this were a raw @@ -339,60 +309,30 @@ impl TlsStreamInner { } State::StreamClosed => {} State::TlsClosed if self.wr_state == State::TcpClosed => { - // Wait for the remote end to gracefully close the TCP connection. - // TODO(piscisaureus): this is unnecessary; remove when stable. - } - _ => break true, - } - - if self.rd_state < State::TlsClosed { - // Do a zero-length plaintext read so we can detect the arrival of - // 'CloseNotify' messages, even if only the write half is open. - // Actually reading data from the socket is done in `poll_read()`. - match self.tls.read(&mut []) { - Ok(0) => {} - Err(err) if err.kind() == ErrorKind::ConnectionAborted => { - // `Session::read()` returns `ConnectionAborted` when a - // 'CloseNotify' alert has been received, which indicates that - // the remote peer wants to gracefully end the TLS session. - self.rd_state = State::TlsClosed; - continue; - } - Err(err) => return Poll::Ready(Err(err)), - _ => unreachable!(), + // Keep trying to read from the TCP connection until the remote end + // closes it gracefully. } + State::TlsClosed => break true, + State::TcpClosed => break true, + _ => unreachable!(), } - if self.rd_state != State::TlsError { - // Receive ciphertext from the socket. - let mut wrapped_tcp = ImplementReadTrait(&mut self.tcp); - match self.tls.read_tls(&mut wrapped_tcp) { - Ok(0) => { - // End of TCP stream. - self.rd_state = State::TcpClosed; - continue; - } - Err(err) if err.kind() == ErrorKind::WouldBlock => { - // Get notified when more ciphertext becomes available in the - // socket receive buffer. - if self.tcp.poll_read_ready(cx)?.is_pending() { - break false; - } else { - continue; - } - } - Err(err) => return Poll::Ready(Err(err)), - _ => {} + // Try to read more TLS protocol data from the TCP socket. + let mut wrapped_tcp = ImplementReadTrait(&mut self.tcp); + match self.tls.read_tls(&mut wrapped_tcp) { + Ok(0) => { + self.rd_state = State::TcpClosed; + continue; } + Ok(_) => continue, + Err(err) if err.kind() == ErrorKind::WouldBlock => {} + Err(err) => return Poll::Ready(Err(err)), } - // Interpret and decrypt TLS protocol data. - match self.tls.process_new_packets() { - Ok(_) => assert!(self.rd_state < State::TcpClosed), - Err(err) => { - self.rd_state = State::TlsError; - return Poll::Ready(Err(Error::new(ErrorKind::InvalidData, err))); - } + // Get notified when more ciphertext becomes available to read from the + // TCP socket. + if self.tcp.poll_read_ready(cx)?.is_pending() { + break false; } }; @@ -438,7 +378,7 @@ impl TlsStreamInner { if self.rd_state == State::StreamOpen { let buf_slice = unsafe { &mut *(buf.unfilled_mut() as *mut [_] as *mut [u8]) }; - let bytes_read = self.tls.read(buf_slice)?; + let bytes_read = self.tls.reader().read(buf_slice)?; assert_ne!(bytes_read, 0); unsafe { buf.assume_init(bytes_read) }; buf.advance(bytes_read); @@ -460,7 +400,7 @@ impl TlsStreamInner { ready!(self.poll_io(cx, Flow::Write))?; // Copy data from `buf` to the Rustls cleartext send queue. - let bytes_written = self.tls.write(buf)?; + let bytes_written = self.tls.writer().write(buf)?; assert_ne!(bytes_written, 0); // Try to flush as much ciphertext as possible. However, since we just @@ -511,7 +451,6 @@ impl TlsStreamInner { } } -#[derive(Debug)] pub struct ReadHalf { shared: Arc<Shared>, } @@ -542,7 +481,6 @@ impl AsyncRead for ReadHalf { } } -#[derive(Debug)] pub struct WriteHalf { shared: Arc<Shared>, } @@ -596,7 +534,6 @@ impl AsyncWrite for WriteHalf { } } -#[derive(Debug)] struct Shared { tls_stream: Mutex<TlsStream>, rd_waker: AtomicWaker, @@ -851,8 +788,8 @@ where .map(|s| s.into_bytes()) .collect::<Vec<_>>(); - let hostname_dns = DNSNameRef::try_from_ascii_str(hostname) - .map_err(|_| invalid_hostname(hostname))?; + let hostname_dns = + ServerName::try_from(hostname).map_err(|_| invalid_hostname(hostname))?; let unsafely_ignore_certificate_errors = state .borrow() @@ -895,7 +832,7 @@ where let tls_config = Arc::new(tls_config); let tls_stream = - TlsStream::new_client_side(tcp_stream, &tls_config, hostname_dns); + TlsStream::new_client_side(tcp_stream, tls_config, hostname_dns); let rid = { let mut state_ = state.borrow_mut(); @@ -970,8 +907,8 @@ where .borrow::<DefaultTlsOptions>() .root_cert_store .clone(); - let hostname_dns = DNSNameRef::try_from_ascii_str(hostname) - .map_err(|_| invalid_hostname(hostname))?; + let hostname_dns = + ServerName::try_from(hostname).map_err(|_| invalid_hostname(hostname))?; let connect_addr = resolve_addr(hostname, port) .await? @@ -980,11 +917,25 @@ where let tcp_stream = TcpStream::connect(connect_addr).await?; let local_addr = tcp_stream.local_addr()?; let remote_addr = tcp_stream.peer_addr()?; + + let cert_chain_and_key = + if args.cert_chain.is_some() || args.private_key.is_some() { + let cert_chain = args + .cert_chain + .ok_or_else(|| type_error("No certificate chain provided"))?; + let private_key = args + .private_key + .ok_or_else(|| type_error("No private key provided"))?; + Some((cert_chain, private_key)) + } else { + None + }; + let mut tls_config = create_client_config( root_cert_store, ca_certs, unsafely_ignore_certificate_errors, - None, + cert_chain_and_key, )?; if let Some(alpn_protocols) = args.alpn_protocols { @@ -993,27 +944,10 @@ where alpn_protocols.into_iter().map(|s| s.into_bytes()).collect(); } - if args.cert_chain.is_some() || args.private_key.is_some() { - let cert_chain = args - .cert_chain - .ok_or_else(|| type_error("No certificate chain provided"))?; - let private_key = args - .private_key - .ok_or_else(|| type_error("No private key provided"))?; - - // The `remove` is safe because load_private_keys checks that there is at least one key. - let private_key = load_private_keys(private_key.as_bytes())?.remove(0); - - tls_config.set_single_client_cert( - load_certs(&mut cert_chain.as_bytes())?, - private_key, - )?; - } - let tls_config = Arc::new(tls_config); let tls_stream = - TlsStream::new_client_side(tcp_stream, &tls_config, hostname_dns); + TlsStream::new_client_side(tcp_stream, tls_config, hostname_dns); let rid = { let mut state_ = state.borrow_mut(); @@ -1096,18 +1030,19 @@ where permissions.check_read(Path::new(key_file))?; } - let mut tls_config = ServerConfig::new(NoClientAuth::new()); + let mut tls_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert( + load_certs_from_file(cert_file)?, + load_private_keys_from_file(key_file)?.remove(0), + ) + .expect("invalid key or certificate"); if let Some(alpn_protocols) = args.alpn_protocols { super::check_unstable(state, "Deno.listenTls#alpn_protocols"); tls_config.alpn_protocols = alpn_protocols.into_iter().map(|s| s.into_bytes()).collect(); } - tls_config - .set_single_cert( - load_certs_from_file(cert_file)?, - load_private_keys_from_file(key_file)?.remove(0), - ) - .expect("invalid key or certificate"); let bind_addr = resolve_addr_sync(hostname, port)? .next() @@ -1163,7 +1098,8 @@ pub async fn op_tls_accept( let local_addr = tcp_stream.local_addr()?; - let tls_stream = TlsStream::new_server_side(tcp_stream, &resource.tls_config); + let tls_stream = + TlsStream::new_server_side(tcp_stream, resource.tls_config.clone()); let rid = { let mut state_ = state.borrow_mut(); |