summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cli/tests/unit/flash_test.ts62
-rw-r--r--ext/flash/01_http.js132
-rw-r--r--ext/flash/lib.rs381
-rw-r--r--ext/flash/socket.rs147
4 files changed, 384 insertions, 338 deletions
diff --git a/cli/tests/unit/flash_test.ts b/cli/tests/unit/flash_test.ts
index fef45beb9..78340a390 100644
--- a/cli/tests/unit/flash_test.ts
+++ b/cli/tests/unit/flash_test.ts
@@ -1283,29 +1283,35 @@ function createServerLengthTest(name: string, testCase: TestCase) {
await promise;
const decoder = new TextDecoder();
- const buf = new Uint8Array(1024);
- const readResult = await conn.read(buf);
- assert(readResult);
- const msg = decoder.decode(buf.subarray(0, readResult));
-
- try {
- assert(testCase.expects_chunked == hasHeader(msg, "Transfer-Encoding:"));
- assert(testCase.expects_chunked == hasHeader(msg, "chunked"));
- assert(testCase.expects_con_len == hasHeader(msg, "Content-Length:"));
+ let msg = "";
+ while (true) {
+ const buf = new Uint8Array(1024);
+ const readResult = await conn.read(buf);
+ if (!readResult) {
+ break;
+ }
+ msg += decoder.decode(buf.subarray(0, readResult));
+ try {
+ assert(
+ testCase.expects_chunked == hasHeader(msg, "Transfer-Encoding:"),
+ );
+ assert(testCase.expects_chunked == hasHeader(msg, "chunked"));
+ assert(testCase.expects_con_len == hasHeader(msg, "Content-Length:"));
- const n = msg.indexOf("\r\n\r\n") + 4;
+ const n = msg.indexOf("\r\n\r\n") + 4;
- if (testCase.expects_chunked) {
- assertEquals(msg.slice(n + 1, n + 3), "\r\n");
- assertEquals(msg.slice(msg.length - 7), "\r\n0\r\n\r\n");
- }
+ if (testCase.expects_chunked) {
+ assertEquals(msg.slice(n + 1, n + 3), "\r\n");
+ assertEquals(msg.slice(msg.length - 7), "\r\n0\r\n\r\n");
+ }
- if (testCase.expects_con_len && typeof testCase.body === "string") {
- assertEquals(msg.slice(n), testCase.body);
+ if (testCase.expects_con_len && typeof testCase.body === "string") {
+ assertEquals(msg.slice(n), testCase.body);
+ }
+ break;
+ } catch (e) {
+ continue;
}
- } catch (e) {
- console.error(e);
- throw e;
}
conn.close();
@@ -1419,11 +1425,19 @@ Deno.test(
const decoder = new TextDecoder();
{
- const buf = new Uint8Array(1024);
- const readResult = await conn.read(buf);
- assert(readResult);
- const msg = decoder.decode(buf.subarray(0, readResult));
- assert(msg.endsWith("\r\nfoo bar baz\r\n0\r\n\r\n"));
+ let msg = "";
+ while (true) {
+ try {
+ const buf = new Uint8Array(1024);
+ const readResult = await conn.read(buf);
+ assert(readResult);
+ msg += decoder.decode(buf.subarray(0, readResult));
+ assert(msg.endsWith("\r\nfoo bar baz\r\n0\r\n\r\n"));
+ break;
+ } catch {
+ continue;
+ }
+ }
}
// once more!
diff --git a/ext/flash/01_http.js b/ext/flash/01_http.js
index fbc24d73d..b00c9f8e4 100644
--- a/ext/flash/01_http.js
+++ b/ext/flash/01_http.js
@@ -188,6 +188,44 @@
return hostname === "0.0.0.0" ? "localhost" : hostname;
}
+ function writeFixedResponse(
+ server,
+ requestId,
+ response,
+ end,
+ respondFast,
+ ) {
+ let nwritten = 0;
+ // TypedArray
+ if (typeof response !== "string") {
+ nwritten = respondFast(requestId, response, end);
+ } else {
+ // string
+ const maybeResponse = stringResources[response];
+ if (maybeResponse === undefined) {
+ stringResources[response] = core.encode(response);
+ nwritten = core.ops.op_flash_respond(
+ server,
+ requestId,
+ stringResources[response],
+ end,
+ );
+ } else {
+ nwritten = respondFast(requestId, maybeResponse, end);
+ }
+ }
+
+ if (nwritten < response.length) {
+ core.opAsync(
+ "op_flash_respond_async",
+ server,
+ requestId,
+ response.slice(nwritten),
+ end,
+ );
+ }
+ }
+
async function serve(arg1, arg2) {
let options = undefined;
let handler = undefined;
@@ -320,7 +358,7 @@
}
// there might've been an HTTP upgrade.
if (resp === undefined) {
- continue;
+ return;
}
const innerResp = toInnerResponse(resp);
@@ -389,26 +427,13 @@
innerResp.headerList,
respBody,
);
-
- // TypedArray
- if (typeof responseStr !== "string") {
- respondFast(i, responseStr, !ws);
- } else {
- // string
- const maybeResponse = stringResources[responseStr];
- if (maybeResponse === undefined) {
- stringResources[responseStr] = core.encode(responseStr);
- core.ops.op_flash_respond(
- serverId,
- i,
- stringResources[responseStr],
- null,
- !ws, // Don't close socket if there is a deferred websocket upgrade.
- );
- } else {
- respondFast(i, maybeResponse, !ws);
- }
- }
+ writeFixedResponse(
+ serverId,
+ i,
+ responseStr,
+ !ws, // Don't close socket if there is a deferred websocket upgrade.
+ respondFast,
+ );
}
(async () => {
@@ -451,41 +476,26 @@
}
} else {
const reader = respBody.getReader();
- let first = true;
- a:
+ writeFixedResponse(
+ serverId,
+ i,
+ http1Response(
+ method,
+ innerResp.status ?? 200,
+ innerResp.headerList,
+ null,
+ ),
+ false,
+ respondFast,
+ );
while (true) {
const { value, done } = await reader.read();
- if (first) {
- first = false;
- core.ops.op_flash_respond(
- serverId,
- i,
- http1Response(
- method,
- innerResp.status ?? 200,
- innerResp.headerList,
- null,
- ),
- value ?? new Uint8Array(),
- false,
- );
- } else {
- if (value === undefined) {
- core.ops.op_flash_respond_chuncked(
- serverId,
- i,
- undefined,
- done,
- );
- } else {
- respondChunked(
- i,
- value,
- done,
- );
- }
- }
- if (done) break a;
+ await respondChunked(
+ i,
+ value,
+ done,
+ );
+ if (done) break;
}
}
}
@@ -528,18 +538,24 @@
once: true,
});
+ function respondChunked(token, chunk, shutdown) {
+ return core.opAsync(
+ "op_flash_respond_chuncked",
+ serverId,
+ token,
+ chunk,
+ shutdown,
+ );
+ }
+
const fastOp = prepareFastCalls();
let nextRequestSync = () => fastOp.nextRequest();
let getMethodSync = (token) => fastOp.getMethod(token);
- let respondChunked = (token, chunk, shutdown) =>
- fastOp.respondChunked(token, chunk, shutdown);
let respondFast = (token, response, shutdown) =>
fastOp.respond(token, response, shutdown);
if (serverId > 0) {
nextRequestSync = () => core.ops.op_flash_next_server(serverId);
getMethodSync = (token) => core.ops.op_flash_method(serverId, token);
- respondChunked = (token, chunk, shutdown) =>
- core.ops.op_flash_respond_chuncked(serverId, token, chunk, shutdown);
respondFast = (token, response, shutdown) =>
core.ops.op_flash_respond(serverId, token, response, null, shutdown);
}
diff --git a/ext/flash/lib.rs b/ext/flash/lib.rs
index 8ed1baaad..201753bea 100644
--- a/ext/flash/lib.rs
+++ b/ext/flash/lib.rs
@@ -3,6 +3,8 @@
// False positive lint for explicit drops.
// https://github.com/rust-lang/rust-clippy/issues/6446
#![allow(clippy::await_holding_lock)]
+// https://github.com/rust-lang/rust-clippy/issues/6353
+#![allow(clippy::await_holding_refcell_ref)]
use deno_core::error::generic_error;
use deno_core::error::type_error;
@@ -29,7 +31,6 @@ use http::header::TRANSFER_ENCODING;
use http::HeaderValue;
use log::trace;
use mio::net::TcpListener;
-use mio::net::TcpStream;
use mio::Events;
use mio::Interest;
use mio::Poll;
@@ -45,7 +46,6 @@ use std::intrinsics::transmute;
use std::io::BufReader;
use std::io::Read;
use std::io::Write;
-use std::marker::PhantomPinned;
use std::mem::replace;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
@@ -62,9 +62,12 @@ mod chunked;
mod request;
#[cfg(unix)]
mod sendfile;
+mod socket;
use request::InnerRequest;
use request::Request;
+use socket::InnerStream;
+use socket::Stream;
pub struct FlashContext {
next_server_id: u32,
@@ -84,83 +87,82 @@ pub struct ServerContext {
}
#[derive(Debug, PartialEq)]
-enum ParseStatus {
+pub enum ParseStatus {
None,
Ongoing(usize),
}
-type TlsTcpStream = rustls::StreamOwned<rustls::ServerConnection, TcpStream>;
-
-enum InnerStream {
- Tcp(TcpStream),
- Tls(Box<TlsTcpStream>),
-}
-
-pub struct Stream {
- inner: InnerStream,
- detached: bool,
- read_rx: Option<mpsc::Receiver<()>>,
- read_tx: Option<mpsc::Sender<()>>,
- parse_done: ParseStatus,
- buffer: UnsafeCell<Vec<u8>>,
- read_lock: Arc<Mutex<()>>,
- _pin: PhantomPinned,
+#[op]
+fn op_flash_respond(
+ op_state: &mut OpState,
+ server_id: u32,
+ token: u32,
+ response: StringOrBuffer,
+ shutdown: bool,
+) -> u32 {
+ let flash_ctx = op_state.borrow_mut::<FlashContext>();
+ let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
+ flash_respond(ctx, token, shutdown, &response)
}
-impl Stream {
- pub fn detach_ownership(&mut self) {
- self.detached = true;
- }
+#[op]
+async fn op_flash_respond_async(
+ state: Rc<RefCell<OpState>>,
+ server_id: u32,
+ token: u32,
+ response: StringOrBuffer,
+ shutdown: bool,
+) -> Result<(), AnyError> {
+ trace!("op_flash_respond_async");
- fn reattach_ownership(&mut self) {
- self.detached = false;
- }
-}
+ let mut close = false;
+ let sock = {
+ let mut op_state = state.borrow_mut();
+ let flash_ctx = op_state.borrow_mut::<FlashContext>();
+ let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
-impl Write for Stream {
- #[inline]
- fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
- match self.inner {
- InnerStream::Tcp(ref mut stream) => stream.write(buf),
- InnerStream::Tls(ref mut stream) => stream.write(buf),
- }
- }
- #[inline]
- fn flush(&mut self) -> std::io::Result<()> {
- match self.inner {
- InnerStream::Tcp(ref mut stream) => stream.flush(),
- InnerStream::Tls(ref mut stream) => stream.flush(),
+ match shutdown {
+ true => {
+ let tx = ctx.requests.remove(&token).unwrap();
+ close = !tx.keep_alive;
+ tx.socket()
+ }
+ // In case of a websocket upgrade or streaming response.
+ false => {
+ let tx = ctx.requests.get(&token).unwrap();
+ tx.socket()
+ }
}
- }
-}
+ };
-impl Read for Stream {
- #[inline]
- fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
- match self.inner {
- InnerStream::Tcp(ref mut stream) => stream.read(buf),
- InnerStream::Tls(ref mut stream) => stream.read(buf),
- }
+ sock
+ .with_async_stream(|stream| {
+ Box::pin(async move {
+ Ok(tokio::io::AsyncWriteExt::write(stream, &response).await?)
+ })
+ })
+ .await?;
+ // server is done writing and request doesn't want to kept alive.
+ if shutdown && close {
+ sock.shutdown();
}
+ Ok(())
}
#[op]
-fn op_flash_respond(
- op_state: &mut OpState,
+async fn op_flash_respond_chuncked(
+ op_state: Rc<RefCell<OpState>>,
server_id: u32,
token: u32,
- response: StringOrBuffer,
- maybe_body: Option<ZeroCopyBuf>,
+ response: Option<ZeroCopyBuf>,
shutdown: bool,
-) {
+) -> Result<(), AnyError> {
+ let mut op_state = op_state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
-
- let mut close = false;
let sock = match shutdown {
true => {
let tx = ctx.requests.remove(&token).unwrap();
- close = !tx.keep_alive;
tx.socket()
}
// In case of a websocket upgrade or streaming response.
@@ -170,49 +172,29 @@ fn op_flash_respond(
}
};
- sock.read_tx.take();
- sock.read_rx.take();
-
- let _ = sock.write(&response);
- if let Some(response) = maybe_body {
- let _ = sock.write(format!("{:x}", response.len()).as_bytes());
- let _ = sock.write(b"\r\n");
- let _ = sock.write(&response);
- let _ = sock.write(b"\r\n");
- }
+ drop(op_state);
+ sock
+ .with_async_stream(|stream| {
+ Box::pin(async move {
+ use tokio::io::AsyncWriteExt;
+ if let Some(response) = response {
+ stream
+ .write_all(format!("{:x}\r\n", response.len()).as_bytes())
+ .await?;
+ stream.write_all(&response).await?;
+ stream.write_all(b"\r\n").await?;
+ }
- // server is done writing and request doesn't want to kept alive.
- if shutdown && close {
- match &mut sock.inner {
- InnerStream::Tcp(stream) => {
- // Typically shutdown shouldn't fail.
- let _ = stream.shutdown(std::net::Shutdown::Both);
- }
- InnerStream::Tls(stream) => {
- let _ = stream.sock.shutdown(std::net::Shutdown::Both);
- }
- }
- }
-}
+ // The last chunk
+ if shutdown {
+ stream.write_all(b"0\r\n\r\n").await?;
+ }
-#[op]
-fn op_flash_respond_chuncked(
- op_state: &mut OpState,
- server_id: u32,
- token: u32,
- response: Option<ZeroCopyBuf>,
- shutdown: bool,
-) {
- let flash_ctx = op_state.borrow_mut::<FlashContext>();
- let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
- match response {
- Some(response) => {
- respond_chunked(ctx, token, shutdown, Some(&response));
- }
- None => {
- respond_chunked(ctx, token, shutdown, None);
- }
- }
+ Ok(())
+ })
+ })
+ .await?;
+ Ok(())
}
#[op]
@@ -258,24 +240,35 @@ async fn op_flash_write_resource(
}
}
- let _ = sock.write(b"Transfer-Encoding: chunked\r\n\r\n");
- loop {
- let vec = vec![0u8; 64 * 1024]; // 64KB
- let buf = ZeroCopyBuf::new_temp(vec);
- let (nread, buf) = resource.clone().read_return(buf).await?;
- if nread == 0 {
- let _ = sock.write(b"0\r\n\r\n");
- break;
- }
- let response = &buf[..nread];
-
- let _ = sock.write(format!("{:x}", response.len()).as_bytes());
- let _ = sock.write(b"\r\n");
- let _ = sock.write(response);
- let _ = sock.write(b"\r\n");
- }
+ sock
+ .with_async_stream(|stream| {
+ Box::pin(async move {
+ use tokio::io::AsyncWriteExt;
+ stream
+ .write_all(b"Transfer-Encoding: chunked\r\n\r\n")
+ .await?;
+ loop {
+ let vec = vec![0u8; 64 * 1024]; // 64KB
+ let buf = ZeroCopyBuf::new_temp(vec);
+ let (nread, buf) = resource.clone().read_return(buf).await?;
+ if nread == 0 {
+ stream.write_all(b"0\r\n\r\n").await?;
+ break;
+ }
- resource.close();
+ let response = &buf[..nread];
+ // TODO(@littledivy): use vectored writes.
+ stream
+ .write_all(format!("{:x}\r\n", response.len()).as_bytes())
+ .await?;
+ stream.write_all(response).await?;
+ stream.write_all(b"\r\n").await?;
+ }
+ resource.close();
+ Ok(())
+ })
+ })
+ .await?;
Ok(())
}
@@ -296,7 +289,7 @@ impl fast_api::FastFunction for RespondFast {
}
fn return_type(&self) -> fast_api::CType {
- fast_api::CType::Void
+ fast_api::CType::Uint32
}
}
@@ -305,128 +298,43 @@ fn flash_respond(
token: u32,
shutdown: bool,
response: &[u8],
-) {
- let mut close = false;
- let sock = match shutdown {
- true => {
- let tx = ctx.requests.remove(&token).unwrap();
- close = !tx.keep_alive;
- tx.socket()
- }
- // In case of a websocket upgrade or streaming response.
- false => {
- let tx = ctx.requests.get(&token).unwrap();
- tx.socket()
- }
- };
+) -> u32 {
+ let tx = ctx.requests.get(&token).unwrap();
+ let sock = tx.socket();
sock.read_tx.take();
sock.read_rx.take();
- let _ = sock.write(response);
- // server is done writing and request doesn't want to kept alive.
- if shutdown && close {
- match &mut sock.inner {
- InnerStream::Tcp(stream) => {
- // Typically shutdown shouldn't fail.
- let _ = stream.shutdown(std::net::Shutdown::Both);
- }
- InnerStream::Tls(stream) => {
- let _ = stream.sock.shutdown(std::net::Shutdown::Both);
- }
- }
- }
-}
-
-unsafe fn op_flash_respond_fast(
- recv: v8::Local<v8::Object>,
- token: u32,
- response: *const fast_api::FastApiTypedArray<u8>,
- shutdown: bool,
-) {
- let ptr =
- recv.get_aligned_pointer_from_internal_field(V8_WRAPPER_OBJECT_INDEX);
- let ctx = &mut *(ptr as *mut ServerContext);
-
- let response = &*response;
- if let Some(response) = response.get_storage_if_aligned() {
- flash_respond(ctx, token, shutdown, response);
- } else {
- todo!();
- }
-}
-
-pub struct RespondChunkedFast;
-
-impl fast_api::FastFunction for RespondChunkedFast {
- fn function(&self) -> *const c_void {
- op_flash_respond_chunked_fast as *const c_void
- }
+ let nwritten = sock.try_write(response);
- fn args(&self) -> &'static [fast_api::Type] {
- &[
- fast_api::Type::V8Value,
- fast_api::Type::Uint32,
- fast_api::Type::TypedArray(fast_api::CType::Uint8),
- fast_api::Type::Bool,
- ]
+ if shutdown && nwritten == response.len() {
+ if !tx.keep_alive {
+ sock.shutdown();
+ }
+ ctx.requests.remove(&token).unwrap();
}
- fn return_type(&self) -> fast_api::CType {
- fast_api::CType::Void
- }
+ nwritten as u32
}
-unsafe fn op_flash_respond_chunked_fast(
+unsafe fn op_flash_respond_fast(
recv: v8::Local<v8::Object>,
token: u32,
response: *const fast_api::FastApiTypedArray<u8>,
shutdown: bool,
-) {
+) -> u32 {
let ptr =
recv.get_aligned_pointer_from_internal_field(V8_WRAPPER_OBJECT_INDEX);
let ctx = &mut *(ptr as *mut ServerContext);
let response = &*response;
if let Some(response) = response.get_storage_if_aligned() {
- respond_chunked(ctx, token, shutdown, Some(response));
+ flash_respond(ctx, token, shutdown, response)
} else {
todo!();
}
}
-fn respond_chunked(
- ctx: &mut ServerContext,
- token: u32,
- shutdown: bool,
- response: Option<&[u8]>,
-) {
- let sock = match shutdown {
- true => {
- let tx = ctx.requests.remove(&token).unwrap();
- tx.socket()
- }
- // In case of a websocket upgrade or streaming response.
- false => {
- let tx = ctx.requests.get(&token).unwrap();
- tx.socket()
- }
- };
-
- if let Some(response) = response {
- let _ = sock.write(format!("{:x}", response.len()).as_bytes());
- let _ = sock.write(b"\r\n");
- let _ = sock.write(response);
- let _ = sock.write(b"\r\n");
- }
-
- // The last chunk
- if shutdown {
- let _ = sock.write(b"0\r\n\r\n");
- }
- sock.reattach_ownership();
-}
-
macro_rules! get_request {
($op_state: ident, $token: ident) => {
get_request!($op_state, 0, $token)
@@ -631,51 +539,12 @@ fn op_flash_make_request<'scope>(
obj.set(scope, key.into(), func).unwrap();
}
- // respondChunked
- {
- let builder = v8::FunctionTemplate::builder(
- |scope: &mut v8::HandleScope,
- args: v8::FunctionCallbackArguments,
- _: v8::ReturnValue| {
- let external: v8::Local<v8::External> =
- args.data().unwrap().try_into().unwrap();
- // SAFETY: This external is guaranteed to be a pointer to a ServerContext
- let ctx = unsafe { &mut *(external.value() as *mut ServerContext) };
-
- let token = args.get(0).uint32_value(scope).unwrap();
-
- let response: v8::Local<v8::ArrayBufferView> =
- args.get(1).try_into().unwrap();
- let ab = response.buffer(scope).unwrap();
- let store = ab.get_backing_store();
- let (offset, len) = (response.byte_offset(), response.byte_length());
- // SAFETY: v8::SharedRef<v8::BackingStore> is similar to Arc<[u8]>,
- // it points to a fixed continuous slice of bytes on the heap.
- // We assume it's initialized and thus safe to read (though may not contain meaningful data)
- let response = unsafe {
- &*(&store[offset..offset + len] as *const _ as *const [u8])
- };
-
- let shutdown = args.get(2).boolean_value(scope);
-
- respond_chunked(ctx, token, shutdown, Some(response));
- },
- )
- .data(v8::External::new(scope, ctx as *mut _).into());
-
- let func = builder.build_fast(scope, &RespondChunkedFast, None);
- let func: v8::Local<v8::Value> = func.get_function(scope).unwrap().into();
-
- let key = v8::String::new(scope, "respondChunked").unwrap();
- obj.set(scope, key.into(), func).unwrap();
- }
-
// respond
{
let builder = v8::FunctionTemplate::builder(
|scope: &mut v8::HandleScope,
args: v8::FunctionCallbackArguments,
- _: v8::ReturnValue| {
+ mut rv: v8::ReturnValue| {
let external: v8::Local<v8::External> =
args.data().unwrap().try_into().unwrap();
// SAFETY: This external is guaranteed to be a pointer to a ServerContext
@@ -697,7 +566,7 @@ fn op_flash_make_request<'scope>(
let shutdown = args.get(2).boolean_value(scope);
- flash_respond(ctx, token, shutdown, response);
+ rv.set_uint32(flash_respond(ctx, token, shutdown, response));
},
)
.data(v8::External::new(scope, ctx as *mut _).into());
@@ -1024,7 +893,6 @@ fn run_server(
read_lock: Arc::new(Mutex::new(())),
parse_done: ParseStatus::None,
buffer: UnsafeCell::new(vec![0_u8; 1024]),
- _pin: PhantomPinned,
});
trace!("New connection: {}", token.0);
@@ -1521,6 +1389,7 @@ pub fn init<P: FlashPermissions + 'static>(unstable: bool) -> Extension {
.ops(vec![
op_flash_serve::decl::<P>(),
op_flash_respond::decl(),
+ op_flash_respond_async::decl(),
op_flash_respond_chuncked::decl(),
op_flash_method::decl(),
op_flash_path::decl(),
diff --git a/ext/flash/socket.rs b/ext/flash/socket.rs
new file mode 100644
index 000000000..8256be8a0
--- /dev/null
+++ b/ext/flash/socket.rs
@@ -0,0 +1,147 @@
+use deno_core::error::AnyError;
+use mio::net::TcpStream;
+use std::{
+ cell::UnsafeCell,
+ future::Future,
+ io::{Read, Write},
+ pin::Pin,
+ sync::{Arc, Mutex},
+};
+use tokio::sync::mpsc;
+
+use crate::ParseStatus;
+
+type TlsTcpStream = rustls::StreamOwned<rustls::ServerConnection, TcpStream>;
+
+pub enum InnerStream {
+ Tcp(TcpStream),
+ Tls(Box<TlsTcpStream>),
+}
+
+pub struct Stream {
+ pub inner: InnerStream,
+ pub detached: bool,
+ pub read_rx: Option<mpsc::Receiver<()>>,
+ pub read_tx: Option<mpsc::Sender<()>>,
+ pub parse_done: ParseStatus,
+ pub buffer: UnsafeCell<Vec<u8>>,
+ pub read_lock: Arc<Mutex<()>>,
+}
+
+impl Stream {
+ pub fn detach_ownership(&mut self) {
+ self.detached = true;
+ }
+
+ /// Try to write to the socket.
+ #[inline]
+ pub fn try_write(&mut self, buf: &[u8]) -> usize {
+ let mut nwritten = 0;
+ while nwritten < buf.len() {
+ match self.write(&buf[nwritten..]) {
+ Ok(n) => nwritten += n,
+ Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
+ break;
+ }
+ Err(e) => {
+ log::trace!("Error writing to socket: {}", e);
+ break;
+ }
+ }
+ }
+ nwritten
+ }
+
+ #[inline]
+ pub fn shutdown(&mut self) {
+ match &mut self.inner {
+ InnerStream::Tcp(stream) => {
+ // Typically shutdown shouldn't fail.
+ let _ = stream.shutdown(std::net::Shutdown::Both);
+ }
+ InnerStream::Tls(stream) => {
+ let _ = stream.sock.shutdown(std::net::Shutdown::Both);
+ }
+ }
+ }
+
+ pub fn as_std(&mut self) -> std::net::TcpStream {
+ #[cfg(unix)]
+ let std_stream = {
+ use std::os::unix::prelude::AsRawFd;
+ use std::os::unix::prelude::FromRawFd;
+ let fd = match self.inner {
+ InnerStream::Tcp(ref tcp) => tcp.as_raw_fd(),
+ _ => todo!(),
+ };
+ // SAFETY: `fd` is a valid file descriptor.
+ unsafe { std::net::TcpStream::from_raw_fd(fd) }
+ };
+ #[cfg(windows)]
+ let std_stream = {
+ use std::os::windows::prelude::AsRawSocket;
+ use std::os::windows::prelude::FromRawSocket;
+ let fd = match self.inner {
+ InnerStream::Tcp(ref tcp) => tcp.as_raw_socket(),
+ _ => todo!(),
+ };
+ // SAFETY: `fd` is a valid file descriptor.
+ unsafe { std::net::TcpStream::from_raw_socket(fd) }
+ };
+ std_stream
+ }
+
+ #[inline]
+ pub async fn with_async_stream<F, T>(&mut self, f: F) -> Result<T, AnyError>
+ where
+ F: FnOnce(
+ &mut tokio::net::TcpStream,
+ ) -> Pin<Box<dyn '_ + Future<Output = Result<T, AnyError>>>>,
+ {
+ let mut async_stream = tokio::net::TcpStream::from_std(self.as_std())?;
+ let result = f(&mut async_stream).await?;
+ forget_stream(async_stream.into_std()?);
+ Ok(result)
+ }
+}
+
+#[inline]
+pub fn forget_stream(stream: std::net::TcpStream) {
+ #[cfg(unix)]
+ {
+ use std::os::unix::prelude::IntoRawFd;
+ let _ = stream.into_raw_fd();
+ }
+ #[cfg(windows)]
+ {
+ use std::os::windows::prelude::IntoRawSocket;
+ let _ = stream.into_raw_socket();
+ }
+}
+
+impl Write for Stream {
+ #[inline]
+ fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
+ match self.inner {
+ InnerStream::Tcp(ref mut stream) => stream.write(buf),
+ InnerStream::Tls(ref mut stream) => stream.write(buf),
+ }
+ }
+ #[inline]
+ fn flush(&mut self) -> std::io::Result<()> {
+ match self.inner {
+ InnerStream::Tcp(ref mut stream) => stream.flush(),
+ InnerStream::Tls(ref mut stream) => stream.flush(),
+ }
+ }
+}
+
+impl Read for Stream {
+ #[inline]
+ fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
+ match self.inner {
+ InnerStream::Tcp(ref mut stream) => stream.read(buf),
+ InnerStream::Tls(ref mut stream) => stream.read(buf),
+ }
+ }
+}