1
0
Fork 0
mirror of https://github.com/denoland/deno.git synced 2024-11-25 15:29:32 -05:00

feat(kv) queue implementation (#19459)

Extend the unstable `Deno.Kv` API to support queues.
This commit is contained in:
Igor Zinkovsky 2023-06-13 17:49:57 -07:00 committed by Bartek Iwańczuk
parent fccec654cb
commit 116972f3fc
No known key found for this signature in database
GPG key ID: 0C6BCDDC3B3AD750
10 changed files with 1203 additions and 51 deletions

4
Cargo.lock generated
View file

@ -1170,8 +1170,12 @@ dependencies = [
"deno_core",
"hex",
"num-bigint",
"rand",
"rusqlite",
"serde",
"serde_json",
"tokio",
"uuid",
]
[[package]]

View file

@ -46,6 +46,7 @@ util::unit_test_factory!(
intl_test,
io_test,
kv_test,
kv_queue_undelivered_test,
link_test,
make_temp_test,
message_channel_test,

View file

@ -0,0 +1,56 @@
// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
import { assertEquals } from "./test_util.ts";
const sleep = (time: number) => new Promise((r) => setTimeout(r, time));
let isCI: boolean;
try {
isCI = Deno.env.get("CI") !== undefined;
} catch {
isCI = true;
}
function queueTest(name: string, fn: (db: Deno.Kv) => Promise<void>) {
Deno.test({
name,
// https://github.com/denoland/deno/issues/18363
ignore: Deno.build.os === "darwin" && isCI,
async fn() {
const db: Deno.Kv = await Deno.openKv(
":memory:",
);
await fn(db);
},
});
}
async function collect<T>(
iter: Deno.KvListIterator<T>,
): Promise<Deno.KvEntry<T>[]> {
const entries: Deno.KvEntry<T>[] = [];
for await (const entry of iter) {
entries.push(entry);
}
return entries;
}
queueTest("queue with undelivered", async (db) => {
const listener = db.listenQueue((_msg) => {
throw new TypeError("dequeue error");
});
try {
await db.enqueue("test", {
keysIfUndelivered: [["queue_failed", "a"], ["queue_failed", "b"]],
});
await sleep(100000);
const undelivered = await collect(db.list({ prefix: ["queue_failed"] }));
assertEquals(undelivered.length, 2);
assertEquals(undelivered[0].key, ["queue_failed", "a"]);
assertEquals(undelivered[0].value, "test");
assertEquals(undelivered[1].key, ["queue_failed", "b"]);
assertEquals(undelivered[1].value, "test");
} finally {
db.close();
await listener;
}
});

View file

@ -3,11 +3,16 @@ import {
assert,
assertEquals,
AssertionError,
assertNotEquals,
assertRejects,
assertThrows,
Deferred,
deferred,
} from "./test_util.ts";
import { assertType, IsExact } from "../../../test_util/std/testing/types.ts";
const sleep = (time: number) => new Promise((r) => setTimeout(r, time));
let isCI: boolean;
try {
isCI = Deno.env.get("CI") !== undefined;
@ -59,6 +64,20 @@ function dbTest(name: string, fn: (db: Deno.Kv) => Promise<void>) {
});
}
function queueTest(name: string, fn: (db: Deno.Kv) => Promise<void>) {
Deno.test({
name,
// https://github.com/denoland/deno/issues/18363
ignore: Deno.build.os === "darwin" && isCI,
async fn() {
const db: Deno.Kv = await Deno.openKv(
":memory:",
);
await fn(db);
},
});
}
dbTest("basic read-write-delete and versionstamps", async (db) => {
const result1 = await db.get(["a"]);
assertEquals(result1.key, ["a"]);
@ -1304,3 +1323,429 @@ async function _typeCheckingTests() {
assert(!j.done);
assertType<IsExact<typeof j.value, Deno.KvEntry<string>>>(true);
}
queueTest("basic listenQueue and enqueue", async (db) => {
const promise = deferred();
let dequeuedMessage: unknown = null;
const listener = db.listenQueue((msg) => {
dequeuedMessage = msg;
promise.resolve();
});
try {
const res = await db.enqueue("test");
assert(res.ok);
assertNotEquals(res.versionstamp, null);
await promise;
assertEquals(dequeuedMessage, "test");
} finally {
db.close();
await listener;
}
});
for (const { name, value } of VALUE_CASES) {
queueTest(`listenQueue and enqueue ${name}`, async (db) => {
const numEnqueues = 10;
let count = 0;
const promises: Deferred<void>[] = [];
const dequeuedMessages: unknown[] = [];
const listeners: Promise<void>[] = [];
listeners.push(db.listenQueue((msg) => {
dequeuedMessages.push(msg);
promises[count++].resolve();
}));
try {
for (let i = 0; i < numEnqueues; i++) {
promises.push(deferred());
await db.enqueue(value);
}
for (let i = 0; i < numEnqueues; i++) {
await promises[i];
}
for (let i = 0; i < numEnqueues; i++) {
assertEquals(dequeuedMessages[i], value);
}
} finally {
db.close();
for (const listener of listeners) {
await listener;
}
}
});
}
queueTest("queue mixed types", async (db) => {
let promise: Deferred<void>;
let dequeuedMessage: unknown = null;
const listener = db.listenQueue((msg) => {
dequeuedMessage = msg;
promise.resolve();
});
try {
for (const item of VALUE_CASES) {
promise = deferred();
await db.enqueue(item.value);
await promise;
assertEquals(dequeuedMessage, item.value);
}
} finally {
db.close();
await listener;
}
});
queueTest("queue delay", async (db) => {
let dequeueTime: number | undefined;
const promise = deferred();
let dequeuedMessage: unknown = null;
const listener = db.listenQueue((msg) => {
dequeueTime = Date.now();
dequeuedMessage = msg;
promise.resolve();
});
try {
const enqueueTime = Date.now();
await db.enqueue("test", { delay: 1000 });
await promise;
assertEquals(dequeuedMessage, "test");
assert(dequeueTime !== undefined);
assert(dequeueTime - enqueueTime >= 1000);
} finally {
db.close();
await listener;
}
});
queueTest("queue delay with atomic", async (db) => {
let dequeueTime: number | undefined;
const promise = deferred();
let dequeuedMessage: unknown = null;
const listener = db.listenQueue((msg) => {
dequeueTime = Date.now();
dequeuedMessage = msg;
promise.resolve();
});
try {
const enqueueTime = Date.now();
const res = await db.atomic()
.enqueue("test", { delay: 1000 })
.commit();
assert(res.ok);
await promise;
assertEquals(dequeuedMessage, "test");
assert(dequeueTime !== undefined);
assert(dequeueTime - enqueueTime >= 1000);
} finally {
db.close();
await listener;
}
});
queueTest("queue delay and now", async (db) => {
let count = 0;
let dequeueTime: number | undefined;
const promise = deferred();
let dequeuedMessage: unknown = null;
const listener = db.listenQueue((msg) => {
count += 1;
if (count == 2) {
dequeueTime = Date.now();
dequeuedMessage = msg;
promise.resolve();
}
});
try {
const enqueueTime = Date.now();
await db.enqueue("test-1000", { delay: 1000 });
await db.enqueue("test");
await promise;
assertEquals(dequeuedMessage, "test-1000");
assert(dequeueTime !== undefined);
assert(dequeueTime - enqueueTime >= 1000);
} finally {
db.close();
await listener;
}
});
dbTest("queue negative delay", async (db) => {
await assertRejects(async () => {
await db.enqueue("test", { delay: -100 });
}, TypeError);
});
dbTest("queue nan delay", async (db) => {
await assertRejects(async () => {
await db.enqueue("test", { delay: Number.NaN });
}, TypeError);
});
dbTest("queue large delay", async (db) => {
await db.enqueue("test", { delay: 7 * 24 * 60 * 60 * 1000 });
await assertRejects(async () => {
await db.enqueue("test", { delay: 7 * 24 * 60 * 60 * 1000 + 1 });
}, TypeError);
});
queueTest("listenQueue with async callback", async (db) => {
const promise = deferred();
let dequeuedMessage: unknown = null;
const listener = db.listenQueue(async (msg) => {
dequeuedMessage = msg;
await sleep(100);
promise.resolve();
});
try {
await db.enqueue("test");
await promise;
assertEquals(dequeuedMessage, "test");
} finally {
db.close();
await listener;
}
});
queueTest("queue retries", async (db) => {
let count = 0;
const listener = db.listenQueue(async (_msg) => {
count += 1;
await sleep(10);
throw new TypeError("dequeue error");
});
try {
await db.enqueue("test");
await sleep(10000);
} finally {
db.close();
await listener;
}
// There should have been 1 attempt + 3 retries in the 10 seconds
assertEquals(4, count);
});
queueTest("multiple listenQueues", async (db) => {
const numListens = 10;
let count = 0;
const promises: Deferred<void>[] = [];
const dequeuedMessages: unknown[] = [];
const listeners: Promise<void>[] = [];
for (let i = 0; i < numListens; i++) {
listeners.push(db.listenQueue((msg) => {
dequeuedMessages.push(msg);
promises[count++].resolve();
}));
}
try {
for (let i = 0; i < numListens; i++) {
promises.push(deferred());
await db.enqueue("msg_" + i);
await promises[i];
const msg = dequeuedMessages[i];
assertEquals("msg_" + i, msg);
}
} finally {
db.close();
for (let i = 0; i < numListens; i++) {
await listeners[i];
}
}
});
queueTest("enqueue with atomic", async (db) => {
const promise = deferred();
let dequeuedMessage: unknown = null;
const listener = db.listenQueue((msg) => {
dequeuedMessage = msg;
promise.resolve();
});
try {
await db.set(["t"], "1");
let currentValue = await db.get(["t"]);
assertEquals("1", currentValue.value);
const res = await db.atomic()
.check(currentValue)
.set(currentValue.key, "2")
.enqueue("test")
.commit();
assert(res.ok);
await promise;
assertEquals("test", dequeuedMessage);
currentValue = await db.get(["t"]);
assertEquals("2", currentValue.value);
} finally {
db.close();
await listener;
}
});
queueTest("enqueue with atomic nonce", async (db) => {
const promise = deferred();
let dequeuedMessage: unknown = null;
const nonce = crypto.randomUUID();
const listener = db.listenQueue(async (val) => {
const message = val as { msg: string; nonce: string };
const nonce = message.nonce;
const nonceValue = await db.get(["nonces", nonce]);
if (nonceValue.versionstamp === null) {
dequeuedMessage = message.msg;
promise.resolve();
return;
}
assertNotEquals(nonceValue.versionstamp, null);
const res = await db.atomic()
.check(nonceValue)
.delete(["nonces", nonce])
.set(["a", "b"], message.msg)
.commit();
if (res.ok) {
// Simulate an error so that the message has to be redelivered
throw new Error("injected error");
}
});
try {
const res = await db.atomic()
.check({ key: ["nonces", nonce], versionstamp: null })
.set(["nonces", nonce], true)
.enqueue({ msg: "test", nonce })
.commit();
assert(res.ok);
await promise;
assertEquals("test", dequeuedMessage);
const currentValue = await db.get(["a", "b"]);
assertEquals("test", currentValue.value);
const nonceValue = await db.get(["nonces", nonce]);
assertEquals(nonceValue.versionstamp, null);
} finally {
db.close();
await listener;
}
});
Deno.test({
name: "queue persistence with inflight messages",
sanitizeOps: false,
sanitizeResources: false,
async fn() {
const filename = "cli/tests/testdata/queue.db";
try {
await Deno.remove(filename);
} catch {
// pass
}
try {
let db: Deno.Kv = await Deno.openKv(filename);
let count = 0;
let promise = deferred();
// Register long-running handler.
let listener = db.listenQueue(async (_msg) => {
count += 1;
if (count == 3) {
promise.resolve();
}
await sleep(60000);
});
// Enqueue 3 messages.
await db.enqueue("msg0");
await db.enqueue("msg1");
await db.enqueue("msg2");
await promise;
// Close the database and wait for the listerner to finish.
db.close();
await listener;
// Now reopen the database.
db = await Deno.openKv(filename);
count = 0;
promise = deferred();
// Register a handler that will complete quickly.
listener = db.listenQueue((_msg) => {
count += 1;
if (count == 3) {
promise.resolve();
}
});
// Wait for the handlers to finish.
await promise;
assertEquals(3, count);
db.close();
await listener;
} finally {
await Deno.remove(filename);
}
},
});
Deno.test({
name: "queue persistence with delay messages",
sanitizeOps: false,
sanitizeResources: false,
async fn() {
const filename = "cli/tests/testdata/queue.db";
try {
await Deno.remove(filename);
} catch {
// pass
}
try {
let db: Deno.Kv = await Deno.openKv(filename);
let count = 0;
let promise = deferred();
// Register long-running handler.
let listener = db.listenQueue((_msg) => {});
// Enqueue 3 messages into the future.
await db.enqueue("msg0", { delay: 10000 });
await db.enqueue("msg1", { delay: 10000 });
await db.enqueue("msg2", { delay: 10000 });
// Close the database and wait for the listerner to finish.
db.close();
await listener;
// Now reopen the database.
db = await Deno.openKv(filename);
count = 0;
promise = deferred();
// Register a handler that will complete quickly.
listener = db.listenQueue((_msg) => {
count += 1;
if (count == 3) {
promise.resolve();
}
});
// Wait for the handlers to finish.
await promise;
assertEquals(3, count);
db.close();
await listener;
} finally {
await Deno.remove(filename);
}
},
});

View file

@ -1914,6 +1914,14 @@ declare namespace Deno {
* checks pass during the commit.
*/
delete(key: KvKey): this;
/**
* Add to the operation a mutation that enqueues a value into the queue
* if all checks pass during the commit.
*/
enqueue(
value: unknown,
options?: { delay?: number; keysIfUndelivered?: Deno.KvKey[] },
): this;
/**
* Commit the operation to the KV store. Returns a value indicating whether
* checks passed and mutations were performed. If the operation failed
@ -2087,6 +2095,57 @@ declare namespace Deno {
options?: KvListOptions,
): KvListIterator<T>;
/**
* Add a value into the database queue to be delivered to the queue
* listener via {@linkcode Deno.Kv.listenQueue}.
*
* ```ts
* const db = await Deno.openKv();
* await db.enqueue("bar");
* ```
*
* The `delay` option can be used to specify the delay (in milliseconds)
* of the value delivery. The default delay is 0, which means immediate
* delivery.
*
* ```ts
* const db = await Deno.openKv();
* await db.enqueue("bar", { delay: 60000 });
* ```
*
* The `keysIfUndelivered` option can be used to specify the keys to
* be set if the value is not successfully delivered to the queue
* listener after several attempts. The values are set to the value of
* the queued message.
*
* ```ts
* const db = await Deno.openKv();
* await db.enqueue("bar", { keysIfUndelivered: [["foo", "bar"]] });
* ```
*/
enqueue(
value: unknown,
options?: { delay?: number; keysIfUndelivered?: Deno.KvKey[] },
): Promise<KvCommitResult>;
/**
* Listen for queue values to be delivered from the database queue, which
* were enqueued with {@linkcode Deno.Kv.enqueue}. The provided handler
* callback is invoked on every dequeued value. A failed callback
* invocation is automatically retried multiple times until it succeeds
* or until the maximum number of retries is reached.
*
* ```ts
* const db = await Deno.openKv();
* db.listenQueue(async (msg: unknown) => {
* await db.set(["foo"], msg);
* });
* ```
*/
listenQueue(
handler: (value: unknown) => Promise<void> | void,
): Promise<void>;
/**
* Create a new {@linkcode Deno.AtomicOperation} object which can be used to
* perform an atomic transaction on the database. This does not perform any

View file

@ -26,6 +26,20 @@ async function openKv(path: string) {
return new Kv(rid, kvSymbol);
}
const millisecondsInOneWeek = 7 * 24 * 60 * 60 * 1000;
function validateQueueDelay(delay: number) {
if (delay < 0) {
throw new TypeError("delay cannot be negative");
}
if (delay > millisecondsInOneWeek) {
throw new TypeError("delay cannot be greater than one week");
}
if (isNaN(delay)) {
throw new TypeError("delay cannot be NaN");
}
}
interface RawKvEntry {
key: Deno.KvKey;
value: RawValue;
@ -47,6 +61,7 @@ const kvSymbol = Symbol("KvRid");
class Kv {
#rid: number;
#closed: boolean;
constructor(rid: number = undefined, symbol: symbol = undefined) {
if (kvSymbol !== symbol) {
@ -55,6 +70,7 @@ class Kv {
);
}
this.#rid = rid;
this.#closed = false;
}
atomic() {
@ -203,8 +219,82 @@ class Kv {
};
}
async enqueue(
message: unknown,
opts?: { delay?: number; keysIfUndelivered?: Deno.KvKey[] },
) {
if (opts?.delay !== undefined) {
validateQueueDelay(opts?.delay);
}
const enqueues = [
[
core.serialize(message, { forStorage: true }),
opts?.delay ?? 0,
opts?.keysIfUndelivered ?? [],
null,
],
];
const versionstamp = await core.opAsync(
"op_kv_atomic_write",
this.#rid,
[],
[],
enqueues,
);
if (versionstamp === null) throw new TypeError("Failed to enqueue value");
return { ok: true, versionstamp };
}
async listenQueue(
handler: (message: unknown) => Promise<void> | void,
): Promise<void> {
while (!this.#closed) {
// Wait for the next message.
let next: { 0: Uint8Array; 1: number };
try {
next = await core.opAsync(
"op_kv_dequeue_next_message",
this.#rid,
);
} catch (error) {
if (this.#closed) {
break;
} else {
throw error;
}
}
// Deserialize the payload.
const { 0: payload, 1: handleId } = next;
const deserializedPayload = core.deserialize(payload, {
forStorage: true,
});
// Dispatch the payload.
(async () => {
let success = false;
try {
const result = handler(deserializedPayload);
const _res = result instanceof Promise ? (await result) : result;
success = true;
} catch (error) {
console.error("Exception in queue handler", error);
} finally {
await core.opAsync(
"op_kv_finish_dequeued_message",
handleId,
success,
);
}
})();
}
}
close() {
core.close(this.#rid);
this.#closed = true;
}
}
@ -213,6 +303,7 @@ class AtomicOperation {
#checks: [Deno.KvKey, string | null][] = [];
#mutations: [Deno.KvKey, string, RawValue | null][] = [];
#enqueues: [Uint8Array, number, Deno.KvKey[], number[] | null][] = [];
constructor(rid: number) {
this.#rid = rid;
@ -280,13 +371,29 @@ class AtomicOperation {
return this;
}
enqueue(
message: unknown,
opts?: { delay?: number; keysIfUndelivered?: Deno.KvKey[] },
): this {
if (opts?.delay !== undefined) {
validateQueueDelay(opts?.delay);
}
this.#enqueues.push([
core.serialize(message, { forStorage: true }),
opts?.delay ?? 0,
opts?.keysIfUndelivered ?? [],
null,
]);
return this;
}
async commit(): Promise<Deno.KvCommitResult | Deno.KvCommitError> {
const versionstamp = await core.opAsync(
"op_kv_atomic_write",
this.#rid,
this.#checks,
this.#mutations,
[], // TODO(@losfair): enqueue
this.#enqueues,
);
if (versionstamp === null) return { ok: false };
return { ok: true, versionstamp };

View file

@ -20,5 +20,9 @@ base64.workspace = true
deno_core.workspace = true
hex.workspace = true
num-bigint.workspace = true
rand.workspace = true
rusqlite.workspace = true
serde.workspace = true
serde_json.workspace = true
tokio.workspace = true
uuid.workspace = true

View file

@ -25,6 +25,8 @@ pub trait DatabaseHandler {
#[async_trait(?Send)]
pub trait Database {
type QMH: QueueMessageHandle + 'static;
async fn snapshot_read(
&self,
requests: Vec<ReadRange>,
@ -35,6 +37,16 @@ pub trait Database {
&self,
write: AtomicWrite,
) -> Result<Option<CommitResult>, AnyError>;
async fn dequeue_next_message(&self) -> Result<Self::QMH, AnyError>;
fn close(&self);
}
#[async_trait(?Send)]
pub trait QueueMessageHandle {
async fn take_payload(&mut self) -> Result<Vec<u8>, AnyError>;
async fn finish(&self, success: bool) -> Result<(), AnyError>;
}
/// Options for a snapshot read.
@ -242,7 +254,7 @@ pub struct KvMutation {
/// keys specified in `keys_if_undelivered`.
pub struct Enqueue {
pub payload: Vec<u8>,
pub deadline_ms: u64,
pub delay_ms: u64,
pub keys_if_undelivered: Vec<Vec<u8>>,
pub backoff_schedule: Option<Vec<u32>>,
}

View file

@ -8,6 +8,7 @@ use std::borrow::Cow;
use std::cell::RefCell;
use std::num::NonZeroU32;
use std::rc::Rc;
use std::vec;
use codec::decode_key;
use codec::encode_key;
@ -60,6 +61,8 @@ deno_core::extension!(deno_kv,
op_kv_snapshot_read<DBH>,
op_kv_atomic_write<DBH>,
op_kv_encode_cursor,
op_kv_dequeue_next_message<DBH>,
op_kv_finish_dequeued_message<DBH>,
],
esm = [ "01_db.ts" ],
options = {
@ -80,6 +83,10 @@ impl<DB: Database + 'static> Resource for DatabaseResource<DB> {
fn name(&self) -> Cow<str> {
"database".into()
}
fn close(self: Rc<Self>) {
self.db.close();
}
}
#[op]
@ -280,6 +287,62 @@ where
Ok(output_ranges)
}
struct QueueMessageResource<QPH: QueueMessageHandle + 'static> {
handle: QPH,
}
impl<QMH: QueueMessageHandle + 'static> Resource for QueueMessageResource<QMH> {
fn name(&self) -> Cow<str> {
"queue_message".into()
}
}
#[op]
async fn op_kv_dequeue_next_message<DBH>(
state: Rc<RefCell<OpState>>,
rid: ResourceId,
) -> Result<(ZeroCopyBuf, ResourceId), AnyError>
where
DBH: DatabaseHandler + 'static,
{
let db = {
let state = state.borrow();
let resource =
state.resource_table.get::<DatabaseResource<DBH::DB>>(rid)?;
resource.db.clone()
};
let mut handle = db.dequeue_next_message().await?;
let payload = handle.take_payload().await?.into();
let handle_rid = {
let mut state = state.borrow_mut();
state.resource_table.add(QueueMessageResource { handle })
};
Ok((payload, handle_rid))
}
#[op]
async fn op_kv_finish_dequeued_message<DBH>(
state: Rc<RefCell<OpState>>,
handle_rid: ResourceId,
success: bool,
) -> Result<(), AnyError>
where
DBH: DatabaseHandler + 'static,
{
let handle = {
let mut state = state.borrow_mut();
let handle = state
.resource_table
.take::<QueueMessageResource<<<DBH>::DB as Database>::QMH>>(handle_rid)
.map_err(|_| type_error("Queue message not found"))?;
Rc::try_unwrap(handle)
.map_err(|_| type_error("Queue message not found"))?
.handle
};
handle.finish(success).await
}
type V8KvCheck = (KvKey, Option<ByteString>);
impl TryFrom<V8KvCheck> for KvCheck {
@ -333,7 +396,7 @@ impl TryFrom<V8Enqueue> for Enqueue {
fn try_from(value: V8Enqueue) -> Result<Self, AnyError> {
Ok(Enqueue {
payload: value.0.to_vec(),
deadline_ms: value.1,
delay_ms: value.1,
keys_if_undelivered: value
.2
.into_iter()

View file

@ -7,10 +7,17 @@ use std::marker::PhantomData;
use std::path::Path;
use std::path::PathBuf;
use std::rc::Rc;
use std::rc::Weak;
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
use async_trait::async_trait;
use deno_core::error::type_error;
use deno_core::error::AnyError;
use deno_core::futures;
use deno_core::futures::FutureExt;
use deno_core::task::spawn;
use deno_core::task::spawn_blocking;
use deno_core::AsyncRefCell;
use deno_core::OpState;
@ -18,6 +25,12 @@ use rusqlite::params;
use rusqlite::OpenFlags;
use rusqlite::OptionalExtension;
use rusqlite::Transaction;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio::sync::OnceCell;
use tokio::sync::OwnedSemaphorePermit;
use tokio::sync::Semaphore;
use uuid::Uuid;
use crate::AtomicWrite;
use crate::CommitResult;
@ -25,6 +38,7 @@ use crate::Database;
use crate::DatabaseHandler;
use crate::KvEntry;
use crate::MutationKind;
use crate::QueueMessageHandle;
use crate::ReadRange;
use crate::ReadRangeOutput;
use crate::SnapshotReadOptions;
@ -44,6 +58,18 @@ const STATEMENT_KV_POINT_SET: &str =
"insert into kv (k, v, v_encoding, version) values (:k, :v, :v_encoding, :version) on conflict(k) do update set v = :v, v_encoding = :v_encoding, version = :version";
const STATEMENT_KV_POINT_DELETE: &str = "delete from kv where k = ?";
const STATEMENT_QUEUE_ADD_READY: &str = "insert into queue (ts, id, data, backoff_schedule, keys_if_undelivered) values(?, ?, ?, ?, ?)";
const STATEMENT_QUEUE_GET_NEXT_READY: &str = "select ts, id, data, backoff_schedule, keys_if_undelivered from queue where ts <= ? order by ts limit 100";
const STATEMENT_QUEUE_GET_EARLIEST_READY: &str =
"select ts from queue order by ts limit 1";
const STATEMENT_QUEUE_REMOVE_READY: &str = "delete from queue where id = ?";
const STATEMENT_QUEUE_ADD_RUNNING: &str = "insert into queue_running (deadline, id, data, backoff_schedule, keys_if_undelivered) values(?, ?, ?, ?, ?)";
const STATEMENT_QUEUE_REMOVE_RUNNING: &str =
"delete from queue_running where id = ?";
const STATEMENT_QUEUE_GET_RUNNING_BY_ID: &str = "select deadline, id, data, backoff_schedule, keys_if_undelivered from queue_running where id = ?";
const STATEMENT_QUEUE_GET_RUNNING: &str =
"select id from queue_running order by deadline limit 100";
const STATEMENT_CREATE_MIGRATION_TABLE: &str = "
create table if not exists migration_state(
k integer not null primary key,
@ -87,6 +113,9 @@ create table queue_running(
",
];
const DISPATCH_CONCURRENCY_LIMIT: usize = 100;
const DEFAULT_BACKOFF_SCHEDULE: [u32; 5] = [100, 1000, 5000, 30000, 60000];
pub struct SqliteDbHandler<P: SqliteDbHandlerPermissions + 'static> {
pub default_storage_dir: Option<PathBuf>,
_permissions: PhantomData<P>,
@ -182,14 +211,23 @@ impl<P: SqliteDbHandlerPermissions> DatabaseHandler for SqliteDbHandler<P> {
.await
.unwrap()?;
Ok(SqliteDb(Rc::new(AsyncRefCell::new(Cell::new(Some(conn))))))
Ok(SqliteDb {
conn: Rc::new(AsyncRefCell::new(Cell::new(Some(conn)))),
queue: OnceCell::new(),
})
}
}
pub struct SqliteDb(Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>);
pub struct SqliteDb {
conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>,
queue: OnceCell<SqliteQueue>,
}
impl SqliteDb {
async fn run_tx<F, R>(&self, f: F) -> Result<R, AnyError>
async fn run_tx<F, R>(
conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>,
f: F,
) -> Result<R, AnyError>
where
F: (FnOnce(rusqlite::Transaction<'_>) -> Result<R, AnyError>)
+ Send
@ -198,7 +236,7 @@ impl SqliteDb {
{
// Transactions need exclusive access to the connection. Wait until
// we can borrow_mut the connection.
let cell = self.0.borrow_mut().await;
let cell = conn.borrow_mut().await;
// Take the db out of the cell and run the transaction via spawn_blocking.
let mut db = cell.take().unwrap();
@ -220,59 +258,372 @@ impl SqliteDb {
}
}
pub struct DequeuedMessage {
conn: Weak<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>,
id: String,
payload: Option<Vec<u8>>,
waker_tx: mpsc::Sender<()>,
_permit: OwnedSemaphorePermit,
}
#[async_trait(?Send)]
impl QueueMessageHandle for DequeuedMessage {
async fn finish(&self, success: bool) -> Result<(), AnyError> {
let Some(conn) = self.conn.upgrade() else {
return Ok(());
};
let id = self.id.clone();
let requeued = SqliteDb::run_tx(conn, move |tx| {
let requeued = {
if success {
let changed = tx
.prepare_cached(STATEMENT_QUEUE_REMOVE_RUNNING)?
.execute([&id])?;
assert!(changed <= 1);
false
} else {
SqliteQueue::requeue_message(&id, &tx)?
}
};
tx.commit()?;
Ok(requeued)
})
.await?;
if requeued {
// If the message was requeued, wake up the dequeue loop.
self.waker_tx.send(()).await?;
}
Ok(())
}
async fn take_payload(&mut self) -> Result<Vec<u8>, AnyError> {
self
.payload
.take()
.ok_or_else(|| type_error("Payload already consumed"))
}
}
type DequeueReceiver = mpsc::Receiver<(Vec<u8>, String)>;
struct SqliteQueue {
conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>,
dequeue_rx: Rc<AsyncRefCell<DequeueReceiver>>,
concurrency_limiter: Arc<Semaphore>,
waker_tx: mpsc::Sender<()>,
shutdown_tx: watch::Sender<()>,
}
impl SqliteQueue {
fn new(conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>) -> Self {
let conn_clone = conn.clone();
let (shutdown_tx, shutdown_rx) = watch::channel::<()>(());
let (waker_tx, waker_rx) = mpsc::channel::<()>(1);
let (dequeue_tx, dequeue_rx) = mpsc::channel::<(Vec<u8>, String)>(64);
spawn(async move {
// Oneshot requeue of all inflight messages.
Self::requeue_inflight_messages(conn.clone()).await.unwrap();
// Continous dequeue loop.
Self::dequeue_loop(conn.clone(), dequeue_tx, shutdown_rx, waker_rx)
.await
.unwrap();
});
Self {
conn: conn_clone,
dequeue_rx: Rc::new(AsyncRefCell::new(dequeue_rx)),
waker_tx,
shutdown_tx,
concurrency_limiter: Arc::new(Semaphore::new(DISPATCH_CONCURRENCY_LIMIT)),
}
}
async fn dequeue(&self) -> Result<DequeuedMessage, AnyError> {
// Wait for the next message to be available from dequeue_rx.
let (payload, id) = {
let mut queue_rx = self.dequeue_rx.borrow_mut().await;
let Some(msg) = queue_rx.recv().await else {
return Err(type_error("Database closed"));
};
msg
};
let permit = self.concurrency_limiter.clone().acquire_owned().await?;
Ok(DequeuedMessage {
conn: Rc::downgrade(&self.conn),
id,
payload: Some(payload),
waker_tx: self.waker_tx.clone(),
_permit: permit,
})
}
async fn wake(&self) -> Result<(), AnyError> {
self.waker_tx.send(()).await?;
Ok(())
}
fn shutdown(&self) {
self.shutdown_tx.send(()).unwrap();
}
async fn dequeue_loop(
conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>,
dequeue_tx: mpsc::Sender<(Vec<u8>, String)>,
mut shutdown_rx: watch::Receiver<()>,
mut waker_rx: mpsc::Receiver<()>,
) -> Result<(), AnyError> {
loop {
let messages = SqliteDb::run_tx(conn.clone(), move |tx| {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let messages = tx
.prepare_cached(STATEMENT_QUEUE_GET_NEXT_READY)?
.query_map([now], |row| {
let ts: u64 = row.get(0)?;
let id: String = row.get(1)?;
let data: Vec<u8> = row.get(2)?;
let backoff_schedule: String = row.get(3)?;
let keys_if_undelivered: String = row.get(4)?;
Ok((ts, id, data, backoff_schedule, keys_if_undelivered))
})?
.collect::<Result<Vec<_>, rusqlite::Error>>()?;
for (ts, id, data, backoff_schedule, keys_if_undelivered) in &messages {
let changed = tx
.prepare_cached(STATEMENT_QUEUE_REMOVE_READY)?
.execute(params![id])?;
assert_eq!(changed, 1);
let changed =
tx.prepare_cached(STATEMENT_QUEUE_ADD_RUNNING)?.execute(
params![ts, id, &data, &backoff_schedule, &keys_if_undelivered],
)?;
assert_eq!(changed, 1);
}
tx.commit()?;
Ok(
messages
.into_iter()
.map(|(_, id, data, _, _)| (id, data))
.collect::<Vec<_>>(),
)
})
.await?;
let busy = !messages.is_empty();
for (id, data) in messages {
if dequeue_tx.send((data, id)).await.is_err() {
// Queue receiver was dropped. Stop the dequeue loop.
return Ok(());
}
}
if !busy {
// There's nothing to dequeue right now; sleep until one of the
// following happens:
// - It's time to dequeue the next message based on its timestamp
// - A new message is added to the queue
// - The database is closed
let sleep_fut = {
match Self::get_earliest_ready_ts(conn.clone()).await? {
Some(ts) => {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
if ts <= now {
continue;
}
tokio::time::sleep(Duration::from_millis(ts - now)).boxed()
}
None => futures::future::pending().boxed(),
}
};
tokio::select! {
_ = sleep_fut => {}
_ = waker_rx.recv() => {}
_ = shutdown_rx.changed() => return Ok(())
}
}
}
}
async fn get_earliest_ready_ts(
conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>,
) -> Result<Option<u64>, AnyError> {
SqliteDb::run_tx(conn.clone(), move |tx| {
let ts = tx
.prepare_cached(STATEMENT_QUEUE_GET_EARLIEST_READY)?
.query_row([], |row| {
let ts: u64 = row.get(0)?;
Ok(ts)
})
.optional()?;
Ok(ts)
})
.await
}
async fn requeue_inflight_messages(
conn: Rc<AsyncRefCell<Cell<Option<rusqlite::Connection>>>>,
) -> Result<(), AnyError> {
loop {
let done = SqliteDb::run_tx(conn.clone(), move |tx| {
let entries = tx
.prepare_cached(STATEMENT_QUEUE_GET_RUNNING)?
.query_map([], |row| {
let id: String = row.get(0)?;
Ok(id)
})?
.collect::<Result<Vec<_>, rusqlite::Error>>()?;
for id in &entries {
Self::requeue_message(id, &tx)?;
}
tx.commit()?;
Ok(entries.is_empty())
})
.await?;
if done {
return Ok(());
}
}
}
fn requeue_message(
id: &str,
tx: &rusqlite::Transaction<'_>,
) -> Result<bool, AnyError> {
let Some((_, id, data, backoff_schedule, keys_if_undelivered)) = tx
.prepare_cached(STATEMENT_QUEUE_GET_RUNNING_BY_ID)?
.query_row([id], |row| {
let deadline: u64 = row.get(0)?;
let id: String = row.get(1)?;
let data: Vec<u8> = row.get(2)?;
let backoff_schedule: String = row.get(3)?;
let keys_if_undelivered: String = row.get(4)?;
Ok((deadline, id, data, backoff_schedule, keys_if_undelivered))
})
.optional()? else {
return Ok(false);
};
let backoff_schedule = {
let backoff_schedule =
serde_json::from_str::<Option<Vec<u64>>>(&backoff_schedule)?;
backoff_schedule.unwrap_or_default()
};
let mut requeued = false;
if !backoff_schedule.is_empty() {
// Requeue based on backoff schedule
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let new_ts = now + backoff_schedule[0];
let new_backoff_schedule = serde_json::to_string(&backoff_schedule[1..])?;
let changed = tx
.prepare_cached(STATEMENT_QUEUE_ADD_READY)?
.execute(params![
new_ts,
id,
&data,
&new_backoff_schedule,
&keys_if_undelivered
])
.unwrap();
assert_eq!(changed, 1);
requeued = true;
} else if !keys_if_undelivered.is_empty() {
// No more requeues. Insert the message into the undelivered queue.
let keys_if_undelivered =
serde_json::from_str::<Vec<Vec<u8>>>(&keys_if_undelivered)?;
let version: i64 = tx
.prepare_cached(STATEMENT_INC_AND_GET_DATA_VERSION)?
.query_row([], |row| row.get(0))?;
for key in keys_if_undelivered {
let changed = tx
.prepare_cached(STATEMENT_KV_POINT_SET)?
.execute(params![key, &data, &VALUE_ENCODING_V8, &version])?;
assert_eq!(changed, 1);
}
}
// Remove from running
let changed = tx
.prepare_cached(STATEMENT_QUEUE_REMOVE_RUNNING)?
.execute(params![id])?;
assert_eq!(changed, 1);
Ok(requeued)
}
}
#[async_trait(?Send)]
impl Database for SqliteDb {
type QMH = DequeuedMessage;
async fn snapshot_read(
&self,
requests: Vec<ReadRange>,
_options: SnapshotReadOptions,
) -> Result<Vec<ReadRangeOutput>, AnyError> {
self
.run_tx(move |tx| {
let mut responses = Vec::with_capacity(requests.len());
for request in requests {
let mut stmt = tx.prepare_cached(if request.reverse {
STATEMENT_KV_RANGE_SCAN_REVERSE
} else {
STATEMENT_KV_RANGE_SCAN
})?;
let entries = stmt
.query_map(
(
request.start.as_slice(),
request.end.as_slice(),
request.limit.get(),
),
|row| {
let key: Vec<u8> = row.get(0)?;
let value: Vec<u8> = row.get(1)?;
let encoding: i64 = row.get(2)?;
Self::run_tx(self.conn.clone(), move |tx| {
let mut responses = Vec::with_capacity(requests.len());
for request in requests {
let mut stmt = tx.prepare_cached(if request.reverse {
STATEMENT_KV_RANGE_SCAN_REVERSE
} else {
STATEMENT_KV_RANGE_SCAN
})?;
let entries = stmt
.query_map(
(
request.start.as_slice(),
request.end.as_slice(),
request.limit.get(),
),
|row| {
let key: Vec<u8> = row.get(0)?;
let value: Vec<u8> = row.get(1)?;
let encoding: i64 = row.get(2)?;
let value = decode_value(value, encoding);
let value = decode_value(value, encoding);
let version: i64 = row.get(3)?;
Ok(KvEntry {
key,
value,
versionstamp: version_to_versionstamp(version),
})
},
)?
.collect::<Result<Vec<_>, rusqlite::Error>>()?;
responses.push(ReadRangeOutput { entries });
}
let version: i64 = row.get(3)?;
Ok(KvEntry {
key,
value,
versionstamp: version_to_versionstamp(version),
})
},
)?
.collect::<Result<Vec<_>, rusqlite::Error>>()?;
responses.push(ReadRangeOutput { entries });
}
Ok(responses)
})
.await
Ok(responses)
})
.await
}
async fn atomic_write(
&self,
write: AtomicWrite,
) -> Result<Option<CommitResult>, AnyError> {
self
.run_tx(move |tx| {
let (has_enqueues, commit_result) =
Self::run_tx(self.conn.clone(), move |tx| {
for check in write.checks {
let real_versionstamp = tx
.prepare_cached(STATEMENT_KV_POINT_GET_VERSION_ONLY)?
@ -280,7 +631,7 @@ impl Database for SqliteDb {
.optional()?
.map(version_to_versionstamp);
if real_versionstamp != check.versionstamp {
return Ok(None);
return Ok((false, None));
}
}
@ -336,17 +687,67 @@ impl Database for SqliteDb {
}
}
// TODO(@losfair): enqueues
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let has_enqueues = !write.enqueues.is_empty();
for enqueue in write.enqueues {
let id = Uuid::new_v4().to_string();
let backoff_schedule = serde_json::to_string(
&enqueue
.backoff_schedule
.or_else(|| Some(DEFAULT_BACKOFF_SCHEDULE.to_vec())),
)?;
let keys_if_undelivered =
serde_json::to_string(&enqueue.keys_if_undelivered)?;
let changed =
tx.prepare_cached(STATEMENT_QUEUE_ADD_READY)?
.execute(params![
now + enqueue.delay_ms,
id,
&enqueue.payload,
&backoff_schedule,
&keys_if_undelivered
])?;
assert_eq!(changed, 1)
}
tx.commit()?;
let new_vesionstamp = version_to_versionstamp(version);
Ok(Some(CommitResult {
versionstamp: new_vesionstamp,
}))
Ok((
has_enqueues,
Some(CommitResult {
versionstamp: new_vesionstamp,
}),
))
})
.await
.await?;
if has_enqueues {
if let Some(queue) = self.queue.get() {
queue.wake().await?;
}
}
Ok(commit_result)
}
async fn dequeue_next_message(&self) -> Result<Self::QMH, AnyError> {
let queue = self
.queue
.get_or_init(|| async move { SqliteQueue::new(self.conn.clone()) })
.await;
let handle = queue.dequeue().await?;
Ok(handle)
}
fn close(&self) {
if let Some(queue) = self.queue.get() {
queue.shutdown();
}
}
}