diff options
Diffstat (limited to 'ext')
-rw-r--r-- | ext/kv/01_db.ts | 109 | ||||
-rw-r--r-- | ext/kv/Cargo.toml | 4 | ||||
-rw-r--r-- | ext/kv/interface.rs | 14 | ||||
-rw-r--r-- | ext/kv/lib.rs | 65 | ||||
-rw-r--r-- | ext/kv/sqlite.rs | 501 |
5 files changed, 640 insertions, 53 deletions
diff --git a/ext/kv/01_db.ts b/ext/kv/01_db.ts index f8181cc2e..eb103ae0c 100644 --- a/ext/kv/01_db.ts +++ b/ext/kv/01_db.ts @@ -26,6 +26,20 @@ async function openKv(path: string) { return new Kv(rid, kvSymbol); } +const millisecondsInOneWeek = 7 * 24 * 60 * 60 * 1000; + +function validateQueueDelay(delay: number) { + if (delay < 0) { + throw new TypeError("delay cannot be negative"); + } + if (delay > millisecondsInOneWeek) { + throw new TypeError("delay cannot be greater than one week"); + } + if (isNaN(delay)) { + throw new TypeError("delay cannot be NaN"); + } +} + interface RawKvEntry { key: Deno.KvKey; value: RawValue; @@ -47,6 +61,7 @@ const kvSymbol = Symbol("KvRid"); class Kv { #rid: number; + #closed: boolean; constructor(rid: number = undefined, symbol: symbol = undefined) { if (kvSymbol !== symbol) { @@ -55,6 +70,7 @@ class Kv { ); } this.#rid = rid; + this.#closed = false; } atomic() { @@ -203,8 +219,82 @@ class Kv { }; } + async enqueue( + message: unknown, + opts?: { delay?: number; keysIfUndelivered?: Deno.KvKey[] }, + ) { + if (opts?.delay !== undefined) { + validateQueueDelay(opts?.delay); + } + + const enqueues = [ + [ + core.serialize(message, { forStorage: true }), + opts?.delay ?? 0, + opts?.keysIfUndelivered ?? [], + null, + ], + ]; + + const versionstamp = await core.opAsync( + "op_kv_atomic_write", + this.#rid, + [], + [], + enqueues, + ); + if (versionstamp === null) throw new TypeError("Failed to enqueue value"); + return { ok: true, versionstamp }; + } + + async listenQueue( + handler: (message: unknown) => Promise<void> | void, + ): Promise<void> { + while (!this.#closed) { + // Wait for the next message. + let next: { 0: Uint8Array; 1: number }; + try { + next = await core.opAsync( + "op_kv_dequeue_next_message", + this.#rid, + ); + } catch (error) { + if (this.#closed) { + break; + } else { + throw error; + } + } + + // Deserialize the payload. + const { 0: payload, 1: handleId } = next; + const deserializedPayload = core.deserialize(payload, { + forStorage: true, + }); + + // Dispatch the payload. + (async () => { + let success = false; + try { + const result = handler(deserializedPayload); + const _res = result instanceof Promise ? (await result) : result; + success = true; + } catch (error) { + console.error("Exception in queue handler", error); + } finally { + await core.opAsync( + "op_kv_finish_dequeued_message", + handleId, + success, + ); + } + })(); + } + } + close() { core.close(this.#rid); + this.#closed = true; } } @@ -213,6 +303,7 @@ class AtomicOperation { #checks: [Deno.KvKey, string | null][] = []; #mutations: [Deno.KvKey, string, RawValue | null][] = []; + #enqueues: [Uint8Array, number, Deno.KvKey[], number[] | null][] = []; constructor(rid: number) { this.#rid = rid; @@ -280,13 +371,29 @@ class AtomicOperation { return this; } + enqueue( + message: unknown, + opts?: { delay?: number; keysIfUndelivered?: Deno.KvKey[] }, + ): this { + if (opts?.delay !== undefined) { + validateQueueDelay(opts?.delay); + } + this.#enqueues.push([ + core.serialize(message, { forStorage: true }), + opts?.delay ?? 0, + opts?.keysIfUndelivered ?? [], + null, + ]); + return this; + } + async commit(): Promise<Deno.KvCommitResult | Deno.KvCommitError> { const versionstamp = await core.opAsync( "op_kv_atomic_write", this.#rid, this.#checks, this.#mutations, - [], // TODO(@losfair): enqueue + this.#enqueues, ); if (versionstamp === null) return { ok: false }; return { ok: true, versionstamp }; diff --git a/ext/kv/Cargo.toml b/ext/kv/Cargo.toml index 1cb64c099..b25837143 100644 --- a/ext/kv/Cargo.toml +++ b/ext/kv/Cargo.toml @@ -20,5 +20,9 @@ base64.workspace = true deno_core.workspace = true hex.workspace = true num-bigint.workspace = true +rand.workspace = true rusqlite.workspace = true serde.workspace = true +serde_json.workspace = true +tokio.workspace = true +uuid.workspace = true diff --git a/ext/kv/interface.rs b/ext/kv/interface.rs index 31b7638b4..b67ee1243 100644 --- a/ext/kv/interface.rs +++ b/ext/kv/interface.rs @@ -25,6 +25,8 @@ pub trait DatabaseHandler { #[async_trait(?Send)] pub trait Database { + type QMH: QueueMessageHandle + 'static; + async fn snapshot_read( &self, requests: Vec<ReadRange>, @@ -35,6 +37,16 @@ pub trait Database { &self, write: AtomicWrite, ) -> Result<Option<CommitResult>, AnyError>; + + async fn dequeue_next_message(&self) -> Result<Self::QMH, AnyError>; + + fn close(&self); +} + +#[async_trait(?Send)] +pub trait QueueMessageHandle { + async fn take_payload(&mut self) -> Result<Vec<u8>, AnyError>; + async fn finish(&self, success: bool) -> Result<(), AnyError>; } /// Options for a snapshot read. @@ -242,7 +254,7 @@ pub struct KvMutation { /// keys specified in `keys_if_undelivered`. pub struct Enqueue { pub payload: Vec<u8>, - pub deadline_ms: u64, + pub delay_ms: u64, pub keys_if_undelivered: Vec<Vec<u8>>, pub backoff_schedule: Option<Vec<u32>>, } diff --git a/ext/kv/lib.rs b/ext/kv/lib.rs index dbc626225..2763fcf50 100644 --- a/ext/kv/lib.rs +++ b/ext/kv/lib.rs @@ -8,6 +8,7 @@ use std::borrow::Cow; use std::cell::RefCell; use std::num::NonZeroU32; use std::rc::Rc; +use std::vec; use codec::decode_key; use codec::encode_key; @@ -60,6 +61,8 @@ deno_core::extension!(deno_kv, op_kv_snapshot_read<DBH>, op_kv_atomic_write<DBH>, op_kv_encode_cursor, + op_kv_dequeue_next_message<DBH>, + op_kv_finish_dequeued_message<DBH>, ], esm = [ "01_db.ts" ], options = { @@ -80,6 +83,10 @@ impl<DB: Database + 'static> Resource for DatabaseResource<DB> { fn name(&self) -> Cow<str> { "database".into() } + + fn close(self: Rc<Self>) { + self.db.close(); + } } #[op] @@ -280,6 +287,62 @@ where Ok(output_ranges) } +struct QueueMessageResource<QPH: QueueMessageHandle + 'static> { + handle: QPH, +} + +impl<QMH: QueueMessageHandle + 'static> Resource for QueueMessageResource<QMH> { + fn name(&self) -> Cow<str> { + "queue_message".into() + } +} + +#[op] +async fn op_kv_dequeue_next_message<DBH>( + state: Rc<RefCell<OpState>>, + rid: ResourceId, +) -> Result<(ZeroCopyBuf, ResourceId), AnyError> +where + DBH: DatabaseHandler + 'static, +{ + let db = { + let state = state.borrow(); + let resource = + state.resource_table.get::<DatabaseResource<DBH::DB>>(rid)?; + resource.db.clone() + }; + + let mut handle = db.dequeue_next_message().await?; + let payload = handle.take_payload().await?.into(); + let handle_rid = { + let mut state = state.borrow_mut(); + state.resource_table.add(QueueMessageResource { handle }) + }; + Ok((payload, handle_rid)) +} + +#[op] +async fn op_kv_finish_dequeued_message<DBH>( + state: Rc<RefCell<OpState>>, + handle_rid: ResourceId, + success: bool, +) -> Result<(), AnyError> +where + DBH: DatabaseHandler + 'static, +{ + let handle = { + let mut state = state.borrow_mut(); + let handle = state + .resource_table + .take::<QueueMessageResource<<<DBH>::DB as Database>::QMH>>(handle_rid) + .map_err(|_| type_error("Queue message not found"))?; + Rc::try_unwrap(handle) + .map_err(|_| type_error("Queue message not found"))? + .handle + }; + handle.finish(success).await +} + type V8KvCheck = (KvKey, Option<ByteString>); impl TryFrom<V8KvCheck> for KvCheck { @@ -333,7 +396,7 @@ impl TryFrom<V8Enqueue> for Enqueue { fn try_from(value: V8Enqueue) -> Result<Self, AnyError> { Ok(Enqueue { payload: value.0.to_vec(), - deadline_ms: value.1, + delay_ms: value.1, keys_if_undelivered: value .2 .into_iter() diff --git a/ext/kv/sqlite.rs b/ext/kv/sqlite.rs index 80d230ab1..6cff3145d 100644 --- a/ext/kv/sqlite.rs +++ b/ext/kv/sqlite.rs @@ -7,10 +7,17 @@ use std::marker::PhantomData; use std::path::Path; use std::path::PathBuf; use std::rc::Rc; +use std::rc::Weak; +use std::sync::Arc; +use std::time::Duration; +use std::time::SystemTime; use async_trait::async_trait; use deno_core::error::type_error; use deno_core::error::AnyError; +use deno_core::futures; +use deno_core::futures::FutureExt; +use deno_core::task::spawn; use deno_core::task::spawn_blocking; use deno_core::AsyncRefCell; use deno_core::OpState; @@ -18,6 +25,12 @@ use rusqlite::params; use rusqlite::OpenFlags; use rusqlite::OptionalExtension; use rusqlite::Transaction; +use tokio::sync::mpsc; +use tokio::sync::watch; +use tokio::sync::OnceCell; +use tokio::sync::OwnedSemaphorePermit; +use tokio::sync::Semaphore; +use uuid::Uuid; use crate::AtomicWrite; use crate::CommitResult; @@ -25,6 +38,7 @@ use crate::Database; use crate::DatabaseHandler; use crate::KvEntry; use crate::MutationKind; +use crate::QueueMessageHandle; use crate::ReadRange; use crate::ReadRangeOutput; use crate::SnapshotReadOptions; @@ -44,6 +58,18 @@ const STATEMENT_KV_POINT_SET: &str = "insert into kv (k, v, v_encoding, version) values (:k, :v, :v_encoding, :version) on conflict(k) do update set v = :v, v_encoding = :v_encoding, version = :version"; const STATEMENT_KV_POINT_DELETE: &str = "delete from kv where k = ?"; +const STATEMENT_QUEUE_ADD_READY: &str = "insert into queue (ts, id, data, backoff_schedule, keys_if_undelivered) values(?, ?, ?, ?, ?)"; +const STATEMENT_QUEUE_GET_NEXT_READY: &str = "select ts, id, data, backoff_schedule, keys_if_undelivered from queue where ts <= ? order by ts limit 100"; +const STATEMENT_QUEUE_GET_EARLIEST_READY: &str = + "select ts from queue order by ts limit 1"; +const STATEMENT_QUEUE_REMOVE_READY: &str = "delete from queue where id = ?"; +const STATEMENT_QUEUE_ADD_RUNNING: &str = "insert into queue_running (deadline, id, data, backoff_schedule, keys_if_undelivered) values(?, ?, ?, ?, ?)"; +const STATEMENT_QUEUE_REMOVE_RUNNING: &str = + "delete from queue_running where id = ?"; +const STATEMENT_QUEUE_GET_RUNNING_BY_ID: &str = "select deadline, id, data, backoff_schedule, keys_if_undelivered from queue_running where id = ?"; +const STATEMENT_QUEUE_GET_RUNNING: &str = + "select id from queue_running order by deadline limit 100"; + const STATEMENT_CREATE_MIGRATION_TABLE: &str = " create table if not exists migration_state( k integer not null primary key, @@ -87,6 +113,9 @@ create table queue_running( ", ]; +const DISPATCH_CONCURRENCY_LIMIT: usize = 100; +const DEFAULT_BACKOFF_SCHEDULE: [u32; 5] = [100, 1000, 5000, 30000, 60000]; + pub struct SqliteDbHandler<P: SqliteDbHandlerPermissions + 'static> { pub default_storage_dir: Option<PathBuf>, _permissions: PhantomData<P>, @@ -182,14 +211,23 @@ impl<P: SqliteDbHandlerPermissions> DatabaseHandler for SqliteDbHandler<P> { .await .unwrap()?; - Ok(SqliteDb(Rc::new(AsyncRefCell::new(Cell::new(Some(conn)))))) + Ok(SqliteDb { + conn: Rc::new(AsyncRefCell::new(Cell::new(Some(conn)))), + queue: OnceCell::new(), + }) } } -pub struct SqliteDb(Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>); +pub struct SqliteDb { + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + queue: OnceCell<SqliteQueue>, +} impl SqliteDb { - async fn run_tx<F, R>(&self, f: F) -> Result<R, AnyError> + async fn run_tx<F, R>( + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + f: F, + ) -> Result<R, AnyError> where F: (FnOnce(rusqlite::Transaction<'_>) -> Result<R, AnyError>) + Send @@ -198,7 +236,7 @@ impl SqliteDb { { // Transactions need exclusive access to the connection. Wait until // we can borrow_mut the connection. - let cell = self.0.borrow_mut().await; + let cell = conn.borrow_mut().await; // Take the db out of the cell and run the transaction via spawn_blocking. let mut db = cell.take().unwrap(); @@ -220,59 +258,372 @@ impl SqliteDb { } } +pub struct DequeuedMessage { + conn: Weak<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + id: String, + payload: Option<Vec<u8>>, + waker_tx: mpsc::Sender<()>, + _permit: OwnedSemaphorePermit, +} + +#[async_trait(?Send)] +impl QueueMessageHandle for DequeuedMessage { + async fn finish(&self, success: bool) -> Result<(), AnyError> { + let Some(conn) = self.conn.upgrade() else { + return Ok(()); + }; + let id = self.id.clone(); + let requeued = SqliteDb::run_tx(conn, move |tx| { + let requeued = { + if success { + let changed = tx + .prepare_cached(STATEMENT_QUEUE_REMOVE_RUNNING)? + .execute([&id])?; + assert!(changed <= 1); + false + } else { + SqliteQueue::requeue_message(&id, &tx)? + } + }; + tx.commit()?; + Ok(requeued) + }) + .await?; + if requeued { + // If the message was requeued, wake up the dequeue loop. + self.waker_tx.send(()).await?; + } + Ok(()) + } + + async fn take_payload(&mut self) -> Result<Vec<u8>, AnyError> { + self + .payload + .take() + .ok_or_else(|| type_error("Payload already consumed")) + } +} + +type DequeueReceiver = mpsc::Receiver<(Vec<u8>, String)>; + +struct SqliteQueue { + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + dequeue_rx: Rc<AsyncRefCell<DequeueReceiver>>, + concurrency_limiter: Arc<Semaphore>, + waker_tx: mpsc::Sender<()>, + shutdown_tx: watch::Sender<()>, +} + +impl SqliteQueue { + fn new(conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>) -> Self { + let conn_clone = conn.clone(); + let (shutdown_tx, shutdown_rx) = watch::channel::<()>(()); + let (waker_tx, waker_rx) = mpsc::channel::<()>(1); + let (dequeue_tx, dequeue_rx) = mpsc::channel::<(Vec<u8>, String)>(64); + + spawn(async move { + // Oneshot requeue of all inflight messages. + Self::requeue_inflight_messages(conn.clone()).await.unwrap(); + + // Continous dequeue loop. + Self::dequeue_loop(conn.clone(), dequeue_tx, shutdown_rx, waker_rx) + .await + .unwrap(); + }); + + Self { + conn: conn_clone, + dequeue_rx: Rc::new(AsyncRefCell::new(dequeue_rx)), + waker_tx, + shutdown_tx, + concurrency_limiter: Arc::new(Semaphore::new(DISPATCH_CONCURRENCY_LIMIT)), + } + } + + async fn dequeue(&self) -> Result<DequeuedMessage, AnyError> { + // Wait for the next message to be available from dequeue_rx. + let (payload, id) = { + let mut queue_rx = self.dequeue_rx.borrow_mut().await; + let Some(msg) = queue_rx.recv().await else { + return Err(type_error("Database closed")); + }; + msg + }; + + let permit = self.concurrency_limiter.clone().acquire_owned().await?; + + Ok(DequeuedMessage { + conn: Rc::downgrade(&self.conn), + id, + payload: Some(payload), + waker_tx: self.waker_tx.clone(), + _permit: permit, + }) + } + + async fn wake(&self) -> Result<(), AnyError> { + self.waker_tx.send(()).await?; + Ok(()) + } + + fn shutdown(&self) { + self.shutdown_tx.send(()).unwrap(); + } + + async fn dequeue_loop( + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + dequeue_tx: mpsc::Sender<(Vec<u8>, String)>, + mut shutdown_rx: watch::Receiver<()>, + mut waker_rx: mpsc::Receiver<()>, + ) -> Result<(), AnyError> { + loop { + let messages = SqliteDb::run_tx(conn.clone(), move |tx| { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + let messages = tx + .prepare_cached(STATEMENT_QUEUE_GET_NEXT_READY)? + .query_map([now], |row| { + let ts: u64 = row.get(0)?; + let id: String = row.get(1)?; + let data: Vec<u8> = row.get(2)?; + let backoff_schedule: String = row.get(3)?; + let keys_if_undelivered: String = row.get(4)?; + Ok((ts, id, data, backoff_schedule, keys_if_undelivered)) + })? + .collect::<Result<Vec<_>, rusqlite::Error>>()?; + + for (ts, id, data, backoff_schedule, keys_if_undelivered) in &messages { + let changed = tx + .prepare_cached(STATEMENT_QUEUE_REMOVE_READY)? + .execute(params![id])?; + assert_eq!(changed, 1); + + let changed = + tx.prepare_cached(STATEMENT_QUEUE_ADD_RUNNING)?.execute( + params![ts, id, &data, &backoff_schedule, &keys_if_undelivered], + )?; + assert_eq!(changed, 1); + } + tx.commit()?; + + Ok( + messages + .into_iter() + .map(|(_, id, data, _, _)| (id, data)) + .collect::<Vec<_>>(), + ) + }) + .await?; + + let busy = !messages.is_empty(); + + for (id, data) in messages { + if dequeue_tx.send((data, id)).await.is_err() { + // Queue receiver was dropped. Stop the dequeue loop. + return Ok(()); + } + } + + if !busy { + // There's nothing to dequeue right now; sleep until one of the + // following happens: + // - It's time to dequeue the next message based on its timestamp + // - A new message is added to the queue + // - The database is closed + let sleep_fut = { + match Self::get_earliest_ready_ts(conn.clone()).await? { + Some(ts) => { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + if ts <= now { + continue; + } + tokio::time::sleep(Duration::from_millis(ts - now)).boxed() + } + None => futures::future::pending().boxed(), + } + }; + tokio::select! { + _ = sleep_fut => {} + _ = waker_rx.recv() => {} + _ = shutdown_rx.changed() => return Ok(()) + } + } + } + } + + async fn get_earliest_ready_ts( + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + ) -> Result<Option<u64>, AnyError> { + SqliteDb::run_tx(conn.clone(), move |tx| { + let ts = tx + .prepare_cached(STATEMENT_QUEUE_GET_EARLIEST_READY)? + .query_row([], |row| { + let ts: u64 = row.get(0)?; + Ok(ts) + }) + .optional()?; + Ok(ts) + }) + .await + } + + async fn requeue_inflight_messages( + conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>, + ) -> Result<(), AnyError> { + loop { + let done = SqliteDb::run_tx(conn.clone(), move |tx| { + let entries = tx + .prepare_cached(STATEMENT_QUEUE_GET_RUNNING)? + .query_map([], |row| { + let id: String = row.get(0)?; + Ok(id) + })? + .collect::<Result<Vec<_>, rusqlite::Error>>()?; + for id in &entries { + Self::requeue_message(id, &tx)?; + } + tx.commit()?; + Ok(entries.is_empty()) + }) + .await?; + if done { + return Ok(()); + } + } + } + + fn requeue_message( + id: &str, + tx: &rusqlite::Transaction<'_>, + ) -> Result<bool, AnyError> { + let Some((_, id, data, backoff_schedule, keys_if_undelivered)) = tx + .prepare_cached(STATEMENT_QUEUE_GET_RUNNING_BY_ID)? + .query_row([id], |row| { + let deadline: u64 = row.get(0)?; + let id: String = row.get(1)?; + let data: Vec<u8> = row.get(2)?; + let backoff_schedule: String = row.get(3)?; + let keys_if_undelivered: String = row.get(4)?; + Ok((deadline, id, data, backoff_schedule, keys_if_undelivered)) + }) + .optional()? else { + return Ok(false); + }; + + let backoff_schedule = { + let backoff_schedule = + serde_json::from_str::<Option<Vec<u64>>>(&backoff_schedule)?; + backoff_schedule.unwrap_or_default() + }; + + let mut requeued = false; + if !backoff_schedule.is_empty() { + // Requeue based on backoff schedule + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + let new_ts = now + backoff_schedule[0]; + let new_backoff_schedule = serde_json::to_string(&backoff_schedule[1..])?; + let changed = tx + .prepare_cached(STATEMENT_QUEUE_ADD_READY)? + .execute(params![ + new_ts, + id, + &data, + &new_backoff_schedule, + &keys_if_undelivered + ]) + .unwrap(); + assert_eq!(changed, 1); + requeued = true; + } else if !keys_if_undelivered.is_empty() { + // No more requeues. Insert the message into the undelivered queue. + let keys_if_undelivered = + serde_json::from_str::<Vec<Vec<u8>>>(&keys_if_undelivered)?; + + let version: i64 = tx + .prepare_cached(STATEMENT_INC_AND_GET_DATA_VERSION)? + .query_row([], |row| row.get(0))?; + + for key in keys_if_undelivered { + let changed = tx + .prepare_cached(STATEMENT_KV_POINT_SET)? + .execute(params![key, &data, &VALUE_ENCODING_V8, &version])?; + assert_eq!(changed, 1); + } + } + + // Remove from running + let changed = tx + .prepare_cached(STATEMENT_QUEUE_REMOVE_RUNNING)? + .execute(params![id])?; + assert_eq!(changed, 1); + + Ok(requeued) + } +} + #[async_trait(?Send)] impl Database for SqliteDb { + type QMH = DequeuedMessage; + async fn snapshot_read( &self, requests: Vec<ReadRange>, _options: SnapshotReadOptions, ) -> Result<Vec<ReadRangeOutput>, AnyError> { - self - .run_tx(move |tx| { - let mut responses = Vec::with_capacity(requests.len()); - for request in requests { - let mut stmt = tx.prepare_cached(if request.reverse { - STATEMENT_KV_RANGE_SCAN_REVERSE - } else { - STATEMENT_KV_RANGE_SCAN - })?; - let entries = stmt - .query_map( - ( - request.start.as_slice(), - request.end.as_slice(), - request.limit.get(), - ), - |row| { - let key: Vec<u8> = row.get(0)?; - let value: Vec<u8> = row.get(1)?; - let encoding: i64 = row.get(2)?; - - let value = decode_value(value, encoding); - - let version: i64 = row.get(3)?; - Ok(KvEntry { - key, - value, - versionstamp: version_to_versionstamp(version), - }) - }, - )? - .collect::<Result<Vec<_>, rusqlite::Error>>()?; - responses.push(ReadRangeOutput { entries }); - } + Self::run_tx(self.conn.clone(), move |tx| { + let mut responses = Vec::with_capacity(requests.len()); + for request in requests { + let mut stmt = tx.prepare_cached(if request.reverse { + STATEMENT_KV_RANGE_SCAN_REVERSE + } else { + STATEMENT_KV_RANGE_SCAN + })?; + let entries = stmt + .query_map( + ( + request.start.as_slice(), + request.end.as_slice(), + request.limit.get(), + ), + |row| { + let key: Vec<u8> = row.get(0)?; + let value: Vec<u8> = row.get(1)?; + let encoding: i64 = row.get(2)?; + + let value = decode_value(value, encoding); + + let version: i64 = row.get(3)?; + Ok(KvEntry { + key, + value, + versionstamp: version_to_versionstamp(version), + }) + }, + )? + .collect::<Result<Vec<_>, rusqlite::Error>>()?; + responses.push(ReadRangeOutput { entries }); + } - Ok(responses) - }) - .await + Ok(responses) + }) + .await } async fn atomic_write( &self, write: AtomicWrite, ) -> Result<Option<CommitResult>, AnyError> { - self - .run_tx(move |tx| { + let (has_enqueues, commit_result) = + Self::run_tx(self.conn.clone(), move |tx| { for check in write.checks { let real_versionstamp = tx .prepare_cached(STATEMENT_KV_POINT_GET_VERSION_ONLY)? @@ -280,7 +631,7 @@ impl Database for SqliteDb { .optional()? .map(version_to_versionstamp); if real_versionstamp != check.versionstamp { - return Ok(None); + return Ok((false, None)); } } @@ -336,17 +687,67 @@ impl Database for SqliteDb { } } - // TODO(@losfair): enqueues + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_millis() as u64; + + let has_enqueues = !write.enqueues.is_empty(); + for enqueue in write.enqueues { + let id = Uuid::new_v4().to_string(); + let backoff_schedule = serde_json::to_string( + &enqueue + .backoff_schedule + .or_else(|| Some(DEFAULT_BACKOFF_SCHEDULE.to_vec())), + )?; + let keys_if_undelivered = + serde_json::to_string(&enqueue.keys_if_undelivered)?; + + let changed = + tx.prepare_cached(STATEMENT_QUEUE_ADD_READY)? + .execute(params![ + now + enqueue.delay_ms, + id, + &enqueue.payload, + &backoff_schedule, + &keys_if_undelivered + ])?; + assert_eq!(changed, 1) + } tx.commit()?; - let new_vesionstamp = version_to_versionstamp(version); - Ok(Some(CommitResult { - versionstamp: new_vesionstamp, - })) + Ok(( + has_enqueues, + Some(CommitResult { + versionstamp: new_vesionstamp, + }), + )) }) - .await + .await?; + + if has_enqueues { + if let Some(queue) = self.queue.get() { + queue.wake().await?; + } + } + Ok(commit_result) + } + + async fn dequeue_next_message(&self) -> Result<Self::QMH, AnyError> { + let queue = self + .queue + .get_or_init(|| async move { SqliteQueue::new(self.conn.clone()) }) + .await; + let handle = queue.dequeue().await?; + Ok(handle) + } + + fn close(&self) { + if let Some(queue) = self.queue.get() { + queue.shutdown(); + } } } |