summaryrefslogtreecommitdiff
path: root/ext/flash/lib.rs
diff options
context:
space:
mode:
authorYusuke Tanaka <yusuktan@maguro.dev>2022-11-25 02:38:09 +0900
committerGitHub <noreply@github.com>2022-11-24 18:38:09 +0100
commitfd023cf7937e67dfde5482d34ebc60839eb7397c (patch)
tree816c976254071ecd9c15a35ad6b68d78066428d1 /ext/flash/lib.rs
parentb6f49cf4790926df125add2329611a8eff8db9da (diff)
fix(ext/flash): graceful server startup/shutdown with unsettled promises in mind (#16616)
This PR resets the revert commit made by #16610, bringing back #16383 which attempts to fix the issue happening when we use the flash server with `--watch` option enabled. Also, some code changes are made to pass the regression test added in #16610.
Diffstat (limited to 'ext/flash/lib.rs')
-rw-r--r--ext/flash/lib.rs294
1 files changed, 183 insertions, 111 deletions
diff --git a/ext/flash/lib.rs b/ext/flash/lib.rs
index d08cdbcdc..7b4308807 100644
--- a/ext/flash/lib.rs
+++ b/ext/flash/lib.rs
@@ -35,6 +35,7 @@ use mio::Events;
use mio::Interest;
use mio::Poll;
use mio::Token;
+use mio::Waker;
use serde::Deserialize;
use serde::Serialize;
use socket2::Socket;
@@ -47,6 +48,7 @@ use std::intrinsics::transmute;
use std::io::BufReader;
use std::io::Read;
use std::io::Write;
+use std::marker::PhantomPinned;
use std::mem::replace;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
@@ -55,8 +57,8 @@ use std::rc::Rc;
use std::sync::Arc;
use std::sync::Mutex;
use std::task::Context;
-use std::time::Duration;
use tokio::sync::mpsc;
+use tokio::sync::oneshot;
use tokio::task::JoinHandle;
mod chunked;
@@ -76,15 +78,24 @@ pub struct FlashContext {
pub servers: HashMap<u32, ServerContext>,
}
+impl Drop for FlashContext {
+ fn drop(&mut self) {
+ // Signal each server instance to shutdown.
+ for (_, server) in self.servers.drain() {
+ let _ = server.waker.wake();
+ }
+ }
+}
+
pub struct ServerContext {
_addr: SocketAddr,
tx: mpsc::Sender<Request>,
- rx: mpsc::Receiver<Request>,
+ rx: Option<mpsc::Receiver<Request>>,
requests: HashMap<u32, Request>,
next_token: u32,
- listening_rx: Option<mpsc::Receiver<u16>>,
- close_tx: mpsc::Sender<()>,
+ listening_rx: Option<mpsc::Receiver<Result<u16, std::io::Error>>>,
cancel_handle: Rc<CancelHandle>,
+ waker: Arc<Waker>,
}
#[derive(Debug, Eq, PartialEq)]
@@ -102,7 +113,10 @@ fn op_flash_respond(
shutdown: bool,
) -> u32 {
let flash_ctx = op_state.borrow_mut::<FlashContext>();
- let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
+ let ctx = match flash_ctx.servers.get_mut(&server_id) {
+ Some(ctx) => ctx,
+ None => return 0,
+ };
flash_respond(ctx, token, shutdown, &response)
}
@@ -116,7 +130,7 @@ fn op_try_flash_respond_chuncked(
) -> u32 {
let flash_ctx = op_state.borrow_mut::<FlashContext>();
let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
- let tx = ctx.requests.get(&token).unwrap();
+ let tx = ctx.requests.get_mut(&token).unwrap();
let sock = tx.socket();
// TODO(@littledivy): Use writev when `UnixIoSlice` lands.
@@ -153,17 +167,20 @@ async fn op_flash_respond_async(
let sock = {
let mut op_state = state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
- let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
+ let ctx = match flash_ctx.servers.get_mut(&server_id) {
+ Some(ctx) => ctx,
+ None => return Ok(()),
+ };
match shutdown {
true => {
- let tx = ctx.requests.remove(&token).unwrap();
+ let mut tx = ctx.requests.remove(&token).unwrap();
close = !tx.keep_alive;
tx.socket()
}
// In case of a websocket upgrade or streaming response.
false => {
- let tx = ctx.requests.get(&token).unwrap();
+ let tx = ctx.requests.get_mut(&token).unwrap();
tx.socket()
}
}
@@ -197,12 +214,12 @@ async fn op_flash_respond_chuncked(
let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
let sock = match shutdown {
true => {
- let tx = ctx.requests.remove(&token).unwrap();
+ let mut tx = ctx.requests.remove(&token).unwrap();
tx.socket()
}
// In case of a websocket upgrade or streaming response.
false => {
- let tx = ctx.requests.get(&token).unwrap();
+ let tx = ctx.requests.get_mut(&token).unwrap();
tx.socket()
}
};
@@ -344,7 +361,7 @@ fn flash_respond(
shutdown: bool,
response: &[u8],
) -> u32 {
- let tx = ctx.requests.get(&token).unwrap();
+ let tx = ctx.requests.get_mut(&token).unwrap();
let sock = tx.socket();
sock.read_tx.take();
@@ -428,15 +445,36 @@ fn op_flash_method(state: &mut OpState, server_id: u32, token: u32) -> u32 {
}
#[op]
-async fn op_flash_close_server(state: Rc<RefCell<OpState>>, server_id: u32) {
- let close_tx = {
- let mut op_state = state.borrow_mut();
- let flash_ctx = op_state.borrow_mut::<FlashContext>();
- let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
- ctx.cancel_handle.cancel();
- ctx.close_tx.clone()
+fn op_flash_drive_server(
+ state: &mut OpState,
+ server_id: u32,
+) -> Result<impl Future<Output = Result<(), AnyError>> + 'static, AnyError> {
+ let join_handle = {
+ let flash_ctx = state.borrow_mut::<FlashContext>();
+ flash_ctx
+ .join_handles
+ .remove(&server_id)
+ .ok_or_else(|| type_error("server not found"))?
};
- let _ = close_tx.send(()).await;
+ Ok(async move {
+ join_handle
+ .await
+ .map_err(|_| type_error("server join error"))??;
+ Ok(())
+ })
+}
+
+#[op]
+fn op_flash_close_server(state: &mut OpState, server_id: u32) {
+ let flash_ctx = state.borrow_mut::<FlashContext>();
+ let ctx = flash_ctx.servers.get(&server_id).unwrap();
+
+ // NOTE: We don't drop ServerContext associated with the given `server_id`,
+ // because it may still be in use by some unsettled promise after the flash
+ // thread is finished.
+
+ ctx.cancel_handle.cancel();
+ let _ = ctx.waker.wake();
}
#[op]
@@ -463,7 +501,7 @@ fn op_flash_path(
fn next_request_sync(ctx: &mut ServerContext) -> u32 {
let offset = ctx.next_token;
- while let Ok(token) = ctx.rx.try_recv() {
+ while let Ok(token) = ctx.rx.as_mut().unwrap().try_recv() {
ctx.requests.insert(ctx.next_token, token);
ctx.next_token += 1;
}
@@ -526,6 +564,7 @@ unsafe fn op_flash_get_method_fast(
fn op_flash_make_request<'scope>(
scope: &mut v8::HandleScope<'scope>,
state: &mut OpState,
+ server_id: u32,
) -> serde_v8::Value<'scope> {
let object_template = v8::ObjectTemplate::new(scope);
assert!(object_template
@@ -533,7 +572,7 @@ fn op_flash_make_request<'scope>(
let obj = object_template.new_instance(scope).unwrap();
let ctx = {
let flash_ctx = state.borrow_mut::<FlashContext>();
- let ctx = flash_ctx.servers.get_mut(&0).unwrap();
+ let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
ctx as *mut ServerContext
};
obj.set_aligned_pointer_in_internal_field(V8_WRAPPER_OBJECT_INDEX, ctx as _);
@@ -625,7 +664,7 @@ fn op_flash_make_request<'scope>(
}
#[inline]
-fn has_body_stream(req: &Request) -> bool {
+fn has_body_stream(req: &mut Request) -> bool {
let sock = req.socket();
sock.read_rx.is_some()
}
@@ -749,7 +788,10 @@ async fn op_flash_read_body(
{
let op_state = &mut state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
- flash_ctx.servers.get_mut(&server_id).unwrap() as *mut ServerContext
+ match flash_ctx.servers.get_mut(&server_id) {
+ Some(ctx) => ctx as *mut ServerContext,
+ None => return 0,
+ }
}
.as_mut()
.unwrap()
@@ -851,41 +893,40 @@ pub struct ListenOpts {
reuseport: bool,
}
+const SERVER_TOKEN: Token = Token(0);
+// Token reserved for the thread close signal.
+const WAKER_TOKEN: Token = Token(1);
+
+#[allow(clippy::too_many_arguments)]
fn run_server(
tx: mpsc::Sender<Request>,
- listening_tx: mpsc::Sender<u16>,
- mut close_rx: mpsc::Receiver<()>,
+ listening_tx: mpsc::Sender<Result<u16, std::io::Error>>,
addr: SocketAddr,
maybe_cert: Option<String>,
maybe_key: Option<String>,
reuseport: bool,
+ mut poll: Poll,
+ // We put a waker as an unused argument here as it needs to be alive both in
+ // the flash thread and in the main thread (otherwise the notification would
+ // not be caught by the event loop on Linux).
+ // See the comment in mio's example:
+ // https://docs.rs/mio/0.8.4/x86_64-unknown-linux-gnu/mio/struct.Waker.html#examples
+ _waker: Arc<Waker>,
) -> Result<(), AnyError> {
- let domain = if addr.is_ipv4() {
- socket2::Domain::IPV4
- } else {
- socket2::Domain::IPV6
+ let mut listener = match listen(addr, reuseport) {
+ Ok(listener) => listener,
+ Err(e) => {
+ listening_tx.blocking_send(Err(e)).unwrap();
+ return Err(generic_error(
+ "failed to start listening on the specified address",
+ ));
+ }
};
- let socket = Socket::new(domain, socket2::Type::STREAM, None)?;
- #[cfg(not(windows))]
- socket.set_reuse_address(true)?;
- if reuseport {
- #[cfg(target_os = "linux")]
- socket.set_reuse_port(true)?;
- }
-
- let socket_addr = socket2::SockAddr::from(addr);
- socket.bind(&socket_addr)?;
- socket.listen(128)?;
- socket.set_nonblocking(true)?;
- let std_listener: std::net::TcpListener = socket.into();
- let mut listener = TcpListener::from_std(std_listener);
-
- let mut poll = Poll::new()?;
- let token = Token(0);
+ // Register server.
poll
.registry()
- .register(&mut listener, token, Interest::READABLE)
+ .register(&mut listener, SERVER_TOKEN, Interest::READABLE)
.unwrap();
let tls_context: Option<Arc<rustls::ServerConfig>> = {
@@ -907,30 +948,25 @@ fn run_server(
};
listening_tx
- .blocking_send(listener.local_addr().unwrap().port())
+ .blocking_send(Ok(listener.local_addr().unwrap().port()))
.unwrap();
let mut sockets = HashMap::with_capacity(1000);
- let mut counter: usize = 1;
+ let mut socket_senders = HashMap::with_capacity(1000);
+ let mut counter: usize = 2;
let mut events = Events::with_capacity(1024);
'outer: loop {
- let result = close_rx.try_recv();
- if result.is_ok() {
- break 'outer;
- }
- // FIXME(bartlomieju): how does Tokio handle it? I just put random 100ms
- // timeout here to handle close signal.
- match poll.poll(&mut events, Some(Duration::from_millis(100))) {
+ match poll.poll(&mut events, None) {
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => panic!("{}", e),
Ok(()) => (),
}
'events: for event in &events {
- if close_rx.try_recv().is_ok() {
- break 'outer;
- }
let token = event.token();
match token {
- Token(0) => loop {
+ WAKER_TOKEN => {
+ break 'outer;
+ }
+ SERVER_TOKEN => loop {
match listener.accept() {
Ok((mut socket, _)) => {
counter += 1;
@@ -958,6 +994,7 @@ fn run_server(
read_lock: Arc::new(Mutex::new(())),
parse_done: ParseStatus::None,
buffer: UnsafeCell::new(vec![0_u8; 1024]),
+ _pinned: PhantomPinned,
});
trace!("New connection: {}", token.0);
@@ -974,7 +1011,6 @@ fn run_server(
let mut_ref: Pin<&mut Stream> = Pin::as_mut(socket);
Pin::get_unchecked_mut(mut_ref)
};
- let sock_ptr = socket as *mut _;
if socket.detached {
match &mut socket.inner {
@@ -988,6 +1024,7 @@ fn run_server(
let boxed = sockets.remove(&token).unwrap();
std::mem::forget(boxed);
+ socket_senders.remove(&token);
trace!("Socket detached: {}", token.0);
continue;
}
@@ -1173,8 +1210,10 @@ fn run_server(
continue 'events;
}
+ let (socket_tx, socket_rx) = oneshot::channel();
+
tx.blocking_send(Request {
- socket: sock_ptr,
+ socket: socket as *mut _,
// SAFETY: headers backing buffer outlives the mio event loop ('static)
inner: inner_req,
keep_alive,
@@ -1183,16 +1222,57 @@ fn run_server(
content_read: 0,
content_length,
expect_continue,
+ socket_rx,
+ owned_socket: None,
})
.ok();
+
+ socket_senders.insert(token, socket_tx);
}
}
}
}
+ // Now the flash thread is about to finish, but there may be some unsettled
+ // promises in the main thread that will use the socket. To make the socket
+ // alive longer enough, we move its ownership to the main thread.
+ for (tok, socket) in sockets {
+ if let Some(sender) = socket_senders.remove(&tok) {
+ // Do nothing if the receiver has already been dropped.
+ _ = sender.send(socket);
+ }
+ }
+
Ok(())
}
+#[inline]
+fn listen(
+ addr: SocketAddr,
+ reuseport: bool,
+) -> Result<TcpListener, std::io::Error> {
+ let domain = if addr.is_ipv4() {
+ socket2::Domain::IPV4
+ } else {
+ socket2::Domain::IPV6
+ };
+ let socket = Socket::new(domain, socket2::Type::STREAM, None)?;
+
+ #[cfg(not(windows))]
+ socket.set_reuse_address(true)?;
+ if reuseport {
+ #[cfg(target_os = "linux")]
+ socket.set_reuse_port(true)?;
+ }
+
+ let socket_addr = socket2::SockAddr::from(addr);
+ socket.bind(&socket_addr)?;
+ socket.listen(128)?;
+ socket.set_nonblocking(true)?;
+ let std_listener: std::net::TcpListener = socket.into();
+ Ok(TcpListener::from_std(std_listener))
+}
+
fn make_addr_port_pair(hostname: &str, port: u16) -> (&str, u16) {
// Default to localhost if given just the port. Example: ":80"
if hostname.is_empty() {
@@ -1230,17 +1310,19 @@ where
.next()
.ok_or_else(|| generic_error("No resolved address found"))?;
let (tx, rx) = mpsc::channel(100);
- let (close_tx, close_rx) = mpsc::channel(1);
let (listening_tx, listening_rx) = mpsc::channel(1);
+
+ let poll = Poll::new()?;
+ let waker = Arc::new(Waker::new(poll.registry(), WAKER_TOKEN).unwrap());
let ctx = ServerContext {
_addr: addr,
tx,
- rx,
+ rx: Some(rx),
requests: HashMap::with_capacity(1000),
next_token: 0,
- close_tx,
listening_rx: Some(listening_rx),
cancel_handle: CancelHandle::new_rc(),
+ waker: waker.clone(),
};
let tx = ctx.tx.clone();
let maybe_cert = opts.cert;
@@ -1250,11 +1332,12 @@ where
run_server(
tx,
listening_tx,
- close_rx,
addr,
maybe_cert,
maybe_key,
reuseport,
+ poll,
+ waker,
)
});
let flash_ctx = state.borrow_mut::<FlashContext>();
@@ -1289,45 +1372,26 @@ where
}
#[op]
-fn op_flash_wait_for_listening(
- state: &mut OpState,
+async fn op_flash_wait_for_listening(
+ state: Rc<RefCell<OpState>>,
server_id: u32,
-) -> Result<impl Future<Output = Result<u16, AnyError>> + 'static, AnyError> {
+) -> Result<u16, AnyError> {
let mut listening_rx = {
- let flash_ctx = state.borrow_mut::<FlashContext>();
+ let mut op_state = state.borrow_mut();
+ let flash_ctx = op_state.borrow_mut::<FlashContext>();
let server_ctx = flash_ctx
.servers
.get_mut(&server_id)
.ok_or_else(|| type_error("server not found"))?;
server_ctx.listening_rx.take().unwrap()
};
- Ok(async move {
- if let Some(port) = listening_rx.recv().await {
- Ok(port)
- } else {
- Err(generic_error("This error will be discarded"))
- }
- })
-}
-
-#[op]
-fn op_flash_drive_server(
- state: &mut OpState,
- server_id: u32,
-) -> Result<impl Future<Output = Result<(), AnyError>> + 'static, AnyError> {
- let join_handle = {
- let flash_ctx = state.borrow_mut::<FlashContext>();
- flash_ctx
- .join_handles
- .remove(&server_id)
- .ok_or_else(|| type_error("server not found"))?
- };
- Ok(async move {
- join_handle
- .await
- .map_err(|_| type_error("server join error"))??;
- Ok(())
- })
+ match listening_rx.recv().await {
+ Some(Ok(port)) => Ok(port),
+ Some(Err(e)) => Err(e.into()),
+ _ => Err(generic_error(
+ "unknown error occurred while waiting for listening",
+ )),
+ }
}
// Asychronous version of op_flash_next. This can be a bottleneck under
@@ -1335,26 +1399,34 @@ fn op_flash_drive_server(
// requests i.e `op_flash_next() == 0`.
#[op]
async fn op_flash_next_async(
- op_state: Rc<RefCell<OpState>>,
+ state: Rc<RefCell<OpState>>,
server_id: u32,
) -> u32 {
- let ctx = {
- let mut op_state = op_state.borrow_mut();
+ let mut op_state = state.borrow_mut();
+ let flash_ctx = op_state.borrow_mut::<FlashContext>();
+ let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
+ let cancel_handle = ctx.cancel_handle.clone();
+ let mut rx = ctx.rx.take().unwrap();
+ // We need to drop the borrow before await point.
+ drop(op_state);
+
+ if let Ok(Some(req)) = rx.recv().or_cancel(&cancel_handle).await {
+ let mut op_state = state.borrow_mut();
let flash_ctx = op_state.borrow_mut::<FlashContext>();
let ctx = flash_ctx.servers.get_mut(&server_id).unwrap();
- ctx as *mut ServerContext
- };
- // SAFETY: we cannot hold op_state borrow across the await point. The JS caller
- // is responsible for ensuring this is not called concurrently.
- let ctx = unsafe { &mut *ctx };
- let cancel_handle = &ctx.cancel_handle;
-
- if let Ok(Some(req)) = ctx.rx.recv().or_cancel(cancel_handle).await {
ctx.requests.insert(ctx.next_token, req);
ctx.next_token += 1;
+ // Set the rx back.
+ ctx.rx = Some(rx);
return 1;
}
+ // Set the rx back.
+ let mut op_state = state.borrow_mut();
+ let flash_ctx = op_state.borrow_mut::<FlashContext>();
+ if let Some(ctx) = flash_ctx.servers.get_mut(&server_id) {
+ ctx.rx = Some(rx);
+ }
0
}
@@ -1427,7 +1499,7 @@ pub fn detach_socket(
// dropped on the server thread.
// * conversion from mio::net::TcpStream -> tokio::net::TcpStream. There is no public API so we
// use raw fds.
- let tx = ctx
+ let mut tx = ctx
.requests
.remove(&token)
.ok_or_else(|| type_error("request closed"))?;
@@ -1522,11 +1594,11 @@ pub fn init<P: FlashPermissions + 'static>(unstable: bool) -> Extension {
op_flash_next_async::decl(),
op_flash_read_body::decl(),
op_flash_upgrade_websocket::decl(),
- op_flash_drive_server::decl(),
op_flash_wait_for_listening::decl(),
op_flash_first_packet::decl(),
op_flash_has_body_stream::decl(),
op_flash_close_server::decl(),
+ op_flash_drive_server::decl(),
op_flash_make_request::decl(),
op_flash_write_resource::decl(),
op_try_flash_respond_chuncked::decl(),