summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock2
-rw-r--r--cli/tests/integration_tests.rs165
-rw-r--r--cli/tests/unit/tls_test.ts685
-rw-r--r--runtime/Cargo.toml2
-rw-r--r--runtime/ops/http.rs10
-rw-r--r--runtime/ops/io.rs50
-rw-r--r--runtime/ops/tls.rs787
7 files changed, 1440 insertions, 261 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 4f34174c8..fa7d00884 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -713,12 +713,12 @@ dependencies = [
"percent-encoding",
"regex",
"ring",
+ "rustls",
"serde",
"sys-info",
"termcolor",
"test_util",
"tokio",
- "tokio-rustls",
"tokio-util",
"trust-dns-proto",
"trust-dns-resolver",
diff --git a/cli/tests/integration_tests.rs b/cli/tests/integration_tests.rs
index 56043930b..5cca0d1cc 100644
--- a/cli/tests/integration_tests.rs
+++ b/cli/tests/integration_tests.rs
@@ -5,7 +5,9 @@ use deno_core::serde_json;
use deno_core::url;
use deno_runtime::deno_fetch::reqwest;
use deno_runtime::deno_websocket::tokio_tungstenite;
-use rustls::Session;
+use deno_runtime::ops::tls::rustls;
+use deno_runtime::ops::tls::webpki;
+use deno_runtime::ops::tls::TlsStream;
use std::fs;
use std::io::BufReader;
use std::io::Cursor;
@@ -14,8 +16,7 @@ use std::process::Command;
use std::sync::Arc;
use tempfile::TempDir;
use test_util as util;
-use tokio_rustls::rustls;
-use tokio_rustls::webpki;
+use tokio::task::LocalSet;
#[test]
fn js_unit_tests_lint() {
@@ -6134,79 +6135,103 @@ console.log("finish");
#[tokio::test]
async fn listen_tls_alpn() {
- let child = util::deno_cmd()
- .current_dir(util::root_path())
- .arg("run")
- .arg("--unstable")
- .arg("--quiet")
- .arg("--allow-net")
- .arg("--allow-read")
- .arg("./cli/tests/listen_tls_alpn.ts")
- .arg("4504")
- .stdout(std::process::Stdio::piped())
- .spawn()
- .unwrap();
- let mut stdout = child.stdout.unwrap();
- let mut buffer = [0; 5];
- let read = stdout.read(&mut buffer).unwrap();
- assert_eq!(read, 5);
- let msg = std::str::from_utf8(&buffer).unwrap();
- assert_eq!(msg, "READY");
-
- let mut cfg = rustls::ClientConfig::new();
- let reader =
- &mut BufReader::new(Cursor::new(include_bytes!("./tls/RootCA.crt")));
- cfg.root_store.add_pem_file(reader).unwrap();
- cfg.alpn_protocols.push("foobar".as_bytes().to_vec());
-
- let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(cfg));
- let hostname = webpki::DNSNameRef::try_from_ascii_str("localhost").unwrap();
- let stream = tokio::net::TcpStream::connect("localhost:4504")
- .await
- .unwrap();
+ // TLS streams require the presence of an ambient local task set to gracefully
+ // close dropped connections in the background.
+ LocalSet::new()
+ .run_until(async {
+ let mut child = util::deno_cmd()
+ .current_dir(util::root_path())
+ .arg("run")
+ .arg("--unstable")
+ .arg("--quiet")
+ .arg("--allow-net")
+ .arg("--allow-read")
+ .arg("./cli/tests/listen_tls_alpn.ts")
+ .arg("4504")
+ .stdout(std::process::Stdio::piped())
+ .spawn()
+ .unwrap();
+ let stdout = child.stdout.as_mut().unwrap();
+ let mut buffer = [0; 5];
+ let read = stdout.read(&mut buffer).unwrap();
+ assert_eq!(read, 5);
+ let msg = std::str::from_utf8(&buffer).unwrap();
+ assert_eq!(msg, "READY");
+
+ let mut cfg = rustls::ClientConfig::new();
+ let reader =
+ &mut BufReader::new(Cursor::new(include_bytes!("./tls/RootCA.crt")));
+ cfg.root_store.add_pem_file(reader).unwrap();
+ cfg.alpn_protocols.push("foobar".as_bytes().to_vec());
+ let cfg = Arc::new(cfg);
+
+ let hostname =
+ webpki::DNSNameRef::try_from_ascii_str("localhost").unwrap();
+
+ let tcp_stream = tokio::net::TcpStream::connect("localhost:4504")
+ .await
+ .unwrap();
+ let mut tls_stream =
+ TlsStream::new_client_side(tcp_stream, &cfg, hostname);
+ tls_stream.handshake().await.unwrap();
+ let (_, session) = tls_stream.get_ref();
- let tls_stream = tls_connector.connect(hostname, stream).await.unwrap();
- let (_, session) = tls_stream.get_ref();
+ let alpn = session.get_alpn_protocol().unwrap();
+ assert_eq!(std::str::from_utf8(alpn).unwrap(), "foobar");
- let alpn = session.get_alpn_protocol().unwrap();
- assert_eq!(std::str::from_utf8(alpn).unwrap(), "foobar");
+ child.kill().unwrap();
+ child.wait().unwrap();
+ })
+ .await;
}
#[tokio::test]
async fn listen_tls_alpn_fail() {
- let child = util::deno_cmd()
- .current_dir(util::root_path())
- .arg("run")
- .arg("--unstable")
- .arg("--quiet")
- .arg("--allow-net")
- .arg("--allow-read")
- .arg("./cli/tests/listen_tls_alpn.ts")
- .arg("4505")
- .stdout(std::process::Stdio::piped())
- .spawn()
- .unwrap();
- let mut stdout = child.stdout.unwrap();
- let mut buffer = [0; 5];
- let read = stdout.read(&mut buffer).unwrap();
- assert_eq!(read, 5);
- let msg = std::str::from_utf8(&buffer).unwrap();
- assert_eq!(msg, "READY");
-
- let mut cfg = rustls::ClientConfig::new();
- let reader =
- &mut BufReader::new(Cursor::new(include_bytes!("./tls/RootCA.crt")));
- cfg.root_store.add_pem_file(reader).unwrap();
- cfg.alpn_protocols.push("boofar".as_bytes().to_vec());
-
- let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(cfg));
- let hostname = webpki::DNSNameRef::try_from_ascii_str("localhost").unwrap();
- let stream = tokio::net::TcpStream::connect("localhost:4505")
- .await
- .unwrap();
+ // TLS streams require the presence of an ambient local task set to gracefully
+ // close dropped connections in the background.
+ LocalSet::new()
+ .run_until(async {
+ let mut child = util::deno_cmd()
+ .current_dir(util::root_path())
+ .arg("run")
+ .arg("--unstable")
+ .arg("--quiet")
+ .arg("--allow-net")
+ .arg("--allow-read")
+ .arg("./cli/tests/listen_tls_alpn.ts")
+ .arg("4505")
+ .stdout(std::process::Stdio::piped())
+ .spawn()
+ .unwrap();
+ let stdout = child.stdout.as_mut().unwrap();
+ let mut buffer = [0; 5];
+ let read = stdout.read(&mut buffer).unwrap();
+ assert_eq!(read, 5);
+ let msg = std::str::from_utf8(&buffer).unwrap();
+ assert_eq!(msg, "READY");
+
+ let mut cfg = rustls::ClientConfig::new();
+ let reader =
+ &mut BufReader::new(Cursor::new(include_bytes!("./tls/RootCA.crt")));
+ cfg.root_store.add_pem_file(reader).unwrap();
+ cfg.alpn_protocols.push("boofar".as_bytes().to_vec());
+ let cfg = Arc::new(cfg);
+
+ let hostname =
+ webpki::DNSNameRef::try_from_ascii_str("localhost").unwrap();
+
+ let tcp_stream = tokio::net::TcpStream::connect("localhost:4505")
+ .await
+ .unwrap();
+ let mut tls_stream =
+ TlsStream::new_client_side(tcp_stream, &cfg, hostname);
+ tls_stream.handshake().await.unwrap();
+ let (_, session) = tls_stream.get_ref();
- let tls_stream = tls_connector.connect(hostname, stream).await.unwrap();
- let (_, session) = tls_stream.get_ref();
+ assert!(session.get_alpn_protocol().is_none());
- assert!(session.get_alpn_protocol().is_none());
+ child.kill().unwrap();
+ child.wait().unwrap();
+ })
+ .await;
}
diff --git a/cli/tests/unit/tls_test.ts b/cli/tests/unit/tls_test.ts
index 0528c8043..cedcf467d 100644
--- a/cli/tests/unit/tls_test.ts
+++ b/cli/tests/unit/tls_test.ts
@@ -2,9 +2,11 @@
import {
assert,
assertEquals,
+ assertNotEquals,
assertStrictEquals,
assertThrows,
assertThrowsAsync,
+ Deferred,
deferred,
unitTest,
} from "./test_util.ts";
@@ -14,6 +16,14 @@ import { TextProtoReader } from "../../../test_util/std/textproto/mod.ts";
const encoder = new TextEncoder();
const decoder = new TextDecoder();
+async function sleep(msec: number): Promise<void> {
+ await new Promise((res, _rej) => setTimeout(res, msec));
+}
+
+function unreachable(): never {
+ throw new Error("Unreachable code reached");
+}
+
unitTest(async function connectTLSNoPerm(): Promise<void> {
await assertThrowsAsync(async () => {
await Deno.connectTls({ hostname: "github.com", port: 443 });
@@ -201,7 +211,13 @@ unitTest(
},
);
-async function tlsPair(port: number): Promise<[Deno.Conn, Deno.Conn]> {
+let nextPort = 3501;
+function getPort() {
+ return nextPort++;
+}
+
+async function tlsPair(): Promise<[Deno.Conn, Deno.Conn]> {
+ const port = getPort();
const listener = Deno.listenTls({
hostname: "localhost",
port,
@@ -215,59 +231,169 @@ async function tlsPair(port: number): Promise<[Deno.Conn, Deno.Conn]> {
port,
certFile: "cli/tests/tls/RootCA.pem",
});
- const connections = await Promise.all([acceptPromise, connectPromise]);
+ const endpoints = await Promise.all([acceptPromise, connectPromise]);
listener.close();
- return connections;
+ return endpoints;
}
-async function sendCloseWrite(conn: Deno.Conn): Promise<void> {
- const buf = new Uint8Array(1024);
- let n: number | null;
-
- // Send 1.
- n = await conn.write(new Uint8Array([1]));
- assertStrictEquals(n, 1);
+async function sendThenCloseWriteThenReceive(
+ conn: Deno.Conn,
+ chunkCount: number,
+ chunkSize: number,
+): Promise<void> {
+ const byteCount = chunkCount * chunkSize;
+ const buf = new Uint8Array(chunkSize); // Note: buf is size of _chunk_.
+ let n: number;
+
+ // Slowly send 42s.
+ buf.fill(42);
+ for (let remaining = byteCount; remaining > 0; remaining -= n) {
+ n = await conn.write(buf.subarray(0, remaining));
+ assert(n >= 1);
+ await sleep(10);
+ }
// Send EOF.
await conn.closeWrite();
- // Receive 2.
- n = await conn.read(buf);
- assertStrictEquals(n, 1);
- assertStrictEquals(buf[0], 2);
+ // Receive 69s.
+ for (let remaining = byteCount; remaining > 0; remaining -= n) {
+ buf.fill(0);
+ n = await conn.read(buf) as number;
+ assert(n >= 1);
+ assertStrictEquals(buf[0], 69);
+ assertStrictEquals(buf[n - 1], 69);
+ }
conn.close();
}
-async function receiveCloseWrite(conn: Deno.Conn): Promise<void> {
- const buf = new Uint8Array(1024);
- let n: number | null;
-
- // Receive 1.
- n = await conn.read(buf);
- assertStrictEquals(n, 1);
- assertStrictEquals(buf[0], 1);
-
- // Receive EOF.
- n = await conn.read(buf);
- assertStrictEquals(n, null);
+async function receiveThenSend(
+ conn: Deno.Conn,
+ chunkCount: number,
+ chunkSize: number,
+): Promise<void> {
+ const byteCount = chunkCount * chunkSize;
+ const buf = new Uint8Array(byteCount); // Note: buf size equals `byteCount`.
+ let n: number;
+
+ // Receive 42s.
+ for (let remaining = byteCount; remaining > 0; remaining -= n) {
+ buf.fill(0);
+ n = await conn.read(buf) as number;
+ assert(n >= 1);
+ assertStrictEquals(buf[0], 42);
+ assertStrictEquals(buf[n - 1], 42);
+ }
- // Send 2.
- n = await conn.write(new Uint8Array([2]));
- assertStrictEquals(n, 1);
+ // Slowly send 69s.
+ buf.fill(69);
+ for (let remaining = byteCount; remaining > 0; remaining -= n) {
+ n = await conn.write(buf.subarray(0, remaining));
+ assert(n >= 1);
+ await sleep(10);
+ }
conn.close();
}
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsServerStreamHalfCloseSendOneByte(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
+ await Promise.all([
+ sendThenCloseWriteThenReceive(serverConn, 1, 1),
+ receiveThenSend(clientConn, 1, 1),
+ ]);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsClientStreamHalfCloseSendOneByte(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
+ await Promise.all([
+ sendThenCloseWriteThenReceive(clientConn, 1, 1),
+ receiveThenSend(serverConn, 1, 1),
+ ]);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsServerStreamHalfCloseSendOneChunk(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
+ await Promise.all([
+ sendThenCloseWriteThenReceive(serverConn, 1, 1 << 20 /* 1 MB */),
+ receiveThenSend(clientConn, 1, 1 << 20 /* 1 MB */),
+ ]);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsClientStreamHalfCloseSendOneChunk(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
+ await Promise.all([
+ sendThenCloseWriteThenReceive(clientConn, 1, 1 << 20 /* 1 MB */),
+ receiveThenSend(serverConn, 1, 1 << 20 /* 1 MB */),
+ ]);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsServerStreamHalfCloseSendManyBytes(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
+ await Promise.all([
+ sendThenCloseWriteThenReceive(serverConn, 100, 1),
+ receiveThenSend(clientConn, 100, 1),
+ ]);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsClientStreamHalfCloseSendManyBytes(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
+ await Promise.all([
+ sendThenCloseWriteThenReceive(clientConn, 100, 1),
+ receiveThenSend(serverConn, 100, 1),
+ ]);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsServerStreamHalfCloseSendManyChunks(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
+ await Promise.all([
+ sendThenCloseWriteThenReceive(serverConn, 100, 1 << 16 /* 64 kB */),
+ receiveThenSend(clientConn, 100, 1 << 16 /* 64 kB */),
+ ]);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsClientStreamHalfCloseSendManyChunks(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
+ await Promise.all([
+ sendThenCloseWriteThenReceive(clientConn, 100, 1 << 16 /* 64 kB */),
+ receiveThenSend(serverConn, 100, 1 << 16 /* 64 kB */),
+ ]);
+ },
+);
+
async function sendAlotReceiveNothing(conn: Deno.Conn): Promise<void> {
// Start receive op.
const readBuf = new Uint8Array(1024);
const readPromise = conn.read(readBuf);
// Send 1 MB of data.
- const writeBuf = new Uint8Array(1 << 20);
+ const writeBuf = new Uint8Array(1 << 20 /* 1 MB */);
writeBuf.fill(42);
await conn.write(writeBuf);
@@ -289,7 +415,7 @@ async function receiveAlotSendNothing(conn: Deno.Conn): Promise<void> {
let n: number | null;
// Receive 1 MB of data.
- for (let nread = 0; nread < 1 << 20; nread += n!) {
+ for (let nread = 0; nread < 1 << 20 /* 1 MB */; nread += n!) {
n = await conn.read(readBuf);
assertStrictEquals(typeof n, "number");
assert(n! > 0);
@@ -302,50 +428,515 @@ async function receiveAlotSendNothing(conn: Deno.Conn): Promise<void> {
unitTest(
{ perms: { read: true, net: true } },
- async function tlsServerStreamHalfClose(): Promise<void> {
- const [serverConn, clientConn] = await tlsPair(3501);
+ async function tlsServerStreamCancelRead(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
await Promise.all([
- sendCloseWrite(serverConn),
- receiveCloseWrite(clientConn),
+ sendAlotReceiveNothing(serverConn),
+ receiveAlotSendNothing(clientConn),
]);
},
);
unitTest(
{ perms: { read: true, net: true } },
- async function tlsClientStreamHalfClose(): Promise<void> {
- const [serverConn, clientConn] = await tlsPair(3502);
+ async function tlsClientStreamCancelRead(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
await Promise.all([
- sendCloseWrite(clientConn),
- receiveCloseWrite(serverConn),
+ sendAlotReceiveNothing(clientConn),
+ receiveAlotSendNothing(serverConn),
]);
},
);
+async function sendReceiveEmptyBuf(conn: Deno.Conn): Promise<void> {
+ const byteBuf = new Uint8Array([1]);
+ const emptyBuf = new Uint8Array(0);
+ let n: number | null;
+
+ n = await conn.write(emptyBuf);
+ assertStrictEquals(n, 0);
+
+ n = await conn.read(emptyBuf);
+ assertStrictEquals(n, 0);
+
+ n = await conn.write(byteBuf);
+ assertStrictEquals(n, 1);
+
+ n = await conn.read(byteBuf);
+ assertStrictEquals(n, 1);
+
+ await conn.closeWrite();
+
+ n = await conn.write(emptyBuf);
+ assertStrictEquals(n, 0);
+
+ await assertThrowsAsync(async () => {
+ await conn.write(byteBuf);
+ }, Deno.errors.BrokenPipe);
+
+ n = await conn.write(emptyBuf);
+ assertStrictEquals(n, 0);
+
+ n = await conn.read(byteBuf);
+ assertStrictEquals(n, null);
+
+ conn.close();
+}
+
unitTest(
{ perms: { read: true, net: true } },
- async function tlsServerStreamCancelRead(): Promise<void> {
- const [serverConn, clientConn] = await tlsPair(3503);
+ async function tlsStreamSendReceiveEmptyBuf(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
await Promise.all([
- sendAlotReceiveNothing(serverConn),
- receiveAlotSendNothing(clientConn),
+ sendReceiveEmptyBuf(serverConn),
+ sendReceiveEmptyBuf(clientConn),
]);
},
);
+function immediateClose(conn: Deno.Conn): Promise<void> {
+ conn.close();
+ return Promise.resolve();
+}
+
+async function closeWriteAndClose(conn: Deno.Conn): Promise<void> {
+ await conn.closeWrite();
+
+ if (await conn.read(new Uint8Array(1)) !== null) {
+ throw new Error("did not expect to receive data on TLS stream");
+ }
+
+ conn.close();
+}
+
unitTest(
{ perms: { read: true, net: true } },
- async function tlsClientStreamCancelRead(): Promise<void> {
- const [serverConn, clientConn] = await tlsPair(3504);
+ async function tlsServerStreamImmediateClose(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
await Promise.all([
- sendAlotReceiveNothing(clientConn),
- receiveAlotSendNothing(serverConn),
+ immediateClose(serverConn),
+ closeWriteAndClose(clientConn),
]);
},
);
unitTest(
{ perms: { read: true, net: true } },
+ async function tlsClientStreamImmediateClose(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
+ await Promise.all([
+ closeWriteAndClose(serverConn),
+ immediateClose(clientConn),
+ ]);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsClientAndServerStreamImmediateClose(): Promise<void> {
+ const [serverConn, clientConn] = await tlsPair();
+ await Promise.all([
+ immediateClose(serverConn),
+ immediateClose(clientConn),
+ ]);
+ },
+);
+
+async function tlsWithTcpFailureTestImpl(
+ phase: "handshake" | "traffic",
+ cipherByteCount: number,
+ failureMode: "corruption" | "shutdown",
+ reverse: boolean,
+): Promise<void> {
+ const tlsPort = getPort();
+ const tlsListener = Deno.listenTls({
+ hostname: "localhost",
+ port: tlsPort,
+ certFile: "cli/tests/tls/localhost.crt",
+ keyFile: "cli/tests/tls/localhost.key",
+ });
+
+ const tcpPort = getPort();
+ const tcpListener = Deno.listen({ hostname: "localhost", port: tcpPort });
+
+ const [tlsServerConn, tcpServerConn] = await Promise.all([
+ tlsListener.accept(),
+ Deno.connect({ hostname: "localhost", port: tlsPort }),
+ ]);
+
+ const [tcpClientConn, tlsClientConn] = await Promise.all([
+ tcpListener.accept(),
+ Deno.connectTls({
+ hostname: "localhost",
+ port: tcpPort,
+ certFile: "cli/tests/tls/RootCA.crt",
+ }),
+ ]);
+
+ tlsListener.close();
+ tcpListener.close();
+
+ const {
+ tlsConn1,
+ tlsConn2,
+ tcpConn1,
+ tcpConn2,
+ } = reverse
+ ? {
+ tlsConn1: tlsClientConn,
+ tlsConn2: tlsServerConn,
+ tcpConn1: tcpClientConn,
+ tcpConn2: tcpServerConn,
+ }
+ : {
+ tlsConn1: tlsServerConn,
+ tlsConn2: tlsClientConn,
+ tcpConn1: tcpServerConn,
+ tcpConn2: tcpClientConn,
+ };
+
+ const tcpForwardingInterruptPromise1 = deferred<void>();
+ const tcpForwardingPromise1 = forwardBytes(
+ tcpConn2,
+ tcpConn1,
+ cipherByteCount,
+ tcpForwardingInterruptPromise1,
+ );
+
+ const tcpForwardingInterruptPromise2 = deferred<void>();
+ const tcpForwardingPromise2 = forwardBytes(
+ tcpConn1,
+ tcpConn2,
+ Infinity,
+ tcpForwardingInterruptPromise2,
+ );
+
+ switch (phase) {
+ case "handshake": {
+ let expectedError;
+ switch (failureMode) {
+ case "corruption":
+ expectedError = Deno.errors.InvalidData;
+ break;
+ case "shutdown":
+ expectedError = Deno.errors.UnexpectedEof;
+ break;
+ default:
+ unreachable();
+ }
+
+ const tlsTrafficPromise1 = Promise.all([
+ assertThrowsAsync(
+ () => sendBytes(tlsConn1, 0x01, 1),
+ expectedError,
+ ),
+ assertThrowsAsync(
+ () => receiveBytes(tlsConn1, 0x02, 1),
+ expectedError,
+ ),
+ ]);
+
+ const tlsTrafficPromise2 = Promise.all([
+ assertThrowsAsync(
+ () => sendBytes(tlsConn2, 0x02, 1),
+ Deno.errors.UnexpectedEof,
+ ),
+ assertThrowsAsync(
+ () => receiveBytes(tlsConn2, 0x01, 1),
+ Deno.errors.UnexpectedEof,
+ ),
+ ]);
+
+ await tcpForwardingPromise1;
+
+ switch (failureMode) {
+ case "corruption":
+ await sendBytes(tcpConn1, 0xff, 1 << 14 /* 16 kB */);
+ break;
+ case "shutdown":
+ await tcpConn1.closeWrite();
+ break;
+ default:
+ unreachable();
+ }
+ await tlsTrafficPromise1;
+
+ tcpForwardingInterruptPromise2.resolve();
+ await tcpForwardingPromise2;
+ await tcpConn2.closeWrite();
+ await tlsTrafficPromise2;
+
+ break;
+ }
+
+ case "traffic": {
+ await Promise.all([
+ sendBytes(tlsConn2, 0x88, 8888),
+ receiveBytes(tlsConn1, 0x88, 8888),
+ sendBytes(tlsConn1, 0x99, 99999),
+ receiveBytes(tlsConn2, 0x99, 99999),
+ ]);
+
+ tcpForwardingInterruptPromise1.resolve();
+ await tcpForwardingPromise1;
+
+ switch (failureMode) {
+ case "corruption":
+ await sendBytes(tcpConn1, 0xff, 1 << 14 /* 16 kB */);
+ await assertThrowsAsync(
+ () => receiveEof(tlsConn1),
+ Deno.errors.InvalidData,
+ );
+ tcpForwardingInterruptPromise2.resolve();
+ break;
+ case "shutdown":
+ // Receiving a TCP FIN packet without receiving a TLS CloseNotify
+ // alert is not the expected mode of operation, but it is not a
+ // problem either, so it should be treated as if the TLS session was
+ // gracefully closed.
+ await Promise.all([
+ tcpConn1.closeWrite(),
+ await receiveEof(tlsConn1),
+ await tlsConn1.closeWrite(),
+ await receiveEof(tlsConn2),
+ ]);
+ break;
+ default:
+ unreachable();
+ }
+
+ await tcpForwardingPromise2;
+
+ break;
+ }
+
+ default:
+ unreachable();
+ }
+
+ tlsServerConn.close();
+ tlsClientConn.close();
+ tcpServerConn.close();
+ tcpClientConn.close();
+
+ async function sendBytes(
+ conn: Deno.Conn,
+ byte: number,
+ count: number,
+ ): Promise<void> {
+ let buf = new Uint8Array(1 << 12 /* 4 kB */);
+ buf.fill(byte);
+
+ while (count > 0) {
+ buf = buf.subarray(0, Math.min(buf.length, count));
+ const nwritten = await conn.write(buf);
+ assertStrictEquals(nwritten, buf.length);
+ count -= nwritten;
+ }
+ }
+
+ async function receiveBytes(
+ conn: Deno.Conn,
+ byte: number,
+ count: number,
+ ): Promise<void> {
+ let buf = new Uint8Array(1 << 12 /* 4 kB */);
+ while (count > 0) {
+ buf = buf.subarray(0, Math.min(buf.length, count));
+ const r = await conn.read(buf);
+ assertNotEquals(r, null);
+ assert(buf.subarray(0, r!).every((b) => b === byte));
+ count -= r!;
+ }
+ }
+
+ async function receiveEof(conn: Deno.Conn) {
+ const buf = new Uint8Array(1);
+ const r = await conn.read(buf);
+ assertStrictEquals(r, null);
+ }
+
+ async function forwardBytes(
+ source: Deno.Conn,
+ sink: Deno.Conn,
+ count: number,
+ interruptPromise: Deferred<void>,
+ ): Promise<void> {
+ let buf = new Uint8Array(1 << 12 /* 4 kB */);
+ while (count > 0) {
+ buf = buf.subarray(0, Math.min(buf.length, count));
+ const nread = await Promise.race([source.read(buf), interruptPromise]);
+ if (nread == null) break; // Either EOF or interrupted.
+ const nwritten = await sink.write(buf.subarray(0, nread));
+ assertStrictEquals(nread, nwritten);
+ count -= nwritten;
+ }
+ }
+}
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsHandshakeWithTcpCorruptionImmediately() {
+ await tlsWithTcpFailureTestImpl("handshake", 0, "corruption", false);
+ await tlsWithTcpFailureTestImpl("handshake", 0, "corruption", true);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsHandshakeWithTcpShutdownImmediately() {
+ await tlsWithTcpFailureTestImpl("handshake", 0, "shutdown", false);
+ await tlsWithTcpFailureTestImpl("handshake", 0, "shutdown", true);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsHandshakeWithTcpCorruptionAfter70Bytes() {
+ await tlsWithTcpFailureTestImpl("handshake", 76, "corruption", false);
+ await tlsWithTcpFailureTestImpl("handshake", 78, "corruption", true);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsHandshakeWithTcpShutdownAfter70bytes() {
+ await tlsWithTcpFailureTestImpl("handshake", 77, "shutdown", false);
+ await tlsWithTcpFailureTestImpl("handshake", 79, "shutdown", true);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsHandshakeWithTcpCorruptionAfter200Bytes() {
+ await tlsWithTcpFailureTestImpl("handshake", 200, "corruption", false);
+ await tlsWithTcpFailureTestImpl("handshake", 202, "corruption", true);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsHandshakeWithTcpShutdownAfter200bytes() {
+ await tlsWithTcpFailureTestImpl("handshake", 201, "shutdown", false);
+ await tlsWithTcpFailureTestImpl("handshake", 203, "shutdown", true);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsTrafficWithTcpCorruption() {
+ await tlsWithTcpFailureTestImpl("traffic", Infinity, "corruption", false);
+ await tlsWithTcpFailureTestImpl("traffic", Infinity, "corruption", true);
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
+ async function tlsTrafficWithTcpShutdown() {
+ await tlsWithTcpFailureTestImpl("traffic", Infinity, "shutdown", false);
+ await tlsWithTcpFailureTestImpl("traffic", Infinity, "shutdown", true);
+ },
+);
+
+function createHttpsListener(port: number): Deno.Listener {
+ // Query format: `curl --insecure https://localhost:8443/z/12345`
+ // The server returns a response consisting of 12345 times the letter 'z'.
+ const listener = Deno.listenTls({
+ hostname: "localhost",
+ port,
+ certFile: "./cli/tests/tls/localhost.crt",
+ keyFile: "./cli/tests/tls/localhost.key",
+ });
+
+ serve(listener);
+ return listener;
+
+ async function serve(listener: Deno.Listener) {
+ for await (const conn of listener) {
+ const EOL = "\r\n";
+
+ // Read GET request plus headers.
+ const buf = new Uint8Array(1 << 12 /* 4 kB */);
+ const decoder = new TextDecoder();
+ let req = "";
+ while (!req.endsWith(EOL + EOL)) {
+ const n = await conn.read(buf);
+ if (n === null) throw new Error("Unexpected EOF");
+ req += decoder.decode(buf.subarray(0, n));
+ }
+
+ // Parse GET request.
+ const { filler, count, version } =
+ /^GET \/(?<filler>[^\/]+)\/(?<count>\d+) HTTP\/(?<version>1\.\d)\r\n/
+ .exec(req)!.groups as {
+ filler: string;
+ count: string;
+ version: string;
+ };
+
+ // Generate response.
+ const resBody = new TextEncoder().encode(filler.repeat(+count));
+ const resHead = new TextEncoder().encode(
+ [
+ `HTTP/${version} 200 OK`,
+ `Content-Length: ${resBody.length}`,
+ "Content-Type: text/plain",
+ ].join(EOL) + EOL + EOL,
+ );
+
+ // Send response.
+ await conn.write(resHead);
+ await conn.write(resBody);
+
+ // Close TCP connection.
+ conn.close();
+ }
+ }
+}
+
+async function curl(url: string): Promise<string> {
+ const curl = Deno.run({
+ cmd: ["curl", "--insecure", url],
+ stdout: "piped",
+ });
+
+ try {
+ const [status, output] = await Promise.all([curl.status(), curl.output()]);
+ if (!status.success) {
+ throw new Error(`curl ${url} failed: ${status.code}`);
+ }
+ return new TextDecoder().decode(output);
+ } finally {
+ curl.close();
+ }
+}
+
+unitTest(
+ { perms: { read: true, net: true, run: true } },
+ async function curlFakeHttpsServer(): Promise<void> {
+ const port = getPort();
+ const listener = createHttpsListener(port);
+
+ const res1 = await curl(`https://localhost:${port}/d/1`);
+ assertStrictEquals(res1, "d");
+
+ const res2 = await curl(`https://localhost:${port}/e/12345`);
+ assertStrictEquals(res2, "e".repeat(12345));
+
+ const count3 = 1 << 17; // 128 kB.
+ const res3 = await curl(`https://localhost:${port}/n/${count3}`);
+ assertStrictEquals(res3, "n".repeat(count3));
+
+ const count4 = 12345678;
+ const res4 = await curl(`https://localhost:${port}/o/${count4}`);
+ assertStrictEquals(res4, "o".repeat(count4));
+
+ listener.close();
+ },
+);
+
+unitTest(
+ { perms: { read: true, net: true } },
async function startTls(): Promise<void> {
const hostname = "smtp.gmail.com";
const port = 587;
diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml
index 4cefa23ee..4ca0539db 100644
--- a/runtime/Cargo.toml
+++ b/runtime/Cargo.toml
@@ -64,12 +64,12 @@ notify = "5.0.0-pre.7"
percent-encoding = "2.1.0"
regex = "1.4.3"
ring = "0.16.20"
+rustls = "0.19.0"
serde = { version = "1.0.125", features = ["derive"] }
sys-info = "0.9.0"
termcolor = "1.1.2"
tokio = { version = "1.4.0", features = ["full"] }
tokio-util = { version = "0.6", features = ["io"] }
-tokio-rustls = "0.22.0"
uuid = { version = "0.8.2", features = ["v4"] }
webpki = "0.21.4"
webpki-roots = "0.21.1"
diff --git a/runtime/ops/http.rs b/runtime/ops/http.rs
index 3642a0ac3..e4ba2db2a 100644
--- a/runtime/ops/http.rs
+++ b/runtime/ops/http.rs
@@ -1,7 +1,8 @@
// Copyright 2018-2021 the Deno authors. All rights reserved. MIT license.
use crate::ops::io::TcpStreamResource;
-use crate::ops::io::TlsServerStreamResource;
+use crate::ops::io::TlsStreamResource;
+use crate::ops::tls::TlsStream;
use deno_core::error::bad_resource_id;
use deno_core::error::null_opbuf;
use deno_core::error::type_error;
@@ -43,7 +44,6 @@ use std::task::Poll;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::oneshot;
-use tokio_rustls::server::TlsStream;
use tokio_util::io::StreamReader;
pub fn init() -> Extension {
@@ -100,7 +100,7 @@ impl HyperService<Request<Body>> for Service {
enum ConnType {
Tcp(Rc<RefCell<Connection<TcpStream, Service, LocalExecutor>>>),
- Tls(Rc<RefCell<Connection<TlsStream<TcpStream>, Service, LocalExecutor>>>),
+ Tls(Rc<RefCell<Connection<TlsStream, Service, LocalExecutor>>>),
}
struct ConnResource {
@@ -305,12 +305,12 @@ fn op_http_start(
if let Some(resource_rc) = state
.resource_table
- .take::<TlsServerStreamResource>(tcp_stream_rid)
+ .take::<TlsStreamResource>(tcp_stream_rid)
{
let resource = Rc::try_unwrap(resource_rc)
.expect("Only a single use of this resource should happen");
let (read_half, write_half) = resource.into_inner();
- let tls_stream = read_half.unsplit(write_half);
+ let tls_stream = read_half.reunite(write_half);
let addr = tls_stream.get_ref().0.local_addr()?;
let hyper_connection = Http::new()
diff --git a/runtime/ops/io.rs b/runtime/ops/io.rs
index c7faa73d7..d9f21e1f5 100644
--- a/runtime/ops/io.rs
+++ b/runtime/ops/io.rs
@@ -1,5 +1,6 @@
// Copyright 2018-2021 the Deno authors. All rights reserved. MIT license.
+use crate::ops::tls;
use deno_core::error::null_opbuf;
use deno_core::error::resource_unavailable;
use deno_core::error::AnyError;
@@ -21,17 +22,12 @@ use std::cell::RefCell;
use std::io::Read;
use std::io::Write;
use std::rc::Rc;
-use tokio::io::split;
use tokio::io::AsyncRead;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;
-use tokio::io::ReadHalf;
-use tokio::io::WriteHalf;
use tokio::net::tcp;
-use tokio::net::TcpStream;
use tokio::process;
-use tokio_rustls as tls;
#[cfg(unix)]
use std::os::unix::io::FromRawFd;
@@ -306,18 +302,6 @@ where
}
}
-pub type FullDuplexSplitResource<S> =
- FullDuplexResource<ReadHalf<S>, WriteHalf<S>>;
-
-impl<S> From<S> for FullDuplexSplitResource<S>
-where
- S: AsyncRead + AsyncWrite + 'static,
-{
- fn from(stream: S) -> Self {
- Self::new(split(stream))
- }
-}
-
pub type ChildStdinResource = WriteOnlyResource<process::ChildStdin>;
impl Resource for ChildStdinResource {
@@ -363,25 +347,11 @@ impl Resource for TcpStreamResource {
}
}
-pub type TlsClientStreamResource =
- FullDuplexSplitResource<tls::client::TlsStream<TcpStream>>;
+pub type TlsStreamResource = FullDuplexResource<tls::ReadHalf, tls::WriteHalf>;
-impl Resource for TlsClientStreamResource {
+impl Resource for TlsStreamResource {
fn name(&self) -> Cow<str> {
- "tlsClientStream".into()
- }
-
- fn close(self: Rc<Self>) {
- self.cancel_read_ops();
- }
-}
-
-pub type TlsServerStreamResource =
- FullDuplexSplitResource<tls::server::TlsStream<TcpStream>>;
-
-impl Resource for TlsServerStreamResource {
- fn name(&self) -> Cow<str> {
- "tlsServerStream".into()
+ "tlsStream".into()
}
fn close(self: Rc<Self>) {
@@ -572,9 +542,7 @@ async fn op_read_async(
s.read(buf).await?
} else if let Some(s) = resource.downcast_rc::<TcpStreamResource>() {
s.read(buf).await?
- } else if let Some(s) = resource.downcast_rc::<TlsClientStreamResource>() {
- s.read(buf).await?
- } else if let Some(s) = resource.downcast_rc::<TlsServerStreamResource>() {
+ } else if let Some(s) = resource.downcast_rc::<TlsStreamResource>() {
s.read(buf).await?
} else if let Some(s) = resource.downcast_rc::<UnixStreamResource>() {
s.read(buf).await?
@@ -616,9 +584,7 @@ async fn op_write_async(
s.write(buf).await?
} else if let Some(s) = resource.downcast_rc::<TcpStreamResource>() {
s.write(buf).await?
- } else if let Some(s) = resource.downcast_rc::<TlsClientStreamResource>() {
- s.write(buf).await?
- } else if let Some(s) = resource.downcast_rc::<TlsServerStreamResource>() {
+ } else if let Some(s) = resource.downcast_rc::<TlsStreamResource>() {
s.write(buf).await?
} else if let Some(s) = resource.downcast_rc::<UnixStreamResource>() {
s.write(buf).await?
@@ -644,9 +610,7 @@ async fn op_shutdown(
s.shutdown().await?;
} else if let Some(s) = resource.downcast_rc::<TcpStreamResource>() {
s.shutdown().await?;
- } else if let Some(s) = resource.downcast_rc::<TlsClientStreamResource>() {
- s.shutdown().await?;
- } else if let Some(s) = resource.downcast_rc::<TlsServerStreamResource>() {
+ } else if let Some(s) = resource.downcast_rc::<TlsStreamResource>() {
s.shutdown().await?;
} else if let Some(s) = resource.downcast_rc::<UnixStreamResource>() {
s.shutdown().await?;
diff --git a/runtime/ops/tls.rs b/runtime/ops/tls.rs
index 9143a86fa..c3f554856 100644
--- a/runtime/ops/tls.rs
+++ b/runtime/ops/tls.rs
@@ -1,11 +1,13 @@
// Copyright 2018-2021 the Deno authors. All rights reserved. MIT license.
-use super::io::TcpStreamResource;
-use super::io::TlsClientStreamResource;
-use super::io::TlsServerStreamResource;
-use super::net::IpAddr;
-use super::net::OpAddr;
-use super::net::OpConn;
+pub use rustls;
+pub use webpki;
+
+use crate::ops::io::TcpStreamResource;
+use crate::ops::io::TlsStreamResource;
+use crate::ops::net::IpAddr;
+use crate::ops::net::OpAddr;
+use crate::ops::net::OpConn;
use crate::permissions::Permissions;
use crate::resolve_addr::resolve_addr;
use crate::resolve_addr::resolve_addr_sync;
@@ -15,6 +17,15 @@ use deno_core::error::custom_error;
use deno_core::error::generic_error;
use deno_core::error::invalid_hostname;
use deno_core::error::AnyError;
+use deno_core::futures::future::poll_fn;
+use deno_core::futures::ready;
+use deno_core::futures::task::noop_waker_ref;
+use deno_core::futures::task::AtomicWaker;
+use deno_core::futures::task::Context;
+use deno_core::futures::task::Poll;
+use deno_core::futures::task::RawWaker;
+use deno_core::futures::task::RawWakerVTable;
+use deno_core::futures::task::Waker;
use deno_core::op_async;
use deno_core::op_sync;
use deno_core::AsyncRefCell;
@@ -25,27 +36,44 @@ use deno_core::OpState;
use deno_core::RcRef;
use deno_core::Resource;
use deno_core::ResourceId;
+use io::Error;
+use io::Read;
+use io::Write;
+use rustls::internal::pemfile::certs;
+use rustls::internal::pemfile::pkcs8_private_keys;
+use rustls::internal::pemfile::rsa_private_keys;
+use rustls::Certificate;
+use rustls::ClientConfig;
+use rustls::ClientSession;
+use rustls::NoClientAuth;
+use rustls::PrivateKey;
+use rustls::ServerConfig;
+use rustls::ServerSession;
+use rustls::Session;
+use rustls::StoresClientSessions;
use serde::Deserialize;
use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::HashMap;
use std::convert::From;
use std::fs::File;
+use std::io;
use std::io::BufReader;
+use std::io::ErrorKind;
+use std::ops::Deref;
+use std::ops::DerefMut;
use std::path::Path;
+use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;
use std::sync::Mutex;
+use std::sync::Weak;
+use tokio::io::AsyncRead;
+use tokio::io::AsyncWrite;
+use tokio::io::ReadBuf;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
-use tokio_rustls::{rustls::ClientConfig, TlsConnector};
-use tokio_rustls::{
- rustls::{
- internal::pemfile::{certs, pkcs8_private_keys, rsa_private_keys},
- Certificate, NoClientAuth, PrivateKey, ServerConfig, StoresClientSessions,
- },
- TlsAcceptor,
-};
+use tokio::task::spawn_local;
use webpki::DNSNameRef;
lazy_static::lazy_static! {
@@ -73,6 +101,567 @@ impl StoresClientSessions for ClientSessionMemoryCache {
}
}
+#[derive(Debug)]
+enum TlsSession {
+ Client(ClientSession),
+ Server(ServerSession),
+}
+
+impl Deref for TlsSession {
+ type Target = dyn Session;
+
+ fn deref(&self) -> &Self::Target {
+ match self {
+ TlsSession::Client(client_session) => client_session,
+ TlsSession::Server(server_session) => server_session,
+ }
+ }
+}
+
+impl DerefMut for TlsSession {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ match self {
+ TlsSession::Client(client_session) => client_session,
+ TlsSession::Server(server_session) => server_session,
+ }
+ }
+}
+
+impl From<ClientSession> for TlsSession {
+ fn from(client_session: ClientSession) -> Self {
+ TlsSession::Client(client_session)
+ }
+}
+
+impl From<ServerSession> for TlsSession {
+ fn from(server_session: ServerSession) -> Self {
+ TlsSession::Server(server_session)
+ }
+}
+
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+enum Flow {
+ Read,
+ Write,
+}
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+enum State {
+ StreamOpen,
+ StreamClosed,
+ TlsClosing,
+ TlsClosed,
+ TcpClosed,
+}
+
+#[derive(Debug)]
+pub struct TlsStream(Option<TlsStreamInner>);
+
+impl TlsStream {
+ fn new(tcp: TcpStream, tls: TlsSession) -> Self {
+ let inner = TlsStreamInner {
+ tcp,
+ tls,
+ rd_state: State::StreamOpen,
+ wr_state: State::StreamOpen,
+ };
+ Self(Some(inner))
+ }
+
+ pub fn new_client_side(
+ tcp: TcpStream,
+ tls_config: &Arc<ClientConfig>,
+ hostname: DNSNameRef,
+ ) -> Self {
+ let tls = TlsSession::Client(ClientSession::new(tls_config, hostname));
+ Self::new(tcp, tls)
+ }
+
+ pub fn new_server_side(
+ tcp: TcpStream,
+ tls_config: &Arc<ServerConfig>,
+ ) -> Self {
+ let tls = TlsSession::Server(ServerSession::new(tls_config));
+ Self::new(tcp, tls)
+ }
+
+ pub async fn handshake(&mut self) -> io::Result<()> {
+ poll_fn(|cx| self.inner_mut().poll_io(cx, Flow::Write)).await
+ }
+
+ fn into_split(self) -> (ReadHalf, WriteHalf) {
+ let shared = Shared::new(self);
+ let rd = ReadHalf {
+ shared: shared.clone(),
+ };
+ let wr = WriteHalf { shared };
+ (rd, wr)
+ }
+
+ /// Tokio-rustls compatibility: returns a reference to the underlying TCP
+ /// stream, and a reference to the Rustls `Session` object.
+ pub fn get_ref(&self) -> (&TcpStream, &dyn Session) {
+ let inner = self.0.as_ref().unwrap();
+ (&inner.tcp, &*inner.tls)
+ }
+
+ fn inner_mut(&mut self) -> &mut TlsStreamInner {
+ self.0.as_mut().unwrap()
+ }
+}
+
+impl AsyncRead for TlsStream {
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ self.inner_mut().poll_read(cx, buf)
+ }
+}
+
+impl AsyncWrite for TlsStream {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ self.inner_mut().poll_write(cx, buf)
+ }
+
+ fn poll_flush(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<io::Result<()>> {
+ self.inner_mut().poll_io(cx, Flow::Write)
+ // The underlying TCP stream does not need to be flushed.
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<io::Result<()>> {
+ self.inner_mut().poll_shutdown(cx)
+ }
+}
+
+impl Drop for TlsStream {
+ fn drop(&mut self) {
+ let mut inner = self.0.take().unwrap();
+
+ let mut cx = Context::from_waker(noop_waker_ref());
+ let use_linger_task = inner.poll_close(&mut cx).is_pending();
+
+ if use_linger_task {
+ spawn_local(poll_fn(move |cx| inner.poll_close(cx)));
+ } else if cfg!(debug_assertions) {
+ spawn_local(async {}); // Spawn dummy task to detect missing LocalSet.
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct TlsStreamInner {
+ tls: TlsSession,
+ tcp: TcpStream,
+ rd_state: State,
+ wr_state: State,
+}
+
+impl TlsStreamInner {
+ fn poll_io(
+ &mut self,
+ cx: &mut Context<'_>,
+ flow: Flow,
+ ) -> Poll<io::Result<()>> {
+ loop {
+ let wr_ready = loop {
+ match self.wr_state {
+ _ if self.tls.is_handshaking() && !self.tls.wants_write() => {
+ break true;
+ }
+ _ if self.tls.is_handshaking() => {}
+ State::StreamOpen if !self.tls.wants_write() => break true,
+ State::StreamClosed => {
+ // Rustls will enqueue the 'CloseNotify' alert and send it after
+ // flusing the data that is already in the queue.
+ self.tls.send_close_notify();
+ self.wr_state = State::TlsClosing;
+ continue;
+ }
+ State::TlsClosing if !self.tls.wants_write() => {
+ self.wr_state = State::TlsClosed;
+ continue;
+ }
+ // If a 'CloseNotify' alert sent by the remote end has been received,
+ // shut down the underlying TCP socket. Otherwise, consider polling
+ // done for the moment.
+ State::TlsClosed if self.rd_state < State::TlsClosed => break true,
+ State::TlsClosed
+ if Pin::new(&mut self.tcp).poll_shutdown(cx)?.is_pending() =>
+ {
+ break false;
+ }
+ State::TlsClosed => {
+ self.wr_state = State::TcpClosed;
+ continue;
+ }
+ State::TcpClosed => break true,
+ _ => {}
+ }
+
+ // Poll whether there is space in the socket send buffer so we can flush
+ // the remaining outgoing ciphertext.
+ if self.tcp.poll_write_ready(cx)?.is_pending() {
+ break false;
+ }
+
+ // Write ciphertext to the TCP socket.
+ let mut wrapped_tcp = ImplementWriteTrait(&mut self.tcp);
+ match self.tls.write_tls(&mut wrapped_tcp) {
+ Ok(0) => unreachable!(),
+ Ok(_) => {}
+ Err(err) if err.kind() == ErrorKind::WouldBlock => {}
+ Err(err) => return Poll::Ready(Err(err)),
+ }
+ };
+
+ let rd_ready = loop {
+ match self.rd_state {
+ State::TcpClosed if self.tls.is_handshaking() => {
+ let err = Error::new(ErrorKind::UnexpectedEof, "tls handshake eof");
+ return Poll::Ready(Err(err));
+ }
+ _ if self.tls.is_handshaking() && !self.tls.wants_read() => {
+ break true;
+ }
+ _ if self.tls.is_handshaking() => {}
+ State::StreamOpen if !self.tls.wants_read() => break true,
+ State::StreamOpen => {}
+ State::StreamClosed if !self.tls.wants_read() => {
+ // Rustls has more incoming cleartext buffered up, but the TLS
+ // session is closing so this data will never be processed by the
+ // application layer. Just like what would happen if this were a raw
+ // TCP stream, don't gracefully end the TLS session, but abort it.
+ return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset)));
+ }
+ State::StreamClosed => {}
+ State::TlsClosed if self.wr_state == State::TcpClosed => {
+ // Wait for the remote end to gracefully close the TCP connection.
+ // TODO(piscisaureus): this is unnecessary; remove when stable.
+ }
+ _ => break true,
+ }
+
+ if self.rd_state < State::TlsClosed {
+ // Do a zero-length plaintext read so we can detect the arrival of
+ // 'CloseNotify' messages, even if only the write half is open.
+ // Actually reading data from the socket is done in `poll_read()`.
+ match self.tls.read(&mut []) {
+ Ok(0) => {}
+ Err(err) if err.kind() == ErrorKind::ConnectionAborted => {
+ // `Session::read()` returns `ConnectionAborted` when a
+ // 'CloseNotify' alert has been received, which indicates that
+ // the remote peer wants to gracefully end the TLS session.
+ self.rd_state = State::TlsClosed;
+ continue;
+ }
+ Err(err) => return Poll::Ready(Err(err)),
+ _ => unreachable!(),
+ }
+ }
+
+ // Poll whether more ciphertext is available in the socket receive
+ // buffer.
+ if self.tcp.poll_read_ready(cx)?.is_pending() {
+ break false;
+ }
+
+ // Receive ciphertext from the socket.
+ let mut wrapped_tcp = ImplementReadTrait(&mut self.tcp);
+ match self.tls.read_tls(&mut wrapped_tcp) {
+ Ok(0) => self.rd_state = State::TcpClosed,
+ Ok(_) => self
+ .tls
+ .process_new_packets()
+ .map_err(|err| Error::new(ErrorKind::InvalidData, err))?,
+ Err(err) if err.kind() == ErrorKind::WouldBlock => {}
+ Err(err) => return Poll::Ready(Err(err)),
+ }
+ };
+
+ if wr_ready {
+ if self.rd_state >= State::TlsClosed
+ && self.wr_state >= State::TlsClosed
+ && self.wr_state < State::TcpClosed
+ {
+ continue;
+ }
+ if self.tls.wants_write() {
+ continue;
+ }
+ }
+
+ let io_ready = match flow {
+ _ if self.tls.is_handshaking() => false,
+ Flow::Read => rd_ready,
+ Flow::Write => wr_ready,
+ };
+ return match io_ready {
+ false => Poll::Pending,
+ true => Poll::Ready(Ok(())),
+ };
+ }
+ }
+
+ fn poll_read(
+ &mut self,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ ready!(self.poll_io(cx, Flow::Read))?;
+
+ if self.rd_state == State::StreamOpen {
+ let buf_slice =
+ unsafe { &mut *(buf.unfilled_mut() as *mut [_] as *mut [u8]) };
+ let bytes_read = self.tls.read(buf_slice)?;
+ assert_ne!(bytes_read, 0);
+ unsafe { buf.assume_init(bytes_read) };
+ buf.advance(bytes_read);
+ }
+
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_write(
+ &mut self,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ if buf.is_empty() {
+ // Tokio-rustls compatibility: a zero byte write always succeeds.
+ Poll::Ready(Ok(0))
+ } else if self.wr_state == State::StreamOpen {
+ // Flush Rustls' ciphertext send queue.
+ ready!(self.poll_io(cx, Flow::Write))?;
+
+ // Copy data from `buf` to the Rustls cleartext send queue.
+ let bytes_written = self.tls.write(buf)?;
+ assert_ne!(bytes_written, 0);
+
+ // Try to flush as much ciphertext as possible. However, since we just
+ // handed off at least some bytes to rustls, so we can't return
+ // `Poll::Pending()` any more: this would tell the caller that it should
+ // try to send those bytes again.
+ let _ = self.poll_io(cx, Flow::Write)?;
+
+ Poll::Ready(Ok(bytes_written))
+ } else {
+ // Return error if stream has been shut down for writing.
+ Poll::Ready(Err(ErrorKind::BrokenPipe.into()))
+ }
+ }
+
+ fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ if self.wr_state == State::StreamOpen {
+ self.wr_state = State::StreamClosed;
+ }
+
+ ready!(self.poll_io(cx, Flow::Write))?;
+
+ // At minimum, a TLS 'CloseNotify' alert should have been sent.
+ assert!(self.wr_state >= State::TlsClosed);
+ // If we received a TLS 'CloseNotify' alert from the remote end
+ // already, the TCP socket should be shut down at this point.
+ assert!(
+ self.rd_state < State::TlsClosed || self.wr_state == State::TcpClosed
+ );
+
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ if self.rd_state == State::StreamOpen {
+ self.rd_state = State::StreamClosed;
+ }
+
+ // Send TLS 'CloseNotify' alert.
+ ready!(self.poll_shutdown(cx))?;
+ // Wait for 'CloseNotify', shut down TCP stream, wait for TCP FIN packet.
+ ready!(self.poll_io(cx, Flow::Read))?;
+
+ assert_eq!(self.rd_state, State::TcpClosed);
+ assert_eq!(self.wr_state, State::TcpClosed);
+
+ Poll::Ready(Ok(()))
+ }
+}
+
+#[derive(Debug)]
+pub struct ReadHalf {
+ shared: Arc<Shared>,
+}
+
+impl ReadHalf {
+ pub fn reunite(self, wr: WriteHalf) -> TlsStream {
+ assert!(Arc::ptr_eq(&self.shared, &wr.shared));
+ drop(wr); // Drop `wr`, so only one strong reference to `shared` remains.
+
+ Arc::try_unwrap(self.shared)
+ .unwrap_or_else(|_| panic!("Arc::<Shared>::try_unwrap() failed"))
+ .tls_stream
+ .into_inner()
+ .unwrap()
+ }
+}
+
+impl AsyncRead for ReadHalf {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ self
+ .shared
+ .poll_with_shared_waker(cx, Flow::Read, move |tls, cx| {
+ tls.poll_read(cx, buf)
+ })
+ }
+}
+
+#[derive(Debug)]
+pub struct WriteHalf {
+ shared: Arc<Shared>,
+}
+
+impl AsyncWrite for WriteHalf {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ self
+ .shared
+ .poll_with_shared_waker(cx, Flow::Write, move |tls, cx| {
+ tls.poll_write(cx, buf)
+ })
+ }
+
+ fn poll_flush(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<io::Result<()>> {
+ self
+ .shared
+ .poll_with_shared_waker(cx, Flow::Write, |tls, cx| tls.poll_flush(cx))
+ }
+
+ fn poll_shutdown(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<io::Result<()>> {
+ self
+ .shared
+ .poll_with_shared_waker(cx, Flow::Write, |tls, cx| tls.poll_shutdown(cx))
+ }
+}
+
+#[derive(Debug)]
+struct Shared {
+ tls_stream: Mutex<TlsStream>,
+ rd_waker: AtomicWaker,
+ wr_waker: AtomicWaker,
+}
+
+impl Shared {
+ fn new(tls_stream: TlsStream) -> Arc<Self> {
+ let self_ = Self {
+ tls_stream: Mutex::new(tls_stream),
+ rd_waker: AtomicWaker::new(),
+ wr_waker: AtomicWaker::new(),
+ };
+ Arc::new(self_)
+ }
+
+ fn poll_with_shared_waker<R>(
+ self: &Arc<Self>,
+ cx: &mut Context<'_>,
+ flow: Flow,
+ mut f: impl FnMut(Pin<&mut TlsStream>, &mut Context<'_>) -> R,
+ ) -> R {
+ match flow {
+ Flow::Read => self.rd_waker.register(cx.waker()),
+ Flow::Write => self.wr_waker.register(cx.waker()),
+ }
+
+ let shared_waker = self.new_shared_waker();
+ let mut cx = Context::from_waker(&shared_waker);
+
+ let mut tls_stream = self.tls_stream.lock().unwrap();
+ f(Pin::new(&mut tls_stream), &mut cx)
+ }
+
+ const SHARED_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
+ Self::clone_shared_waker,
+ Self::wake_shared_waker,
+ Self::wake_shared_waker_by_ref,
+ Self::drop_shared_waker,
+ );
+
+ fn new_shared_waker(self: &Arc<Self>) -> Waker {
+ let self_weak = Arc::downgrade(self);
+ let self_ptr = self_weak.into_raw() as *const ();
+ let raw_waker = RawWaker::new(self_ptr, &Self::SHARED_WAKER_VTABLE);
+ unsafe { Waker::from_raw(raw_waker) }
+ }
+
+ fn clone_shared_waker(self_ptr: *const ()) -> RawWaker {
+ let self_weak = unsafe { Weak::from_raw(self_ptr as *const Self) };
+ let ptr1 = self_weak.clone().into_raw();
+ let ptr2 = self_weak.into_raw();
+ assert!(ptr1 == ptr2);
+ RawWaker::new(self_ptr, &Self::SHARED_WAKER_VTABLE)
+ }
+
+ fn wake_shared_waker(self_ptr: *const ()) {
+ Self::wake_shared_waker_by_ref(self_ptr);
+ Self::drop_shared_waker(self_ptr);
+ }
+
+ fn wake_shared_waker_by_ref(self_ptr: *const ()) {
+ let self_weak = unsafe { Weak::from_raw(self_ptr as *const Self) };
+ if let Some(self_arc) = Weak::upgrade(&self_weak) {
+ self_arc.rd_waker.wake();
+ self_arc.wr_waker.wake();
+ }
+ self_weak.into_raw();
+ }
+
+ fn drop_shared_waker(self_ptr: *const ()) {
+ let _ = unsafe { Weak::from_raw(self_ptr as *const Self) };
+ }
+}
+
+struct ImplementReadTrait<'a, T>(&'a mut T);
+
+impl Read for ImplementReadTrait<'_, TcpStream> {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ self.0.try_read(buf)
+ }
+}
+
+struct ImplementWriteTrait<'a, T>(&'a mut T);
+
+impl Write for ImplementWriteTrait<'_, TcpStream> {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ self.0.try_write(buf)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ Ok(())
+ }
+}
+
pub fn init() -> Extension {
Extension::builder()
.ops(vec![
@@ -107,21 +696,25 @@ async fn op_start_tls(
_: (),
) -> Result<OpConn, AnyError> {
let rid = args.rid;
+ let hostname = match &*args.hostname {
+ "" => "localhost",
+ n => n,
+ };
+ let cert_file = args.cert_file.as_deref();
- let mut domain = args.hostname.as_str();
- if domain.is_empty() {
- domain = "localhost";
- }
{
super::check_unstable2(&state, "Deno.startTls");
let mut s = state.borrow_mut();
let permissions = s.borrow_mut::<Permissions>();
- permissions.net.check(&(&domain, Some(0)))?;
- if let Some(path) = &args.cert_file {
- permissions.read.check(Path::new(&path))?;
+ permissions.net.check(&(hostname, Some(0)))?;
+ if let Some(path) = cert_file {
+ permissions.read.check(Path::new(path))?;
}
}
+ let hostname_dns = DNSNameRef::try_from_ascii_str(hostname)
+ .map_err(|_| invalid_hostname(hostname))?;
+
let resource_rc = state
.borrow_mut()
.resource_table
@@ -134,28 +727,29 @@ async fn op_start_tls(
let local_addr = tcp_stream.local_addr()?;
let remote_addr = tcp_stream.peer_addr()?;
- let mut config = ClientConfig::new();
- config.set_persistence(CLIENT_SESSION_MEMORY_CACHE.clone());
- config
+
+ let mut tls_config = ClientConfig::new();
+ tls_config.set_persistence(CLIENT_SESSION_MEMORY_CACHE.clone());
+ tls_config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
- if let Some(path) = args.cert_file {
+ if let Some(path) = cert_file {
let key_file = File::open(path)?;
let reader = &mut BufReader::new(key_file);
- config.root_store.add_pem_file(reader).unwrap();
+ tls_config.root_store.add_pem_file(reader).unwrap();
}
+ let tls_config = Arc::new(tls_config);
- let tls_connector = TlsConnector::from(Arc::new(config));
- let dnsname = DNSNameRef::try_from_ascii_str(domain)
- .map_err(|_| invalid_hostname(domain))?;
- let tls_stream = tls_connector.connect(dnsname, tcp_stream).await?;
+ let tls_stream =
+ TlsStream::new_client_side(tcp_stream, &tls_config, hostname_dns);
let rid = {
let mut state_ = state.borrow_mut();
state_
.resource_table
- .add(TlsClientStreamResource::from(tls_stream))
+ .add(TlsStreamResource::new(tls_stream.into_split()))
};
+
Ok(OpConn {
rid,
local_addr: Some(OpAddr::Tcp(IpAddr {
@@ -175,47 +769,55 @@ async fn op_connect_tls(
_: (),
) -> Result<OpConn, AnyError> {
assert_eq!(args.transport, "tcp");
+ let hostname = match &*args.hostname {
+ "" => "localhost",
+ n => n,
+ };
+ let port = args.port;
+ let cert_file = args.cert_file.as_deref();
- let mut domain = args.hostname.as_str();
- if domain.is_empty() {
- domain = "localhost";
- }
{
let mut s = state.borrow_mut();
let permissions = s.borrow_mut::<Permissions>();
- permissions.net.check(&(domain, Some(args.port)))?;
- if let Some(path) = &args.cert_file {
- permissions.read.check(Path::new(&path))?;
+ permissions.net.check(&(hostname, Some(port)))?;
+ if let Some(path) = cert_file {
+ permissions.read.check(Path::new(path))?;
}
}
- let dnsname = DNSNameRef::try_from_ascii_str(domain)
- .map_err(|_| invalid_hostname(domain))?;
- let addr = resolve_addr(domain, args.port)
+ let hostname_dns = DNSNameRef::try_from_ascii_str(hostname)
+ .map_err(|_| invalid_hostname(hostname))?;
+
+ let connect_addr = resolve_addr(hostname, port)
.await?
.next()
.ok_or_else(|| generic_error("No resolved address found"))?;
- let tcp_stream = TcpStream::connect(&addr).await?;
+ let tcp_stream = TcpStream::connect(connect_addr).await?;
let local_addr = tcp_stream.local_addr()?;
let remote_addr = tcp_stream.peer_addr()?;
- let mut config = ClientConfig::new();
- config.set_persistence(CLIENT_SESSION_MEMORY_CACHE.clone());
- config
+
+ let mut tls_config = ClientConfig::new();
+ tls_config.set_persistence(CLIENT_SESSION_MEMORY_CACHE.clone());
+ tls_config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
- if let Some(path) = args.cert_file {
+ if let Some(path) = cert_file {
let key_file = File::open(path)?;
let reader = &mut BufReader::new(key_file);
- config.root_store.add_pem_file(reader).unwrap();
+ tls_config.root_store.add_pem_file(reader).unwrap();
}
- let tls_connector = TlsConnector::from(Arc::new(config));
- let tls_stream = tls_connector.connect(dnsname, tcp_stream).await?;
+ let tls_config = Arc::new(tls_config);
+
+ let tls_stream =
+ TlsStream::new_client_side(tcp_stream, &tls_config, hostname_dns);
+
let rid = {
let mut state_ = state.borrow_mut();
state_
.resource_table
- .add(TlsClientStreamResource::from(tls_stream))
+ .add(TlsStreamResource::new(tls_stream.into_split()))
};
+
Ok(OpConn {
rid,
local_addr: Some(OpAddr::Tcp(IpAddr {
@@ -284,9 +886,9 @@ fn load_keys(path: &str) -> Result<Vec<PrivateKey>, AnyError> {
}
pub struct TlsListenerResource {
- listener: AsyncRefCell<TcpListener>,
- tls_acceptor: TlsAcceptor,
- cancel: CancelHandle,
+ tcp_listener: AsyncRefCell<TcpListener>,
+ tls_config: Arc<ServerConfig>,
+ cancel_handle: CancelHandle,
}
impl Resource for TlsListenerResource {
@@ -295,7 +897,7 @@ impl Resource for TlsListenerResource {
}
fn close(self: Rc<Self>) {
- self.cancel.cancel();
+ self.cancel_handle.cancel();
}
}
@@ -316,36 +918,40 @@ fn op_listen_tls(
_: (),
) -> Result<OpConn, AnyError> {
assert_eq!(args.transport, "tcp");
+ let hostname = &*args.hostname;
+ let port = args.port;
+ let cert_file = &*args.cert_file;
+ let key_file = &*args.key_file;
- let cert_file = args.cert_file;
- let key_file = args.key_file;
{
let permissions = state.borrow_mut::<Permissions>();
- permissions.net.check(&(&args.hostname, Some(args.port)))?;
- permissions.read.check(Path::new(&cert_file))?;
- permissions.read.check(Path::new(&key_file))?;
+ permissions.net.check(&(hostname, Some(port)))?;
+ permissions.read.check(Path::new(cert_file))?;
+ permissions.read.check(Path::new(key_file))?;
}
- let mut config = ServerConfig::new(NoClientAuth::new());
+
+ let mut tls_config = ServerConfig::new(NoClientAuth::new());
if let Some(alpn_protocols) = args.alpn_protocols {
super::check_unstable(state, "Deno.listenTls#alpn_protocols");
- config.alpn_protocols =
+ tls_config.alpn_protocols =
alpn_protocols.into_iter().map(|s| s.into_bytes()).collect();
}
- config
- .set_single_cert(load_certs(&cert_file)?, load_keys(&key_file)?.remove(0))
+ tls_config
+ .set_single_cert(load_certs(cert_file)?, load_keys(key_file)?.remove(0))
.expect("invalid key or certificate");
- let tls_acceptor = TlsAcceptor::from(Arc::new(config));
- let addr = resolve_addr_sync(&args.hostname, args.port)?
+
+ let bind_addr = resolve_addr_sync(hostname, port)?
.next()
.ok_or_else(|| generic_error("No resolved address found"))?;
- let std_listener = std::net::TcpListener::bind(&addr)?;
+ let std_listener = std::net::TcpListener::bind(bind_addr)?;
std_listener.set_nonblocking(true)?;
- let listener = TcpListener::from_std(std_listener)?;
- let local_addr = listener.local_addr()?;
+ let tcp_listener = TcpListener::from_std(std_listener)?;
+ let local_addr = tcp_listener.local_addr()?;
+
let tls_listener_resource = TlsListenerResource {
- listener: AsyncRefCell::new(listener),
- tls_acceptor,
- cancel: Default::default(),
+ tcp_listener: AsyncRefCell::new(tcp_listener),
+ tls_config: Arc::new(tls_config),
+ cancel_handle: Default::default(),
};
let rid = state.resource_table.add(tls_listener_resource);
@@ -370,38 +976,31 @@ async fn op_accept_tls(
.resource_table
.get::<TlsListenerResource>(rid)
.ok_or_else(|| bad_resource("Listener has been closed"))?;
- let listener = RcRef::map(&resource, |r| &r.listener)
+
+ let cancel_handle = RcRef::map(&resource, |r| &r.cancel_handle);
+ let tcp_listener = RcRef::map(&resource, |r| &r.tcp_listener)
.try_borrow_mut()
.ok_or_else(|| custom_error("Busy", "Another accept task is ongoing"))?;
- let cancel = RcRef::map(resource, |r| &r.cancel);
- let (tcp_stream, _socket_addr) =
- listener.accept().try_or_cancel(cancel).await.map_err(|e| {
- // FIXME(bartlomieju): compatibility with current JS implementation
- if let std::io::ErrorKind::Interrupted = e.kind() {
- bad_resource("Listener has been closed")
- } else {
- e.into()
+
+ let (tcp_stream, remote_addr) =
+ match tcp_listener.accept().try_or_cancel(&cancel_handle).await {
+ Ok(tuple) => tuple,
+ Err(err) if err.kind() == ErrorKind::Interrupted => {
+ // FIXME(bartlomieju): compatibility with current JS implementation.
+ return Err(bad_resource("Listener has been closed"));
}
- })?;
+ Err(err) => return Err(err.into()),
+ };
+
let local_addr = tcp_stream.local_addr()?;
- let remote_addr = tcp_stream.peer_addr()?;
- let resource = state
- .borrow()
- .resource_table
- .get::<TlsListenerResource>(rid)
- .ok_or_else(|| bad_resource("Listener has been closed"))?;
- let cancel = RcRef::map(&resource, |r| &r.cancel);
- let tls_acceptor = resource.tls_acceptor.clone();
- let tls_stream = tls_acceptor
- .accept(tcp_stream)
- .try_or_cancel(cancel)
- .await?;
+
+ let tls_stream = TlsStream::new_server_side(tcp_stream, &resource.tls_config);
let rid = {
let mut state_ = state.borrow_mut();
state_
.resource_table
- .add(TlsServerStreamResource::from(tls_stream))
+ .add(TlsStreamResource::new(tls_stream.into_split()))
};
Ok(OpConn {