diff options
-rw-r--r-- | cli/tests/unit/flash_test.ts | 62 | ||||
-rw-r--r-- | ext/flash/01_http.js | 132 | ||||
-rw-r--r-- | ext/flash/lib.rs | 381 | ||||
-rw-r--r-- | ext/flash/socket.rs | 147 |
4 files changed, 384 insertions, 338 deletions
diff --git a/cli/tests/unit/flash_test.ts b/cli/tests/unit/flash_test.ts index fef45beb9..78340a390 100644 --- a/cli/tests/unit/flash_test.ts +++ b/cli/tests/unit/flash_test.ts @@ -1283,29 +1283,35 @@ function createServerLengthTest(name: string, testCase: TestCase) { await promise; const decoder = new TextDecoder(); - const buf = new Uint8Array(1024); - const readResult = await conn.read(buf); - assert(readResult); - const msg = decoder.decode(buf.subarray(0, readResult)); - - try { - assert(testCase.expects_chunked == hasHeader(msg, "Transfer-Encoding:")); - assert(testCase.expects_chunked == hasHeader(msg, "chunked")); - assert(testCase.expects_con_len == hasHeader(msg, "Content-Length:")); + let msg = ""; + while (true) { + const buf = new Uint8Array(1024); + const readResult = await conn.read(buf); + if (!readResult) { + break; + } + msg += decoder.decode(buf.subarray(0, readResult)); + try { + assert( + testCase.expects_chunked == hasHeader(msg, "Transfer-Encoding:"), + ); + assert(testCase.expects_chunked == hasHeader(msg, "chunked")); + assert(testCase.expects_con_len == hasHeader(msg, "Content-Length:")); - const n = msg.indexOf("\r\n\r\n") + 4; + const n = msg.indexOf("\r\n\r\n") + 4; - if (testCase.expects_chunked) { - assertEquals(msg.slice(n + 1, n + 3), "\r\n"); - assertEquals(msg.slice(msg.length - 7), "\r\n0\r\n\r\n"); - } + if (testCase.expects_chunked) { + assertEquals(msg.slice(n + 1, n + 3), "\r\n"); + assertEquals(msg.slice(msg.length - 7), "\r\n0\r\n\r\n"); + } - if (testCase.expects_con_len && typeof testCase.body === "string") { - assertEquals(msg.slice(n), testCase.body); + if (testCase.expects_con_len && typeof testCase.body === "string") { + assertEquals(msg.slice(n), testCase.body); + } + break; + } catch (e) { + continue; } - } catch (e) { - console.error(e); - throw e; } conn.close(); @@ -1419,11 +1425,19 @@ Deno.test( const decoder = new TextDecoder(); { - const buf = new Uint8Array(1024); - const readResult = await conn.read(buf); - assert(readResult); - const msg = decoder.decode(buf.subarray(0, readResult)); - assert(msg.endsWith("\r\nfoo bar baz\r\n0\r\n\r\n")); + let msg = ""; + while (true) { + try { + const buf = new Uint8Array(1024); + const readResult = await conn.read(buf); + assert(readResult); + msg += decoder.decode(buf.subarray(0, readResult)); + assert(msg.endsWith("\r\nfoo bar baz\r\n0\r\n\r\n")); + break; + } catch { + continue; + } + } } // once more! diff --git a/ext/flash/01_http.js b/ext/flash/01_http.js index fbc24d73d..b00c9f8e4 100644 --- a/ext/flash/01_http.js +++ b/ext/flash/01_http.js @@ -188,6 +188,44 @@ return hostname === "0.0.0.0" ? "localhost" : hostname; } + function writeFixedResponse( + server, + requestId, + response, + end, + respondFast, + ) { + let nwritten = 0; + // TypedArray + if (typeof response !== "string") { + nwritten = respondFast(requestId, response, end); + } else { + // string + const maybeResponse = stringResources[response]; + if (maybeResponse === undefined) { + stringResources[response] = core.encode(response); + nwritten = core.ops.op_flash_respond( + server, + requestId, + stringResources[response], + end, + ); + } else { + nwritten = respondFast(requestId, maybeResponse, end); + } + } + + if (nwritten < response.length) { + core.opAsync( + "op_flash_respond_async", + server, + requestId, + response.slice(nwritten), + end, + ); + } + } + async function serve(arg1, arg2) { let options = undefined; let handler = undefined; @@ -320,7 +358,7 @@ } // there might've been an HTTP upgrade. if (resp === undefined) { - continue; + return; } const innerResp = toInnerResponse(resp); @@ -389,26 +427,13 @@ innerResp.headerList, respBody, ); - - // TypedArray - if (typeof responseStr !== "string") { - respondFast(i, responseStr, !ws); - } else { - // string - const maybeResponse = stringResources[responseStr]; - if (maybeResponse === undefined) { - stringResources[responseStr] = core.encode(responseStr); - core.ops.op_flash_respond( - serverId, - i, - stringResources[responseStr], - null, - !ws, // Don't close socket if there is a deferred websocket upgrade. - ); - } else { - respondFast(i, maybeResponse, !ws); - } - } + writeFixedResponse( + serverId, + i, + responseStr, + !ws, // Don't close socket if there is a deferred websocket upgrade. + respondFast, + ); } (async () => { @@ -451,41 +476,26 @@ } } else { const reader = respBody.getReader(); - let first = true; - a: + writeFixedResponse( + serverId, + i, + http1Response( + method, + innerResp.status ?? 200, + innerResp.headerList, + null, + ), + false, + respondFast, + ); while (true) { const { value, done } = await reader.read(); - if (first) { - first = false; - core.ops.op_flash_respond( - serverId, - i, - http1Response( - method, - innerResp.status ?? 200, - innerResp.headerList, - null, - ), - value ?? new Uint8Array(), - false, - ); - } else { - if (value === undefined) { - core.ops.op_flash_respond_chuncked( - serverId, - i, - undefined, - done, - ); - } else { - respondChunked( - i, - value, - done, - ); - } - } - if (done) break a; + await respondChunked( + i, + value, + done, + ); + if (done) break; } } } @@ -528,18 +538,24 @@ once: true, }); + function respondChunked(token, chunk, shutdown) { + return core.opAsync( + "op_flash_respond_chuncked", + serverId, + token, + chunk, + shutdown, + ); + } + const fastOp = prepareFastCalls(); let nextRequestSync = () => fastOp.nextRequest(); let getMethodSync = (token) => fastOp.getMethod(token); - let respondChunked = (token, chunk, shutdown) => - fastOp.respondChunked(token, chunk, shutdown); let respondFast = (token, response, shutdown) => fastOp.respond(token, response, shutdown); if (serverId > 0) { nextRequestSync = () => core.ops.op_flash_next_server(serverId); getMethodSync = (token) => core.ops.op_flash_method(serverId, token); - respondChunked = (token, chunk, shutdown) => - core.ops.op_flash_respond_chuncked(serverId, token, chunk, shutdown); respondFast = (token, response, shutdown) => core.ops.op_flash_respond(serverId, token, response, null, shutdown); } diff --git a/ext/flash/lib.rs b/ext/flash/lib.rs index 8ed1baaad..201753bea 100644 --- a/ext/flash/lib.rs +++ b/ext/flash/lib.rs @@ -3,6 +3,8 @@ // False positive lint for explicit drops. // https://github.com/rust-lang/rust-clippy/issues/6446 #![allow(clippy::await_holding_lock)] +// https://github.com/rust-lang/rust-clippy/issues/6353 +#![allow(clippy::await_holding_refcell_ref)] use deno_core::error::generic_error; use deno_core::error::type_error; @@ -29,7 +31,6 @@ use http::header::TRANSFER_ENCODING; use http::HeaderValue; use log::trace; use mio::net::TcpListener; -use mio::net::TcpStream; use mio::Events; use mio::Interest; use mio::Poll; @@ -45,7 +46,6 @@ use std::intrinsics::transmute; use std::io::BufReader; use std::io::Read; use std::io::Write; -use std::marker::PhantomPinned; use std::mem::replace; use std::net::SocketAddr; use std::net::ToSocketAddrs; @@ -62,9 +62,12 @@ mod chunked; mod request; #[cfg(unix)] mod sendfile; +mod socket; use request::InnerRequest; use request::Request; +use socket::InnerStream; +use socket::Stream; pub struct FlashContext { next_server_id: u32, @@ -84,83 +87,82 @@ pub struct ServerContext { } #[derive(Debug, PartialEq)] -enum ParseStatus { +pub enum ParseStatus { None, Ongoing(usize), } -type TlsTcpStream = rustls::StreamOwned<rustls::ServerConnection, TcpStream>; - -enum InnerStream { - Tcp(TcpStream), - Tls(Box<TlsTcpStream>), -} - -pub struct Stream { - inner: InnerStream, - detached: bool, - read_rx: Option<mpsc::Receiver<()>>, - read_tx: Option<mpsc::Sender<()>>, - parse_done: ParseStatus, - buffer: UnsafeCell<Vec<u8>>, - read_lock: Arc<Mutex<()>>, - _pin: PhantomPinned, +#[op] +fn op_flash_respond( + op_state: &mut OpState, + server_id: u32, + token: u32, + response: StringOrBuffer, + shutdown: bool, +) -> u32 { + let flash_ctx = op_state.borrow_mut::<FlashContext>(); + let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); + flash_respond(ctx, token, shutdown, &response) } -impl Stream { - pub fn detach_ownership(&mut self) { - self.detached = true; - } +#[op] +async fn op_flash_respond_async( + state: Rc<RefCell<OpState>>, + server_id: u32, + token: u32, + response: StringOrBuffer, + shutdown: bool, +) -> Result<(), AnyError> { + trace!("op_flash_respond_async"); - fn reattach_ownership(&mut self) { - self.detached = false; - } -} + let mut close = false; + let sock = { + let mut op_state = state.borrow_mut(); + let flash_ctx = op_state.borrow_mut::<FlashContext>(); + let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); -impl Write for Stream { - #[inline] - fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { - match self.inner { - InnerStream::Tcp(ref mut stream) => stream.write(buf), - InnerStream::Tls(ref mut stream) => stream.write(buf), - } - } - #[inline] - fn flush(&mut self) -> std::io::Result<()> { - match self.inner { - InnerStream::Tcp(ref mut stream) => stream.flush(), - InnerStream::Tls(ref mut stream) => stream.flush(), + match shutdown { + true => { + let tx = ctx.requests.remove(&token).unwrap(); + close = !tx.keep_alive; + tx.socket() + } + // In case of a websocket upgrade or streaming response. + false => { + let tx = ctx.requests.get(&token).unwrap(); + tx.socket() + } } - } -} + }; -impl Read for Stream { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { - match self.inner { - InnerStream::Tcp(ref mut stream) => stream.read(buf), - InnerStream::Tls(ref mut stream) => stream.read(buf), - } + sock + .with_async_stream(|stream| { + Box::pin(async move { + Ok(tokio::io::AsyncWriteExt::write(stream, &response).await?) + }) + }) + .await?; + // server is done writing and request doesn't want to kept alive. + if shutdown && close { + sock.shutdown(); } + Ok(()) } #[op] -fn op_flash_respond( - op_state: &mut OpState, +async fn op_flash_respond_chuncked( + op_state: Rc<RefCell<OpState>>, server_id: u32, token: u32, - response: StringOrBuffer, - maybe_body: Option<ZeroCopyBuf>, + response: Option<ZeroCopyBuf>, shutdown: bool, -) { +) -> Result<(), AnyError> { + let mut op_state = op_state.borrow_mut(); let flash_ctx = op_state.borrow_mut::<FlashContext>(); let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); - - let mut close = false; let sock = match shutdown { true => { let tx = ctx.requests.remove(&token).unwrap(); - close = !tx.keep_alive; tx.socket() } // In case of a websocket upgrade or streaming response. @@ -170,49 +172,29 @@ fn op_flash_respond( } }; - sock.read_tx.take(); - sock.read_rx.take(); - - let _ = sock.write(&response); - if let Some(response) = maybe_body { - let _ = sock.write(format!("{:x}", response.len()).as_bytes()); - let _ = sock.write(b"\r\n"); - let _ = sock.write(&response); - let _ = sock.write(b"\r\n"); - } + drop(op_state); + sock + .with_async_stream(|stream| { + Box::pin(async move { + use tokio::io::AsyncWriteExt; + if let Some(response) = response { + stream + .write_all(format!("{:x}\r\n", response.len()).as_bytes()) + .await?; + stream.write_all(&response).await?; + stream.write_all(b"\r\n").await?; + } - // server is done writing and request doesn't want to kept alive. - if shutdown && close { - match &mut sock.inner { - InnerStream::Tcp(stream) => { - // Typically shutdown shouldn't fail. - let _ = stream.shutdown(std::net::Shutdown::Both); - } - InnerStream::Tls(stream) => { - let _ = stream.sock.shutdown(std::net::Shutdown::Both); - } - } - } -} + // The last chunk + if shutdown { + stream.write_all(b"0\r\n\r\n").await?; + } -#[op] -fn op_flash_respond_chuncked( - op_state: &mut OpState, - server_id: u32, - token: u32, - response: Option<ZeroCopyBuf>, - shutdown: bool, -) { - let flash_ctx = op_state.borrow_mut::<FlashContext>(); - let ctx = flash_ctx.servers.get_mut(&server_id).unwrap(); - match response { - Some(response) => { - respond_chunked(ctx, token, shutdown, Some(&response)); - } - None => { - respond_chunked(ctx, token, shutdown, None); - } - } + Ok(()) + }) + }) + .await?; + Ok(()) } #[op] @@ -258,24 +240,35 @@ async fn op_flash_write_resource( } } - let _ = sock.write(b"Transfer-Encoding: chunked\r\n\r\n"); - loop { - let vec = vec![0u8; 64 * 1024]; // 64KB - let buf = ZeroCopyBuf::new_temp(vec); - let (nread, buf) = resource.clone().read_return(buf).await?; - if nread == 0 { - let _ = sock.write(b"0\r\n\r\n"); - break; - } - let response = &buf[..nread]; - - let _ = sock.write(format!("{:x}", response.len()).as_bytes()); - let _ = sock.write(b"\r\n"); - let _ = sock.write(response); - let _ = sock.write(b"\r\n"); - } + sock + .with_async_stream(|stream| { + Box::pin(async move { + use tokio::io::AsyncWriteExt; + stream + .write_all(b"Transfer-Encoding: chunked\r\n\r\n") + .await?; + loop { + let vec = vec![0u8; 64 * 1024]; // 64KB + let buf = ZeroCopyBuf::new_temp(vec); + let (nread, buf) = resource.clone().read_return(buf).await?; + if nread == 0 { + stream.write_all(b"0\r\n\r\n").await?; + break; + } - resource.close(); + let response = &buf[..nread]; + // TODO(@littledivy): use vectored writes. + stream + .write_all(format!("{:x}\r\n", response.len()).as_bytes()) + .await?; + stream.write_all(response).await?; + stream.write_all(b"\r\n").await?; + } + resource.close(); + Ok(()) + }) + }) + .await?; Ok(()) } @@ -296,7 +289,7 @@ impl fast_api::FastFunction for RespondFast { } fn return_type(&self) -> fast_api::CType { - fast_api::CType::Void + fast_api::CType::Uint32 } } @@ -305,128 +298,43 @@ fn flash_respond( token: u32, shutdown: bool, response: &[u8], -) { - let mut close = false; - let sock = match shutdown { - true => { - let tx = ctx.requests.remove(&token).unwrap(); - close = !tx.keep_alive; - tx.socket() - } - // In case of a websocket upgrade or streaming response. - false => { - let tx = ctx.requests.get(&token).unwrap(); - tx.socket() - } - }; +) -> u32 { + let tx = ctx.requests.get(&token).unwrap(); + let sock = tx.socket(); sock.read_tx.take(); sock.read_rx.take(); - let _ = sock.write(response); - // server is done writing and request doesn't want to kept alive. - if shutdown && close { - match &mut sock.inner { - InnerStream::Tcp(stream) => { - // Typically shutdown shouldn't fail. - let _ = stream.shutdown(std::net::Shutdown::Both); - } - InnerStream::Tls(stream) => { - let _ = stream.sock.shutdown(std::net::Shutdown::Both); - } - } - } -} - -unsafe fn op_flash_respond_fast( - recv: v8::Local<v8::Object>, - token: u32, - response: *const fast_api::FastApiTypedArray<u8>, - shutdown: bool, -) { - let ptr = - recv.get_aligned_pointer_from_internal_field(V8_WRAPPER_OBJECT_INDEX); - let ctx = &mut *(ptr as *mut ServerContext); - - let response = &*response; - if let Some(response) = response.get_storage_if_aligned() { - flash_respond(ctx, token, shutdown, response); - } else { - todo!(); - } -} - -pub struct RespondChunkedFast; - -impl fast_api::FastFunction for RespondChunkedFast { - fn function(&self) -> *const c_void { - op_flash_respond_chunked_fast as *const c_void - } + let nwritten = sock.try_write(response); - fn args(&self) -> &'static [fast_api::Type] { - &[ - fast_api::Type::V8Value, - fast_api::Type::Uint32, - fast_api::Type::TypedArray(fast_api::CType::Uint8), - fast_api::Type::Bool, - ] + if shutdown && nwritten == response.len() { + if !tx.keep_alive { + sock.shutdown(); + } + ctx.requests.remove(&token).unwrap(); } - fn return_type(&self) -> fast_api::CType { - fast_api::CType::Void - } + nwritten as u32 } -unsafe fn op_flash_respond_chunked_fast( +unsafe fn op_flash_respond_fast( recv: v8::Local<v8::Object>, token: u32, response: *const fast_api::FastApiTypedArray<u8>, shutdown: bool, -) { +) -> u32 { let ptr = recv.get_aligned_pointer_from_internal_field(V8_WRAPPER_OBJECT_INDEX); let ctx = &mut *(ptr as *mut ServerContext); let response = &*response; if let Some(response) = response.get_storage_if_aligned() { - respond_chunked(ctx, token, shutdown, Some(response)); + flash_respond(ctx, token, shutdown, response) } else { todo!(); } } -fn respond_chunked( - ctx: &mut ServerContext, - token: u32, - shutdown: bool, - response: Option<&[u8]>, -) { - let sock = match shutdown { - true => { - let tx = ctx.requests.remove(&token).unwrap(); - tx.socket() - } - // In case of a websocket upgrade or streaming response. - false => { - let tx = ctx.requests.get(&token).unwrap(); - tx.socket() - } - }; - - if let Some(response) = response { - let _ = sock.write(format!("{:x}", response.len()).as_bytes()); - let _ = sock.write(b"\r\n"); - let _ = sock.write(response); - let _ = sock.write(b"\r\n"); - } - - // The last chunk - if shutdown { - let _ = sock.write(b"0\r\n\r\n"); - } - sock.reattach_ownership(); -} - macro_rules! get_request { ($op_state: ident, $token: ident) => { get_request!($op_state, 0, $token) @@ -631,51 +539,12 @@ fn op_flash_make_request<'scope>( obj.set(scope, key.into(), func).unwrap(); } - // respondChunked - { - let builder = v8::FunctionTemplate::builder( - |scope: &mut v8::HandleScope, - args: v8::FunctionCallbackArguments, - _: v8::ReturnValue| { - let external: v8::Local<v8::External> = - args.data().unwrap().try_into().unwrap(); - // SAFETY: This external is guaranteed to be a pointer to a ServerContext - let ctx = unsafe { &mut *(external.value() as *mut ServerContext) }; - - let token = args.get(0).uint32_value(scope).unwrap(); - - let response: v8::Local<v8::ArrayBufferView> = - args.get(1).try_into().unwrap(); - let ab = response.buffer(scope).unwrap(); - let store = ab.get_backing_store(); - let (offset, len) = (response.byte_offset(), response.byte_length()); - // SAFETY: v8::SharedRef<v8::BackingStore> is similar to Arc<[u8]>, - // it points to a fixed continuous slice of bytes on the heap. - // We assume it's initialized and thus safe to read (though may not contain meaningful data) - let response = unsafe { - &*(&store[offset..offset + len] as *const _ as *const [u8]) - }; - - let shutdown = args.get(2).boolean_value(scope); - - respond_chunked(ctx, token, shutdown, Some(response)); - }, - ) - .data(v8::External::new(scope, ctx as *mut _).into()); - - let func = builder.build_fast(scope, &RespondChunkedFast, None); - let func: v8::Local<v8::Value> = func.get_function(scope).unwrap().into(); - - let key = v8::String::new(scope, "respondChunked").unwrap(); - obj.set(scope, key.into(), func).unwrap(); - } - // respond { let builder = v8::FunctionTemplate::builder( |scope: &mut v8::HandleScope, args: v8::FunctionCallbackArguments, - _: v8::ReturnValue| { + mut rv: v8::ReturnValue| { let external: v8::Local<v8::External> = args.data().unwrap().try_into().unwrap(); // SAFETY: This external is guaranteed to be a pointer to a ServerContext @@ -697,7 +566,7 @@ fn op_flash_make_request<'scope>( let shutdown = args.get(2).boolean_value(scope); - flash_respond(ctx, token, shutdown, response); + rv.set_uint32(flash_respond(ctx, token, shutdown, response)); }, ) .data(v8::External::new(scope, ctx as *mut _).into()); @@ -1024,7 +893,6 @@ fn run_server( read_lock: Arc::new(Mutex::new(())), parse_done: ParseStatus::None, buffer: UnsafeCell::new(vec![0_u8; 1024]), - _pin: PhantomPinned, }); trace!("New connection: {}", token.0); @@ -1521,6 +1389,7 @@ pub fn init<P: FlashPermissions + 'static>(unstable: bool) -> Extension { .ops(vec![ op_flash_serve::decl::<P>(), op_flash_respond::decl(), + op_flash_respond_async::decl(), op_flash_respond_chuncked::decl(), op_flash_method::decl(), op_flash_path::decl(), diff --git a/ext/flash/socket.rs b/ext/flash/socket.rs new file mode 100644 index 000000000..8256be8a0 --- /dev/null +++ b/ext/flash/socket.rs @@ -0,0 +1,147 @@ +use deno_core::error::AnyError; +use mio::net::TcpStream; +use std::{ + cell::UnsafeCell, + future::Future, + io::{Read, Write}, + pin::Pin, + sync::{Arc, Mutex}, +}; +use tokio::sync::mpsc; + +use crate::ParseStatus; + +type TlsTcpStream = rustls::StreamOwned<rustls::ServerConnection, TcpStream>; + +pub enum InnerStream { + Tcp(TcpStream), + Tls(Box<TlsTcpStream>), +} + +pub struct Stream { + pub inner: InnerStream, + pub detached: bool, + pub read_rx: Option<mpsc::Receiver<()>>, + pub read_tx: Option<mpsc::Sender<()>>, + pub parse_done: ParseStatus, + pub buffer: UnsafeCell<Vec<u8>>, + pub read_lock: Arc<Mutex<()>>, +} + +impl Stream { + pub fn detach_ownership(&mut self) { + self.detached = true; + } + + /// Try to write to the socket. + #[inline] + pub fn try_write(&mut self, buf: &[u8]) -> usize { + let mut nwritten = 0; + while nwritten < buf.len() { + match self.write(&buf[nwritten..]) { + Ok(n) => nwritten += n, + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + break; + } + Err(e) => { + log::trace!("Error writing to socket: {}", e); + break; + } + } + } + nwritten + } + + #[inline] + pub fn shutdown(&mut self) { + match &mut self.inner { + InnerStream::Tcp(stream) => { + // Typically shutdown shouldn't fail. + let _ = stream.shutdown(std::net::Shutdown::Both); + } + InnerStream::Tls(stream) => { + let _ = stream.sock.shutdown(std::net::Shutdown::Both); + } + } + } + + pub fn as_std(&mut self) -> std::net::TcpStream { + #[cfg(unix)] + let std_stream = { + use std::os::unix::prelude::AsRawFd; + use std::os::unix::prelude::FromRawFd; + let fd = match self.inner { + InnerStream::Tcp(ref tcp) => tcp.as_raw_fd(), + _ => todo!(), + }; + // SAFETY: `fd` is a valid file descriptor. + unsafe { std::net::TcpStream::from_raw_fd(fd) } + }; + #[cfg(windows)] + let std_stream = { + use std::os::windows::prelude::AsRawSocket; + use std::os::windows::prelude::FromRawSocket; + let fd = match self.inner { + InnerStream::Tcp(ref tcp) => tcp.as_raw_socket(), + _ => todo!(), + }; + // SAFETY: `fd` is a valid file descriptor. + unsafe { std::net::TcpStream::from_raw_socket(fd) } + }; + std_stream + } + + #[inline] + pub async fn with_async_stream<F, T>(&mut self, f: F) -> Result<T, AnyError> + where + F: FnOnce( + &mut tokio::net::TcpStream, + ) -> Pin<Box<dyn '_ + Future<Output = Result<T, AnyError>>>>, + { + let mut async_stream = tokio::net::TcpStream::from_std(self.as_std())?; + let result = f(&mut async_stream).await?; + forget_stream(async_stream.into_std()?); + Ok(result) + } +} + +#[inline] +pub fn forget_stream(stream: std::net::TcpStream) { + #[cfg(unix)] + { + use std::os::unix::prelude::IntoRawFd; + let _ = stream.into_raw_fd(); + } + #[cfg(windows)] + { + use std::os::windows::prelude::IntoRawSocket; + let _ = stream.into_raw_socket(); + } +} + +impl Write for Stream { + #[inline] + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { + match self.inner { + InnerStream::Tcp(ref mut stream) => stream.write(buf), + InnerStream::Tls(ref mut stream) => stream.write(buf), + } + } + #[inline] + fn flush(&mut self) -> std::io::Result<()> { + match self.inner { + InnerStream::Tcp(ref mut stream) => stream.flush(), + InnerStream::Tls(ref mut stream) => stream.flush(), + } + } +} + +impl Read for Stream { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { + match self.inner { + InnerStream::Tcp(ref mut stream) => stream.read(buf), + InnerStream::Tls(ref mut stream) => stream.read(buf), + } + } +} |