diff options
Diffstat (limited to 'ext/kv/sqlite.rs')
-rw-r--r-- | ext/kv/sqlite.rs | 501 |
1 files changed, 451 insertions, 50 deletions
diff --git a/ext/kv/sqlite.rs b/ext/kv/sqlite.rs index 80d230ab1..6cff3145d 100644 --- a/ext/kv/sqlite.rs +++ b/ext/kv/sqlite.rs @@ -7,10 +7,17 @@ use std::marker::PhantomData; use std::path::Path; use std::path::PathBuf; use std::rc::Rc; +use std::rc::Weak; +use std::sync::Arc; +use std::time::Duration; +use std::time::SystemTime; use async_trait::async_trait; use deno_core::error::type_error; use deno_core::error::AnyError; +use deno_core::futures; +use deno_core::futures::FutureExt; +use deno_core::task::spawn; use deno_core::task::spawn_blocking; use deno_core::AsyncRefCell; use deno_core::OpState; @@ -18,6 +25,12 @@ use rusqlite::params; use rusqlite::OpenFlags; use rusqlite::OptionalExtension; use rusqlite::Transaction; +use tokio::sync::mpsc; +use tokio::sync::watch; +use tokio::sync::OnceCell; +use tokio::sync::OwnedSemaphorePermit; +use tokio::sync::Semaphore; +use uuid::Uuid; use crate::AtomicWrite; use crate::CommitResult; @@ -25,6 +38,7 @@ use crate::Database; use crate::DatabaseHandler; use crate::KvEntry; use crate::MutationKind; +use crate::QueueMessageHandle; use crate::ReadRange; use crate::ReadRangeOutput; use crate::SnapshotReadOptions; @@ -44,6 +58,18 @@ const STATEMENT_KV_POINT_SET: &str = "insert into kv (k, v, v_encoding, version) values (:k, :v, :v_encoding, :version) on conflict(k) do update set v = :v, v_encoding = :v_encoding, version = :version"; const STATEMENT_KV_POINT_DELETE: &str = "delete from kv where k = ?"; +const STATEMENT_QUEUE_ADD_READY: &str = "insert into queue (ts, id, data, backoff_schedule, keys_if_undelivered) values(?, ?, ?, ?, ?)"; +const STATEMENT_QUEUE_GET_NEXT_READY: &str = "select ts, id, data, backoff_schedule, keys_if_undelivered from queue where ts <= ? order by ts limit 100"; +const STATEMENT_QUEUE_GET_EARLIEST_READY: &str = + "select ts from queue order by ts limit 1"; +const STATEMENT_QUEUE_REMOVE_READY: &str = "delete from queue where id = ?"; +const STATEMENT_QUEUE_ADD_RUNNING: &str = "insert into queue_running (deadline, id, data, backoff_schedule, keys_if_undelivered) values(?, ?, ?, ?, ?)"; +const STATEMENT_QUEUE_REMOVE_RUNNING: &str = + "delete from queue_running where id = ?"; +const STATEMENT_QUEUE_GET_RUNNING_BY_ID: &str = "select deadline, id, data, backoff_schedule, keys_if_undelivered from queue_running where id = ?"; +const STATEMENT_QUEUE_GET_RUNNING: &str = + "select id from queue_running order by deadline limit 100"; + const STATEMENT_CREATE_MIGRATION_TABLE: &str = " create table if not exists migration_state( k integer not null primary key, @@ -87,6 +113,9 @@ create table queue_running( ", ]; +const DISPATCH_CONCURRENCY_LIMIT: usize = 100; +const DEFAULT_BACKOFF_SCHEDULE: [u32; 5] = [100, 1000, 5000, 30000, 60000]; + pub struct SqliteDbHandler<P: SqliteDbHandlerPermissions + 'static> { pub default_storage_dir: Option<PathBuf>, _permissions: PhantomData<P>, @@ -182,14 +211,23 @@ impl<P: SqliteDbHandlerPermissions> DatabaseHandler for SqliteDbHandler<P> { .await .unwrap()?; - Ok(SqliteDb(Rc::new(AsyncRefCell::new(Cell::new(Some(conn)))))) + Ok(SqliteDb { + conn: Rc::new(AsyncRefCell::new(Cell::new(Some(conn)))), + queue: OnceCell::new(), + }) } } -pub struct SqliteDb(Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>); +pub struct SqliteDb { + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + queue: OnceCell<SqliteQueue>, +} impl SqliteDb { - async fn run_tx<F, R>(&self, f: F) -> Result<R, AnyError> + async fn run_tx<F, R>( + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + f: F, + ) -> Result<R, AnyError> where F: (FnOnce(rusqlite::Transaction<'_>) -> Result<R, AnyError>) + Send @@ -198,7 +236,7 @@ impl SqliteDb { { // Transactions need exclusive access to the connection. Wait until // we can borrow_mut the connection. - let cell = self.0.borrow_mut().await; + let cell = conn.borrow_mut().await; // Take the db out of the cell and run the transaction via spawn_blocking. let mut db = cell.take().unwrap(); @@ -220,59 +258,372 @@ impl SqliteDb { } } +pub struct DequeuedMessage { + conn: Weak<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + id: String, + payload: Option<Vec<u8>>, + waker_tx: mpsc::Sender<()>, + _permit: OwnedSemaphorePermit, +} + +#[async_trait(?Send)] +impl QueueMessageHandle for DequeuedMessage { + async fn finish(&self, success: bool) -> Result<(), AnyError> { + let Some(conn) = self.conn.upgrade() else { + return Ok(()); + }; + let id = self.id.clone(); + let requeued = SqliteDb::run_tx(conn, move |tx| { + let requeued = { + if success { + let changed = tx + .prepare_cached(STATEMENT_QUEUE_REMOVE_RUNNING)? + .execute([&id])?; + assert!(changed <= 1); + false + } else { + SqliteQueue::requeue_message(&id, &tx)? + } + }; + tx.commit()?; + Ok(requeued) + }) + .await?; + if requeued { + // If the message was requeued, wake up the dequeue loop. + self.waker_tx.send(()).await?; + } + Ok(()) + } + + async fn take_payload(&mut self) -> Result<Vec<u8>, AnyError> { + self + .payload + .take() + .ok_or_else(|| type_error("Payload already consumed")) + } +} + +type DequeueReceiver = mpsc::Receiver<(Vec<u8>, String)>; + +struct SqliteQueue { + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + dequeue_rx: Rc<AsyncRefCell<DequeueReceiver>>, + concurrency_limiter: Arc<Semaphore>, + waker_tx: mpsc::Sender<()>, + shutdown_tx: watch::Sender<()>, +} + +impl SqliteQueue { + fn new(conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>) -> Self { + let conn_clone = conn.clone(); + let (shutdown_tx, shutdown_rx) = watch::channel::<()>(()); + let (waker_tx, waker_rx) = mpsc::channel::<()>(1); + let (dequeue_tx, dequeue_rx) = mpsc::channel::<(Vec<u8>, String)>(64); + + spawn(async move { + // Oneshot requeue of all inflight messages. + Self::requeue_inflight_messages(conn.clone()).await.unwrap(); + + // Continous dequeue loop. + Self::dequeue_loop(conn.clone(), dequeue_tx, shutdown_rx, waker_rx) + .await + .unwrap(); + }); + + Self { + conn: conn_clone, + dequeue_rx: Rc::new(AsyncRefCell::new(dequeue_rx)), + waker_tx, + shutdown_tx, + concurrency_limiter: Arc::new(Semaphore::new(DISPATCH_CONCURRENCY_LIMIT)), + } + } + + async fn dequeue(&self) -> Result<DequeuedMessage, AnyError> { + // Wait for the next message to be available from dequeue_rx. + let (payload, id) = { + let mut queue_rx = self.dequeue_rx.borrow_mut().await; + let Some(msg) = queue_rx.recv().await else { + return Err(type_error("Database closed")); + }; + msg + }; + + let permit = self.concurrency_limiter.clone().acquire_owned().await?; + + Ok(DequeuedMessage { + conn: Rc::downgrade(&self.conn), + id, + payload: Some(payload), + waker_tx: self.waker_tx.clone(), + _permit: permit, + }) + } + + async fn wake(&self) -> Result<(), AnyError> { + self.waker_tx.send(()).await?; + Ok(()) + } + + fn shutdown(&self) { + self.shutdown_tx.send(()).unwrap(); + } + + async fn dequeue_loop( + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + dequeue_tx: mpsc::Sender<(Vec<u8>, String)>, + mut shutdown_rx: watch::Receiver<()>, + mut waker_rx: mpsc::Receiver<()>, + ) -> Result<(), AnyError> { + loop { + let messages = SqliteDb::run_tx(conn.clone(), move |tx| { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + let messages = tx + .prepare_cached(STATEMENT_QUEUE_GET_NEXT_READY)? + .query_map([now], |row| { + let ts: u64 = row.get(0)?; + let id: String = row.get(1)?; + let data: Vec<u8> = row.get(2)?; + let backoff_schedule: String = row.get(3)?; + let keys_if_undelivered: String = row.get(4)?; + Ok((ts, id, data, backoff_schedule, keys_if_undelivered)) + })? + .collect::<Result<Vec<_>, rusqlite::Error>>()?; + + for (ts, id, data, backoff_schedule, keys_if_undelivered) in &messages { + let changed = tx + .prepare_cached(STATEMENT_QUEUE_REMOVE_READY)? + .execute(params![id])?; + assert_eq!(changed, 1); + + let changed = + tx.prepare_cached(STATEMENT_QUEUE_ADD_RUNNING)?.execute( + params![ts, id, &data, &backoff_schedule, &keys_if_undelivered], + )?; + assert_eq!(changed, 1); + } + tx.commit()?; + + Ok( + messages + .into_iter() + .map(|(_, id, data, _, _)| (id, data)) + .collect::<Vec<_>>(), + ) + }) + .await?; + + let busy = !messages.is_empty(); + + for (id, data) in messages { + if dequeue_tx.send((data, id)).await.is_err() { + // Queue receiver was dropped. Stop the dequeue loop. + return Ok(()); + } + } + + if !busy { + // There's nothing to dequeue right now; sleep until one of the + // following happens: + // - It's time to dequeue the next message based on its timestamp + // - A new message is added to the queue + // - The database is closed + let sleep_fut = { + match Self::get_earliest_ready_ts(conn.clone()).await? { + Some(ts) => { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + if ts <= now { + continue; + } + tokio::time::sleep(Duration::from_millis(ts - now)).boxed() + } + None => futures::future::pending().boxed(), + } + }; + tokio::select! { + _ = sleep_fut => {} + _ = waker_rx.recv() => {} + _ = shutdown_rx.changed() => return Ok(()) + } + } + } + } + + async fn get_earliest_ready_ts( + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + ) -> Result<Option<u64>, AnyError> { + SqliteDb::run_tx(conn.clone(), move |tx| { + let ts = tx + .prepare_cached(STATEMENT_QUEUE_GET_EARLIEST_READY)? + .query_row([], |row| { + let ts: u64 = row.get(0)?; + Ok(ts) + }) + .optional()?; + Ok(ts) + }) + .await + } + + async fn requeue_inflight_messages( + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + ) -> Result<(), AnyError> { + loop { + let done = SqliteDb::run_tx(conn.clone(), move |tx| { + let entries = tx + .prepare_cached(STATEMENT_QUEUE_GET_RUNNING)? + .query_map([], |row| { + let id: String = row.get(0)?; + Ok(id) + })? + .collect::<Result<Vec<_>, rusqlite::Error>>()?; + for id in &entries { + Self::requeue_message(id, &tx)?; + } + tx.commit()?; + Ok(entries.is_empty()) + }) + .await?; + if done { + return Ok(()); + } + } + } + + fn requeue_message( + id: &str, + tx: &rusqlite::Transaction<'_>, + ) -> Result<bool, AnyError> { + let Some((_, id, data, backoff_schedule, keys_if_undelivered)) = tx + .prepare_cached(STATEMENT_QUEUE_GET_RUNNING_BY_ID)? + .query_row([id], |row| { + let deadline: u64 = row.get(0)?; + let id: String = row.get(1)?; + let data: Vec<u8> = row.get(2)?; + let backoff_schedule: String = row.get(3)?; + let keys_if_undelivered: String = row.get(4)?; + Ok((deadline, id, data, backoff_schedule, keys_if_undelivered)) + }) + .optional()? else { + return Ok(false); + }; + + let backoff_schedule = { + let backoff_schedule = + serde_json::from_str::<Option<Vec<u64>>>(&backoff_schedule)?; + backoff_schedule.unwrap_or_default() + }; + + let mut requeued = false; + if !backoff_schedule.is_empty() { + // Requeue based on backoff schedule + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + let new_ts = now + backoff_schedule[0]; + let new_backoff_schedule = serde_json::to_string(&backoff_schedule[1..])?; + let changed = tx + .prepare_cached(STATEMENT_QUEUE_ADD_READY)? + .execute(params![ + new_ts, + id, + &data, + &new_backoff_schedule, + &keys_if_undelivered + ]) + .unwrap(); + assert_eq!(changed, 1); + requeued = true; + } else if !keys_if_undelivered.is_empty() { + // No more requeues. Insert the message into the undelivered queue. + let keys_if_undelivered = + serde_json::from_str::<Vec<Vec<u8>>>(&keys_if_undelivered)?; + + let version: i64 = tx + .prepare_cached(STATEMENT_INC_AND_GET_DATA_VERSION)? + .query_row([], |row| row.get(0))?; + + for key in keys_if_undelivered { + let changed = tx + .prepare_cached(STATEMENT_KV_POINT_SET)? + .execute(params![key, &data, &VALUE_ENCODING_V8, &version])?; + assert_eq!(changed, 1); + } + } + + // Remove from running + let changed = tx + .prepare_cached(STATEMENT_QUEUE_REMOVE_RUNNING)? + .execute(params![id])?; + assert_eq!(changed, 1); + + Ok(requeued) + } +} + #[async_trait(?Send)] impl Database for SqliteDb { + type QMH = DequeuedMessage; + async fn snapshot_read( &self, requests: Vec<ReadRange>, _options: SnapshotReadOptions, ) -> Result<Vec<ReadRangeOutput>, AnyError> { - self - .run_tx(move |tx| { - let mut responses = Vec::with_capacity(requests.len()); - for request in requests { - let mut stmt = tx.prepare_cached(if request.reverse { - STATEMENT_KV_RANGE_SCAN_REVERSE - } else { - STATEMENT_KV_RANGE_SCAN - })?; - let entries = stmt - .query_map( - ( - request.start.as_slice(), - request.end.as_slice(), - request.limit.get(), - ), - |row| { - let key: Vec<u8> = row.get(0)?; - let value: Vec<u8> = row.get(1)?; - let encoding: i64 = row.get(2)?; - - let value = decode_value(value, encoding); - - let version: i64 = row.get(3)?; - Ok(KvEntry { - key, - value, - versionstamp: version_to_versionstamp(version), - }) - }, - )? - .collect::<Result<Vec<_>, rusqlite::Error>>()?; - responses.push(ReadRangeOutput { entries }); - } + Self::run_tx(self.conn.clone(), move |tx| { + let mut responses = Vec::with_capacity(requests.len()); + for request in requests { + let mut stmt = tx.prepare_cached(if request.reverse { + STATEMENT_KV_RANGE_SCAN_REVERSE + } else { + STATEMENT_KV_RANGE_SCAN + })?; + let entries = stmt + .query_map( + ( + request.start.as_slice(), + request.end.as_slice(), + request.limit.get(), + ), + |row| { + let key: Vec<u8> = row.get(0)?; + let value: Vec<u8> = row.get(1)?; + let encoding: i64 = row.get(2)?; + + let value = decode_value(value, encoding); + + let version: i64 = row.get(3)?; + Ok(KvEntry { + key, + value, + versionstamp: version_to_versionstamp(version), + }) + }, + )? + .collect::<Result<Vec<_>, rusqlite::Error>>()?; + responses.push(ReadRangeOutput { entries }); + } - Ok(responses) - }) - .await + Ok(responses) + }) + .await } async fn atomic_write( &self, write: AtomicWrite, ) -> Result<Option<CommitResult>, AnyError> { - self - .run_tx(move |tx| { + let (has_enqueues, commit_result) = + Self::run_tx(self.conn.clone(), move |tx| { for check in write.checks { let real_versionstamp = tx .prepare_cached(STATEMENT_KV_POINT_GET_VERSION_ONLY)? @@ -280,7 +631,7 @@ impl Database for SqliteDb { .optional()? .map(version_to_versionstamp); if real_versionstamp != check.versionstamp { - return Ok(None); + return Ok((false, None)); } } @@ -336,17 +687,67 @@ impl Database for SqliteDb { } } - // TODO(@losfair): enqueues + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + let has_enqueues = !write.enqueues.is_empty(); + for enqueue in write.enqueues { + let id = Uuid::new_v4().to_string(); + let backoff_schedule = serde_json::to_string( + &enqueue + .backoff_schedule + .or_else(|| Some(DEFAULT_BACKOFF_SCHEDULE.to_vec())), + )?; + let keys_if_undelivered = + serde_json::to_string(&enqueue.keys_if_undelivered)?; + + let changed = + tx.prepare_cached(STATEMENT_QUEUE_ADD_READY)? + .execute(params![ + now + enqueue.delay_ms, + id, + &enqueue.payload, + &backoff_schedule, + &keys_if_undelivered + ])?; + assert_eq!(changed, 1) + } tx.commit()?; - let new_vesionstamp = version_to_versionstamp(version); - Ok(Some(CommitResult { - versionstamp: new_vesionstamp, - })) + Ok(( + has_enqueues, + Some(CommitResult { + versionstamp: new_vesionstamp, + }), + )) }) - .await + .await?; + + if has_enqueues { + if let Some(queue) = self.queue.get() { + queue.wake().await?; + } + } + Ok(commit_result) + } + + async fn dequeue_next_message(&self) -> Result<Self::QMH, AnyError> { + let queue = self + .queue + .get_or_init(|| async move { SqliteQueue::new(self.conn.clone()) }) + .await; + let handle = queue.dequeue().await?; + Ok(handle) + } + + fn close(&self) { + if let Some(queue) = self.queue.get() { + queue.shutdown(); + } } } |