1
use std::borrow::Cow;
2
use std::collections::{btree_map, BTreeMap, VecDeque};
3
use std::sync::{Arc, Weak};
4
use std::time::Duration;
5

            
6
use bonsaidb_core::connection::{Connection, HasSession};
7
use bonsaidb_core::keyvalue::{
8
    Command, KeyCheck, KeyOperation, KeyStatus, KeyValue, Numeric, Output, SetCommand, Timestamp,
9
    Value,
10
};
11
use bonsaidb_core::permissions::bonsai::{
12
    keyvalue_key_resource_name, BonsaiAction, DatabaseAction, KeyValueAction,
13
};
14
use bonsaidb_core::transaction::{ChangedKey, Changes};
15
use nebari::io::any::AnyFile;
16
use nebari::tree::{CompareSwap, Operation, Root, ScanEvaluation, Unversioned};
17
use nebari::{AbortError, ArcBytes, Roots};
18
use parking_lot::Mutex;
19
use serde::{Deserialize, Serialize};
20
use watchable::{Watchable, Watcher};
21

            
22
use crate::config::KeyValuePersistence;
23
use crate::database::compat;
24
use crate::storage::StorageLock;
25
use crate::tasks::{Job, Keyed, Task};
26
use crate::{Database, DatabaseNonBlocking, Error};
27

            
28
1461953
#[derive(Serialize, Deserialize, Debug, Clone)]
29
pub struct Entry {
30
    pub value: Value,
31
    pub expiration: Option<Timestamp>,
32
    #[serde(default)]
33
    pub last_updated: Timestamp,
34
}
35

            
36
impl Entry {
37
    pub(crate) fn restore(
38
        self,
39
        namespace: Option<String>,
40
        key: String,
41
        database: &Database,
42
    ) -> Result<(), bonsaidb_core::Error> {
43
3
        database.execute_key_operation(KeyOperation {
44
3
            namespace,
45
3
            key,
46
3
            command: Command::Set(SetCommand {
47
3
                value: self.value,
48
3
                expiration: self.expiration,
49
3
                keep_existing_expiration: false,
50
3
                check: None,
51
3
                return_previous_value: false,
52
3
            }),
53
3
        })?;
54
3
        Ok(())
55
3
    }
56
}
57

            
58
impl KeyValue for Database {
59
    fn execute_key_operation(&self, op: KeyOperation) -> Result<Output, bonsaidb_core::Error> {
60
1496676
        self.check_permission(
61
1496676
            keyvalue_key_resource_name(self.name(), op.namespace.as_deref(), &op.key),
62
1496676
            &BonsaiAction::Database(DatabaseAction::KeyValue(KeyValueAction::ExecuteOperation)),
63
1496676
        )?;
64
1496676
        self.data.context.perform_kv_operation(op)
65
1496676
    }
66
}
67

            
68
impl Database {
69
17540
    pub(crate) fn all_key_value_entries(
70
17540
        &self,
71
17540
    ) -> Result<BTreeMap<(Option<String>, String), Entry>, Error> {
72
17540
        // Lock the state so that new new modifications can be made while we gather this snapshot.
73
17540
        let state = self.data.context.key_value_state.lock();
74
17540
        let database = self.clone();
75
17540
        // Initialize our entries with any dirty keys and any keys that are about to be persisted.
76
17540
        let mut all_entries = BTreeMap::new();
77
17540
        database
78
17540
            .roots()
79
17540
            .tree(Unversioned::tree(KEY_TREE))?
80
17537
            .scan::<Error, _, _, _, _>(
81
17537
                &(..),
82
17537
                true,
83
17537
                |_, _, _| ScanEvaluation::ReadData,
84
17537
                |_, _| ScanEvaluation::ReadData,
85
17537
                |key, _, entry: ArcBytes<'static>| {
86
25
                    let entry = bincode::deserialize::<Entry>(&entry)
87
25
                        .map_err(|err| AbortError::Other(Error::from(err)))?;
88
25
                    let full_key = std::str::from_utf8(&key)
89
25
                        .map_err(|err| AbortError::Other(Error::from(err)))?;
90

            
91
25
                    if let Some(split_key) = split_key(full_key) {
92
25
                        // Do not overwrite the existing key
93
25
                        all_entries.entry(split_key).or_insert(entry);
94
25
                    }
95

            
96
25
                    Ok(())
97
17537
                },
98
17537
            )?;
99

            
100
        // Apply the pending writes first
101
17536
        if let Some(pending_keys) = &state.keys_being_persisted {
102
46
            for (key, possible_entry) in pending_keys.iter() {
103
46
                let (namespace, key) = split_key(key).unwrap();
104
46
                if let Some(updated_entry) = possible_entry {
105
44
                    all_entries.insert((namespace, key), updated_entry.clone());
106
44
                } else {
107
2
                    all_entries.remove(&(namespace, key));
108
2
                }
109
            }
110
17491
        }
111

            
112
17536
        for (key, possible_entry) in &state.dirty_keys {
113
12
            let (namespace, key) = split_key(key).unwrap();
114
12
            if let Some(updated_entry) = possible_entry {
115
11
                all_entries.insert((namespace, key), updated_entry.clone());
116
11
            } else {
117
1
                all_entries.remove(&(namespace, key));
118
1
            }
119
        }
120

            
121
17536
        Ok(all_entries)
122
17540
    }
123
}
124

            
125
pub(crate) const KEY_TREE: &str = "kv";
126

            
127
1496546
fn full_key(namespace: Option<&str>, key: &str) -> String {
128
1496546
    let full_length = namespace.map_or_else(|| 0, str::len) + key.len() + 1;
129
1496546
    let mut full_key = String::with_capacity(full_length);
130
1496546
    if let Some(ns) = namespace {
131
11925
        full_key.push_str(ns);
132
1484621
    }
133
1496546
    full_key.push('\0');
134
1496546
    full_key.push_str(key);
135
1496546
    full_key
136
1496546
}
137

            
138
fn split_key(full_key: &str) -> Option<(Option<String>, String)> {
139
98614
    if let Some((namespace, key)) = full_key.split_once('\0') {
140
98614
        let namespace = if namespace.is_empty() {
141
93468
            None
142
        } else {
143
5146
            Some(namespace.to_string())
144
        };
145
98614
        Some((namespace, key.to_string()))
146
    } else {
147
        None
148
    }
149
98614
}
150

            
151
1481804
fn increment(existing: &Numeric, amount: &Numeric, saturating: bool) -> Numeric {
152
1481804
    match amount {
153
444
        Numeric::Integer(amount) => {
154
444
            let existing_value = existing.as_i64_lossy(saturating);
155
444
            let new_value = if saturating {
156
296
                existing_value.saturating_add(*amount)
157
            } else {
158
148
                existing_value.wrapping_add(*amount)
159
            };
160
444
            Numeric::Integer(new_value)
161
        }
162
1480916
        Numeric::UnsignedInteger(amount) => {
163
1480916
            let existing_value = existing.as_u64_lossy(saturating);
164
1480916
            let new_value = if saturating {
165
1480768
                existing_value.saturating_add(*amount)
166
            } else {
167
148
                existing_value.wrapping_add(*amount)
168
            };
169
1480916
            Numeric::UnsignedInteger(new_value)
170
        }
171
444
        Numeric::Float(amount) => {
172
444
            let existing_value = existing.as_f64_lossy();
173
444
            let new_value = existing_value + *amount;
174
444
            Numeric::Float(new_value)
175
        }
176
    }
177
1481804
}
178

            
179
1332
fn decrement(existing: &Numeric, amount: &Numeric, saturating: bool) -> Numeric {
180
1332
    match amount {
181
444
        Numeric::Integer(amount) => {
182
444
            let existing_value = existing.as_i64_lossy(saturating);
183
444
            let new_value = if saturating {
184
296
                existing_value.saturating_sub(*amount)
185
            } else {
186
148
                existing_value.wrapping_sub(*amount)
187
            };
188
444
            Numeric::Integer(new_value)
189
        }
190
592
        Numeric::UnsignedInteger(amount) => {
191
592
            let existing_value = existing.as_u64_lossy(saturating);
192
592
            let new_value = if saturating {
193
296
                existing_value.saturating_sub(*amount)
194
            } else {
195
296
                existing_value.wrapping_sub(*amount)
196
            };
197
592
            Numeric::UnsignedInteger(new_value)
198
        }
199
296
        Numeric::Float(amount) => {
200
296
            let existing_value = existing.as_f64_lossy();
201
296
            let new_value = existing_value - *amount;
202
296
            Numeric::Float(new_value)
203
        }
204
    }
205
1332
}
206

            
207
#[derive(Debug)]
208
pub struct KeyValueState {
209
    roots: Roots<AnyFile>,
210
    persistence: KeyValuePersistence,
211
    last_commit: Timestamp,
212
    background_worker_target: Watchable<BackgroundWorkerProcessTarget>,
213
    expiring_keys: BTreeMap<String, Timestamp>,
214
    expiration_order: VecDeque<String>,
215
    dirty_keys: BTreeMap<String, Option<Entry>>,
216
    keys_being_persisted: Option<Arc<BTreeMap<String, Option<Entry>>>>,
217
    last_persistence: Watchable<Timestamp>,
218
    shutdown: Option<flume::Sender<()>>,
219
}
220

            
221
impl KeyValueState {
222
34893
    pub fn new(
223
34893
        persistence: KeyValuePersistence,
224
34893
        roots: Roots<AnyFile>,
225
34893
        background_worker_target: Watchable<BackgroundWorkerProcessTarget>,
226
34893
    ) -> Self {
227
34893
        Self {
228
34893
            roots,
229
34893
            persistence,
230
34893
            last_commit: Timestamp::now(),
231
34893
            expiring_keys: BTreeMap::new(),
232
34893
            background_worker_target,
233
34893
            expiration_order: VecDeque::new(),
234
34893
            dirty_keys: BTreeMap::new(),
235
34893
            keys_being_persisted: None,
236
34893
            last_persistence: Watchable::new(Timestamp::MIN),
237
34893
            shutdown: None,
238
34893
        }
239
34893
    }
240

            
241
31041
    pub fn shutdown(&mut self, state: &Arc<Mutex<KeyValueState>>) -> Option<flume::Receiver<()>> {
242
31041
        if self.keys_being_persisted.is_none() && self.commit_dirty_keys(state) {
243
361
            let (shutdown_sender, shutdown_receiver) = flume::bounded(1);
244
361
            self.shutdown = Some(shutdown_sender);
245
361
            Some(shutdown_receiver)
246
        } else {
247
30680
            None
248
        }
249
31041
    }
250

            
251
1496680
    pub fn perform_kv_operation(
252
1496680
        &mut self,
253
1496680
        op: KeyOperation,
254
1496680
        state: &Arc<Mutex<KeyValueState>>,
255
1496680
    ) -> Result<Output, bonsaidb_core::Error> {
256
1496680
        let now = Timestamp::now();
257
1496680
        // If there are any keys that have expired, clear them before executing any operations.
258
1496680
        self.remove_expired_keys(now);
259
1496680
        let result = match op.command {
260
6355
            Command::Set(command) => {
261
6355
                self.execute_set_operation(op.namespace.as_deref(), &op.key, command, now)
262
            }
263
5709
            Command::Get { delete } => {
264
5709
                self.execute_get_operation(op.namespace.as_deref(), &op.key, delete)
265
            }
266
1184
            Command::Delete => self.execute_delete_operation(op.namespace.as_deref(), &op.key),
267
1481952
            Command::Increment { amount, saturating } => self.execute_increment_operation(
268
1481952
                op.namespace.as_deref(),
269
1481952
                &op.key,
270
1481952
                &amount,
271
1481952
                saturating,
272
1481952
                now,
273
1481952
            ),
274
1480
            Command::Decrement { amount, saturating } => self.execute_decrement_operation(
275
1480
                op.namespace.as_deref(),
276
1480
                &op.key,
277
1480
                &amount,
278
1480
                saturating,
279
1480
                now,
280
1480
            ),
281
        };
282
1496680
        if result.is_ok() {
283
1496088
            if self.needs_commit(now) {
284
35893
                self.commit_dirty_keys(state);
285
1460195
            }
286
1496088
            self.update_background_worker_target();
287
592
        }
288
1496680
        result
289
1496680
    }
290

            
291
    #[cfg_attr(
292
        feature = "tracing",
293
        tracing::instrument(level = "trace", skip(self, set, now),)
294
    )]
295
6355
    fn execute_set_operation(
296
6355
        &mut self,
297
6355
        namespace: Option<&str>,
298
6355
        key: &str,
299
6355
        set: SetCommand,
300
6355
        now: Timestamp,
301
6355
    ) -> Result<Output, bonsaidb_core::Error> {
302
6207
        let mut entry = Entry {
303
6355
            value: set.value.validate()?,
304
6207
            expiration: set.expiration,
305
6207
            last_updated: now,
306
6207
        };
307
6207
        let full_key = full_key(namespace, key);
308
6207
        let possible_existing_value =
309
6207
            if set.check.is_some() || set.return_previous_value || set.keep_existing_expiration {
310
1400
                Some(self.get(&full_key).map_err(Error::from)?)
311
            } else {
312
4807
                None
313
            };
314
6207
        let existing_value_ref = possible_existing_value.as_ref().and_then(Option::as_ref);
315

            
316
6207
        let updating = match set.check {
317
368
            Some(KeyCheck::OnlyIfPresent) => existing_value_ref.is_some(),
318
440
            Some(KeyCheck::OnlyIfVacant) => existing_value_ref.is_none(),
319
5399
            None => true,
320
        };
321
6207
        if updating {
322
5839
            if set.keep_existing_expiration {
323
148
                if let Some(existing_value) = existing_value_ref {
324
148
                    entry.expiration = existing_value.expiration;
325
148
                }
326
5691
            }
327
5839
            self.update_key_expiration(&full_key, entry.expiration);
328

            
329
5839
            let previous_value = if let Some(existing_value) = possible_existing_value {
330
                // we already fetched, no need to ask for the existing value back
331
1032
                self.set(full_key, entry);
332
1032
                existing_value
333
            } else {
334
4807
                self.replace(full_key, entry).map_err(Error::from)?
335
            };
336
5839
            if set.return_previous_value {
337
516
                Ok(Output::Value(previous_value.map(|entry| entry.value)))
338
5323
            } else if previous_value.is_none() {
339
2367
                Ok(Output::Status(KeyStatus::Inserted))
340
            } else {
341
2956
                Ok(Output::Status(KeyStatus::Updated))
342
            }
343
        } else {
344
368
            Ok(Output::Status(KeyStatus::NotChanged))
345
        }
346
6355
    }
347

            
348
    #[cfg_attr(
349
        feature = "tracing",
350
        tracing::instrument(level = "trace", skip(self, tree_key, expiration))
351
    )]
352
7333
    pub fn update_key_expiration<'key>(
353
7333
        &mut self,
354
7333
        tree_key: impl Into<Cow<'key, str>>,
355
7333
        expiration: Option<Timestamp>,
356
7333
    ) {
357
7333
        let tree_key = tree_key.into();
358
7333
        let mut changed_first_expiration = false;
359
7333
        if let Some(expiration) = expiration {
360
971
            let key = if self.expiring_keys.contains_key(tree_key.as_ref()) {
361
                // Update the existing entry.
362
297
                let existing_entry_index = self
363
297
                    .expiration_order
364
297
                    .iter()
365
297
                    .enumerate()
366
297
                    .find_map(
367
297
                        |(index, key)| {
368
297
                            if &tree_key == key {
369
297
                                Some(index)
370
                            } else {
371
                                None
372
                            }
373
297
                        },
374
297
                    )
375
297
                    .unwrap();
376
297
                changed_first_expiration = existing_entry_index == 0;
377
297
                self.expiration_order.remove(existing_entry_index).unwrap()
378
            } else {
379
674
                tree_key.into_owned()
380
            };
381

            
382
            // Insert the key into the expiration_order queue
383
971
            let mut insert_at = None;
384
971
            for (index, expiring_key) in self.expiration_order.iter().enumerate() {
385
448
                if self.expiring_keys.get(expiring_key).unwrap() > &expiration {
386
150
                    insert_at = Some(index);
387
150
                    break;
388
298
                }
389
            }
390
971
            if let Some(insert_at) = insert_at {
391
150
                changed_first_expiration |= insert_at == 0;
392
150

            
393
150
                self.expiration_order.insert(insert_at, key.clone());
394
821
            } else {
395
821
                changed_first_expiration |= self.expiration_order.is_empty();
396
821
                self.expiration_order.push_back(key.clone());
397
821
            }
398
971
            self.expiring_keys.insert(key, expiration);
399
6362
        } else if self.expiring_keys.remove(tree_key.as_ref()).is_some() {
400
149
            let index = self
401
149
                .expiration_order
402
149
                .iter()
403
149
                .enumerate()
404
149
                .find_map(|(index, key)| {
405
149
                    if tree_key.as_ref() == key {
406
149
                        Some(index)
407
                    } else {
408
                        None
409
                    }
410
149
                })
411
149
                .unwrap();
412
149

            
413
149
            changed_first_expiration |= index == 0;
414
149
            self.expiration_order.remove(index);
415
6213
        }
416

            
417
7333
        if changed_first_expiration {
418
970
            self.update_background_worker_target();
419
6363
        }
420
7333
    }
421

            
422
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
423
5709
    fn execute_get_operation(
424
5709
        &mut self,
425
5709
        namespace: Option<&str>,
426
5709
        key: &str,
427
5709
        delete: bool,
428
5709
    ) -> Result<Output, bonsaidb_core::Error> {
429
5709
        let full_key = full_key(namespace, key);
430
5709
        let entry = if delete {
431
296
            self.remove(full_key).map_err(Error::from)?
432
        } else {
433
5413
            self.get(&full_key).map_err(Error::from)?
434
        };
435

            
436
5709
        Ok(Output::Value(entry.map(|e| e.value)))
437
5709
    }
438

            
439
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
440
1184
    fn execute_delete_operation(
441
1184
        &mut self,
442
1184
        namespace: Option<&str>,
443
1184
        key: &str,
444
1184
    ) -> Result<Output, bonsaidb_core::Error> {
445
1184
        let full_key = full_key(namespace, key);
446
1184
        let value = self.remove(full_key).map_err(Error::from)?;
447
1184
        if value.is_some() {
448
444
            Ok(Output::Status(KeyStatus::Deleted))
449
        } else {
450
740
            Ok(Output::Status(KeyStatus::NotChanged))
451
        }
452
1184
    }
453

            
454
    #[cfg_attr(
455
        feature = "tracing",
456
        tracing::instrument(level = "trace", skip(self, amount, saturating, now))
457
    )]
458
1481952
    fn execute_increment_operation(
459
1481952
        &mut self,
460
1481952
        namespace: Option<&str>,
461
1481952
        key: &str,
462
1481952
        amount: &Numeric,
463
1481952
        saturating: bool,
464
1481952
        now: Timestamp,
465
1481952
    ) -> Result<Output, bonsaidb_core::Error> {
466
1481952
        self.execute_numeric_operation(namespace, key, amount, saturating, now, increment)
467
1481952
    }
468

            
469
    #[cfg_attr(
470
        feature = "tracing",
471
        tracing::instrument(level = "trace", skip(self, amount, saturating, now))
472
    )]
473
1480
    fn execute_decrement_operation(
474
1480
        &mut self,
475
1480
        namespace: Option<&str>,
476
1480
        key: &str,
477
1480
        amount: &Numeric,
478
1480
        saturating: bool,
479
1480
        now: Timestamp,
480
1480
    ) -> Result<Output, bonsaidb_core::Error> {
481
1480
        self.execute_numeric_operation(namespace, key, amount, saturating, now, decrement)
482
1480
    }
483

            
484
1483432
    fn execute_numeric_operation<F: Fn(&Numeric, &Numeric, bool) -> Numeric>(
485
1483432
        &mut self,
486
1483432
        namespace: Option<&str>,
487
1483432
        key: &str,
488
1483432
        amount: &Numeric,
489
1483432
        saturating: bool,
490
1483432
        now: Timestamp,
491
1483432
        op: F,
492
1483432
    ) -> Result<Output, bonsaidb_core::Error> {
493
1483432
        let full_key = full_key(namespace, key);
494
1483432
        let current = self.get(&full_key).map_err(Error::from)?;
495
1483432
        let mut entry = current.unwrap_or(Entry {
496
1483432
            value: Value::Numeric(Numeric::UnsignedInteger(0)),
497
1483432
            expiration: None,
498
1483432
            last_updated: now,
499
1483432
        });
500
1483432

            
501
1483432
        match entry.value {
502
1483136
            Value::Numeric(existing) => {
503
1483136
                let value = Value::Numeric(op(&existing, amount, saturating).validate()?);
504
1482988
                entry.value = value.clone();
505
1482988

            
506
1482988
                self.set(full_key, entry);
507
1482988
                Ok(Output::Value(Some(value)))
508
            }
509
296
            Value::Bytes(_) => Err(bonsaidb_core::Error::other(
510
296
                "bonsaidb-local",
511
296
                "type of stored `Value` is not `Numeric`",
512
296
            )),
513
        }
514
1483432
    }
515

            
516
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
517
1480
    fn remove(&mut self, key: String) -> Result<Option<Entry>, nebari::Error> {
518
1480
        self.update_key_expiration(&key, None);
519

            
520
1480
        if let Some(dirty_entry) = self.dirty_keys.get_mut(&key) {
521
196
            Ok(dirty_entry.take())
522
1284
        } else if let Some(persisting_entry) = self
523
1284
            .keys_being_persisted
524
1284
            .as_ref()
525
1284
            .and_then(|keys| keys.get(&key))
526
        {
527
108
            self.dirty_keys.insert(key, None);
528
108
            Ok(persisting_entry.clone())
529
        } else {
530
            // There might be a value on-disk we need to remove.
531
1176
            let previous_value = Self::retrieve_key_from_disk(&self.roots, &key)?;
532
1176
            self.dirty_keys.insert(key, None);
533
1176
            Ok(previous_value)
534
        }
535
1480
    }
536

            
537
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
538
    fn get(&self, key: &str) -> Result<Option<Entry>, nebari::Error> {
539
1490245
        if let Some(entry) = self.dirty_keys.get(key) {
540
1393209
            Ok(entry.clone())
541
97036
        } else if let Some(persisting_entry) = self
542
97036
            .keys_being_persisted
543
97036
            .as_ref()
544
97036
            .and_then(|keys| keys.get(key))
545
        {
546
67900
            Ok(persisting_entry.clone())
547
        } else {
548
29136
            Self::retrieve_key_from_disk(&self.roots, key)
549
        }
550
1490245
    }
551

            
552
1484020
    fn set(&mut self, key: String, value: Entry) {
553
1484020
        self.dirty_keys.insert(key, Some(value));
554
1484020
    }
555

            
556
4807
    fn replace(&mut self, key: String, value: Entry) -> Result<Option<Entry>, nebari::Error> {
557
4807
        let mut value = Some(value);
558
4807
        let map_entry = self.dirty_keys.entry(key);
559
4807
        if matches!(map_entry, btree_map::Entry::Vacant(_)) {
560
            // This key is clean, and the caller is expecting the previous
561
            // value.
562
3549
            let stored_value = if let Some(persisting_entry) = self
563
3549
                .keys_being_persisted
564
3549
                .as_ref()
565
3549
                .and_then(|keys| keys.get(map_entry.key()))
566
            {
567
949
                persisting_entry.clone()
568
            } else {
569
2600
                Self::retrieve_key_from_disk(&self.roots, map_entry.key())?
570
            };
571
3549
            map_entry.or_insert(value);
572
3549
            Ok(stored_value)
573
        } else {
574
            // This key is already dirty, we can just replace the value and
575
            // return the old value.
576
1258
            map_entry.and_modify(|map_entry| {
577
1258
                std::mem::swap(&mut value, map_entry);
578
1258
            });
579
1258
            Ok(value)
580
        }
581
4807
    }
582

            
583
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(roots)))]
584
32912
    fn retrieve_key_from_disk(
585
32912
        roots: &Roots<AnyFile>,
586
32912
        key: &str,
587
32912
    ) -> Result<Option<Entry>, nebari::Error> {
588
32912
        roots
589
32912
            .tree(Unversioned::tree(KEY_TREE))?
590
32912
            .get(key.as_bytes())
591
32912
            .map(|current| current.and_then(|current| bincode::deserialize::<Entry>(&current).ok()))
592
32912
    }
593

            
594
1660676
    fn update_background_worker_target(&mut self) {
595
1660676
        let key_expiration_target = self.expiration_order.get(0).map(|key| {
596
4401
            let expiration_timeout = self.expiring_keys.get(key).unwrap();
597
4401
            *expiration_timeout
598
1660676
        });
599
1660676
        let now = Timestamp::now();
600
1660676
        let persisting = self.keys_being_persisted.is_some();
601
1660676
        let commit_target = (!persisting)
602
1660676
            .then(|| {
603
99465
                self.persistence.duration_until_next_commit(
604
99465
                    self.dirty_keys.len(),
605
99465
                    (now - self.last_commit).unwrap_or_default(),
606
99465
                )
607
1660676
            })
608
1660676
            .flatten()
609
1660676
            .map(|duration| now + duration);
610
1660676
        match (commit_target, key_expiration_target) {
611
71767
            (Some(target), _) | (_, Some(target)) if target <= now => {
612
67840
                self.background_worker_target
613
67840
                    .replace(BackgroundWorkerProcessTarget::Now);
614
67840
            }
615
            (Some(commit_target), Some(key_target)) => {
616
                let closest_target = key_target.min(commit_target);
617
                let new_target = BackgroundWorkerProcessTarget::Timestamp(closest_target);
618
                let _: Result<_, _> = self.background_worker_target.update(new_target);
619
            }
620
3927
            (Some(target), None) | (None, Some(target)) => {
621
3927
                let _: Result<_, _> = self
622
3927
                    .background_worker_target
623
3927
                    .update(BackgroundWorkerProcessTarget::Timestamp(target));
624
3927
            }
625
1588909
            (None, None) => {
626
1588909
                let _: Result<_, _> = self
627
1588909
                    .background_worker_target
628
1588909
                    .update(BackgroundWorkerProcessTarget::Never);
629
1588909
            }
630
        }
631
1660676
    }
632

            
633
1563270
    fn remove_expired_keys(&mut self, now: Timestamp) {
634
1563722
        while !self.expiration_order.is_empty()
635
2519
            && self.expiring_keys.get(&self.expiration_order[0]).unwrap() <= &now
636
452
        {
637
452
            let key = self.expiration_order.pop_front().unwrap();
638
452
            self.expiring_keys.remove(&key);
639
452
            self.dirty_keys.insert(key, None);
640
452
        }
641
1563270
    }
642

            
643
1562678
    fn needs_commit(&mut self, now: Timestamp) -> bool {
644
1562678
        if self.keys_being_persisted.is_some() {
645
1463984
            false
646
        } else {
647
98694
            let since_last_commit = (now - self.last_commit).unwrap_or_default();
648
98694
            self.persistence
649
98694
                .should_commit(self.dirty_keys.len(), since_last_commit)
650
        }
651
1562678
    }
652

            
653
127691
    fn stage_dirty_keys(&mut self) -> Option<Arc<BTreeMap<String, Option<Entry>>>> {
654
127691
        if !self.dirty_keys.is_empty() && self.keys_being_persisted.is_none() {
655
97028
            let keys = Arc::new(std::mem::take(&mut self.dirty_keys));
656
97028
            self.keys_being_persisted = Some(keys.clone());
657
97028
            Some(keys)
658
        } else {
659
30663
            None
660
        }
661
127691
    }
662

            
663
    pub fn commit_dirty_keys(&mut self, state: &Arc<Mutex<KeyValueState>>) -> bool {
664
127330
        if let Some(keys) = self.stage_dirty_keys() {
665
97028
            let roots = self.roots.clone();
666
97028
            let state = state.clone();
667
97028
            std::thread::Builder::new()
668
97028
                .name(String::from("keyvalue-persist"))
669
97028
                .spawn(move || Self::persist_keys(&state, &roots, &keys))
670
97028
                .unwrap();
671
97028
            self.last_commit = Timestamp::now();
672
97028
            true
673
        } else {
674
30302
            false
675
        }
676
127330
    }
677

            
678
    #[cfg(test)]
679
5
    pub fn persistence_watcher(&self) -> Watcher<Timestamp> {
680
5
        self.last_persistence.watch()
681
5
    }
682

            
683
194056
    #[cfg_attr(feature = "instrument", tracing::instrument(level = "trace", skip_all))]
684
    fn persist_keys(
685
        key_value_state: &Arc<Mutex<KeyValueState>>,
686
        roots: &Roots<AnyFile>,
687
        keys: &BTreeMap<String, Option<Entry>>,
688
    ) -> Result<(), bonsaidb_core::Error> {
689
        let mut transaction = roots
690
            .transaction(&[Unversioned::tree(KEY_TREE)])
691
            .map_err(Error::from)?;
692
        let all_keys = keys
693
            .keys()
694
98531
            .map(|key| ArcBytes::from(key.as_bytes().to_vec()))
695
            .collect();
696
        let mut changed_keys = Vec::new();
697
        transaction
698
            .tree::<Unversioned>(0)
699
            .unwrap()
700
            .modify(
701
                all_keys,
702
98531
                Operation::CompareSwap(CompareSwap::new(&mut |key, existing_value| {
703
98531
                    let full_key = std::str::from_utf8(key).unwrap();
704
98531
                    let (namespace, key) = split_key(full_key).unwrap();
705

            
706
98531
                    if let Some(new_value) = keys.get(full_key).unwrap() {
707
96833
                        changed_keys.push(ChangedKey {
708
96833
                            namespace,
709
96833
                            key,
710
96833
                            deleted: false,
711
96833
                        });
712
96833
                        let bytes = bincode::serialize(new_value).unwrap();
713
96833
                        nebari::tree::KeyOperation::Set(ArcBytes::from(bytes))
714
1698
                    } else if existing_value.is_some() {
715
1038
                        changed_keys.push(ChangedKey {
716
1038
                            namespace,
717
1038
                            key,
718
1038
                            deleted: existing_value.is_some(),
719
1038
                        });
720
1038
                        nebari::tree::KeyOperation::Remove
721
                    } else {
722
660
                        nebari::tree::KeyOperation::Skip
723
                    }
724
98531
                })),
725
            )
726
            .map_err(Error::from)?;
727

            
728
        if !changed_keys.is_empty() {
729
            transaction
730
                .entry_mut()
731
                .set_data(compat::serialize_executed_transaction_changes(
732
                    &Changes::Keys(changed_keys),
733
                )?)
734
                .map_err(Error::from)?;
735
            transaction.commit().map_err(Error::from)?;
736
        }
737

            
738
        // If we are shutting down, check if we still have dirty keys.
739
        let final_keys = {
740
            let mut state = key_value_state.lock();
741
            state.last_persistence.replace(Timestamp::now());
742
            state.keys_being_persisted = None;
743
            state.update_background_worker_target();
744
            // This block is a little ugly to avoid having to acquire the lock
745
            // twice. If we're shutting down and have no dirty keys, we notify
746
            // the waiting shutdown task. If we have any dirty keys, we wait do
747
            // to that step because we're going to recurse and reach this spot
748
            // again.
749
            if state.shutdown.is_some() {
750
                let staged_keys = state.stage_dirty_keys();
751
                if staged_keys.is_none() {
752
                    let shutdown = state.shutdown.take().unwrap();
753
                    let _: Result<_, _> = shutdown.send(());
754
                }
755
                staged_keys
756
            } else {
757
                None
758
            }
759
        };
760
        if let Some(final_keys) = final_keys {
761
            Self::persist_keys(key_value_state, roots, &final_keys)?;
762
        }
763
        Ok(())
764
    }
765
}
766

            
767
34893
pub fn background_worker(
768
34893
    key_value_state: &Weak<Mutex<KeyValueState>>,
769
34893
    timestamp_receiver: &mut Watcher<BackgroundWorkerProcessTarget>,
770
34893
    storage_lock: Option<StorageLock>,
771
34893
) {
772
169099
    loop {
773
169099
        let mut perform_operations = false;
774
169099
        let current_target = *timestamp_receiver.read();
775
169099
        match current_target {
776
            // With no target, sleep until we receive a target.
777
            BackgroundWorkerProcessTarget::Never => {
778
102435
                if timestamp_receiver.watch().is_err() {
779
30887
                    break;
780
71548
                }
781
            }
782
1111
            BackgroundWorkerProcessTarget::Timestamp(target) => {
783
1111
                // With a target, we need to wait to receive a target only as
784
1111
                // long as there is time remaining.
785
1111
                let remaining = target - Timestamp::now();
786
1111
                if let Some(remaining) = remaining {
787
                    // recv_timeout panics if Instant::checked_add(remaining)
788
                    // fails. So, we will cap the sleep time at 1 day.
789
1111
                    let remaining = remaining.min(Duration::from_secs(60 * 60 * 24));
790
1111
                    match timestamp_receiver.watch_timeout(remaining) {
791
1110
                        Ok(_) | Err(watchable::TimeoutError::Timeout) => {
792
1110
                            perform_operations = true;
793
1110
                        }
794
1
                        Err(watchable::TimeoutError::Disconnected) => break,
795
                    }
796
                } else {
797
                    perform_operations = true;
798
                }
799
            }
800
65553
            BackgroundWorkerProcessTarget::Now => {
801
65553
                perform_operations = true;
802
65553
            }
803
        };
804

            
805
138211
        let Some(key_value_state) = key_value_state.upgrade() else {
806
153
            break;
807
        };
808

            
809
138058
        if perform_operations {
810
70442
            let mut state = key_value_state.lock();
811
70442
            let now = Timestamp::now();
812
70442
            state.remove_expired_keys(now);
813
70442
            if state.needs_commit(now) {
814
60774
                state.commit_dirty_keys(&key_value_state);
815
60774
            }
816
66590
            state.update_background_worker_target();
817
67616
        }
818
    }
819

            
820
    // The key-value store's delayed persistence can cause the key-value storage
821
    // to be written past when the last reference to the storage is still held.
822
    // The storage lock being held ensures that another reader/writer doesn't
823
    // begin accessing this same storage again.
824
31041
    drop(storage_lock);
825
31041
}
826

            
827
1592836
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
828
pub enum BackgroundWorkerProcessTarget {
829
    Now,
830
    Timestamp(Timestamp),
831
    Never,
832
}
833

            
834
#[derive(Debug)]
835
pub struct ExpirationLoader {
836
    pub database: Database,
837
    pub launched_at: Timestamp,
838
}
839

            
840
impl Keyed<Task> for ExpirationLoader {
841
17789
    fn key(&self) -> Task {
842
17789
        Task::ExpirationLoader(self.database.data.name.clone())
843
17789
    }
844
}
845

            
846
impl Job for ExpirationLoader {
847
    type Error = Error;
848
    type Output = ();
849

            
850
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
851
17246
    fn execute(&mut self) -> Result<Self::Output, Self::Error> {
852
17246
        let database = self.database.clone();
853
17246
        let launched_at = self.launched_at;
854

            
855
17246
        for ((namespace, key), entry) in database.all_key_value_entries()? {
856
75
            if entry.last_updated < launched_at && entry.expiration.is_some() {
857
1
                self.database
858
1
                    .update_key_expiration(full_key(namespace.as_deref(), &key), entry.expiration);
859
74
            }
860
        }
861

            
862
17242
        self.database
863
17242
            .storage()
864
17242
            .instance
865
17242
            .tasks()
866
17242
            .mark_key_value_expiration_loaded(self.database.data.name.clone());
867
17242

            
868
17242
        Ok(())
869
17246
    }
870
}
871

            
872
#[cfg(test)]
873
mod tests {
874
    use std::time::{Duration, Instant};
875

            
876
    use bonsaidb_core::arc_bytes::serde::Bytes;
877
    use bonsaidb_core::test_util::{TestDirectory, TimingTest};
878
    use nebari::io::any::{AnyFile, AnyFileManager};
879

            
880
    use super::*;
881
    use crate::config::PersistenceThreshold;
882
    use crate::database::Context;
883

            
884
6
    fn run_test_with_persistence<
885
6
        F: Fn(Context, nebari::Roots<AnyFile>) -> anyhow::Result<()> + Send,
886
6
    >(
887
6
        name: &str,
888
6
        persistence: KeyValuePersistence,
889
6
        test_contents: &F,
890
6
    ) -> anyhow::Result<()> {
891
6
        let dir = TestDirectory::new(name);
892
6
        let sled = nebari::Config::new(&dir)
893
6
            .file_manager(AnyFileManager::std())
894
6
            .open()?;
895

            
896
6
        let context = Context::new(sled.clone(), persistence, None);
897
6

            
898
6
        test_contents(context, sled)?;
899

            
900
6
        Ok(())
901
6
    }
902

            
903
5
    fn run_test<F: Fn(Context, nebari::Roots<AnyFile>) -> anyhow::Result<()> + Send>(
904
5
        name: &str,
905
5
        test_contents: F,
906
5
    ) -> anyhow::Result<()> {
907
5
        run_test_with_persistence(name, KeyValuePersistence::default(), &test_contents)
908
5
    }
909

            
910
1
    #[test]
911
1
    fn basic_expiration() -> anyhow::Result<()> {
912
1
        run_test("kv-basic-expiration", |context, roots| {
913
1
            // Initialize the test state
914
1
            let mut persistence_watcher = context.kv_persistence_watcher();
915
1
            roots.delete_tree(KEY_TREE)?;
916
1
            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
917
1
            tree.set(b"atree\0akey", b"somevalue")?;
918

            
919
            // Expire the existing key
920
1
            context.update_key_expiration(
921
1
                full_key(Some("atree"), "akey"),
922
1
                Some(Timestamp::now() + Duration::from_millis(100)),
923
1
            );
924
1
            // Wait for persistence.
925
1
            persistence_watcher.next_value()?;
926

            
927
            // Verify it is gone.
928
1
            assert!(tree.get(b"akey")?.is_none());
929

            
930
1
            Ok(())
931
1
        })
932
1
    }
933

            
934
1
    #[test]
935
1
    fn updating_expiration() -> anyhow::Result<()> {
936
1
        run_test("kv-updating-expiration", |context, roots| {
937
1
            // Initialize the test state
938
1
            let mut persistence_watcher = context.kv_persistence_watcher();
939
1
            roots.delete_tree(KEY_TREE)?;
940
1
            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
941
1
            tree.set(b"atree\0akey", b"somevalue")?;
942
1
            let start = Timestamp::now();
943
1

            
944
1
            // Set the expiration once.
945
1
            context.update_key_expiration(
946
1
                full_key(Some("atree"), "akey"),
947
1
                Some(start + Duration::from_millis(100)),
948
1
            );
949
1
            // Set the expiration to a longer value.
950
1
            let correct_expiration = start + Duration::from_secs(1);
951
1
            context
952
1
                .update_key_expiration(full_key(Some("atree"), "akey"), Some(correct_expiration));
953

            
954
            // Wait for persistence, and ensure that the next persistence is
955
            // after our expiration timestamp.
956
1
            assert!(persistence_watcher.next_value()? > correct_expiration);
957

            
958
            // Verify the key is gone now.
959
1
            assert_eq!(tree.get(b"atree\0akey")?, None);
960

            
961
1
            Ok(())
962
1
        })
963
1
    }
964

            
965
1
    #[test]
966
1
    fn multiple_keys_expiration() -> anyhow::Result<()> {
967
1
        run_test("kv-multiple-keys-expiration", |context, roots| {
968
1
            // Initialize the test state
969
1
            let mut persistence_watcher = context.kv_persistence_watcher();
970
1
            roots.delete_tree(KEY_TREE)?;
971
1
            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
972
1
            tree.set(b"atree\0akey", b"somevalue")?;
973
1
            tree.set(b"atree\0bkey", b"somevalue")?;
974

            
975
            // Expire both keys, one for a shorter time than the other.
976
1
            context.update_key_expiration(
977
1
                full_key(Some("atree"), "akey"),
978
1
                Some(Timestamp::now() + Duration::from_millis(100)),
979
1
            );
980
1
            context.update_key_expiration(
981
1
                full_key(Some("atree"), "bkey"),
982
1
                Some(Timestamp::now() + Duration::from_secs(1)),
983
1
            );
984
1

            
985
1
            // Wait for the first persistence.
986
1
            persistence_watcher.next_value()?;
987
1
            assert!(tree.get(b"atree\0akey")?.is_none());
988
1
            assert!(tree.get(b"atree\0bkey")?.is_some());
989

            
990
            // Wait for the second persistence.
991
1
            persistence_watcher.next_value()?;
992
1
            assert!(tree.get(b"atree\0bkey")?.is_none());
993

            
994
1
            Ok(())
995
1
        })
996
1
    }
997

            
998
1
    #[test]
999
1
    fn clearing_expiration() -> anyhow::Result<()> {
1
        run_test("kv-clearing-expiration", |sender, sled| {
            loop {
1
                sled.delete_tree(KEY_TREE)?;
1
                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
1
                tree.set(b"atree\0akey", b"somevalue")?;
1
                let timing = TimingTest::new(Duration::from_millis(100));
1
                sender.update_key_expiration(
1
                    full_key(Some("atree"), "akey"),
1
                    Some(Timestamp::now() + Duration::from_millis(100)),
1
                );
1
                sender.update_key_expiration(full_key(Some("atree"), "akey"), None);
1
                if timing.elapsed() > Duration::from_millis(100) {
                    // Restart, took too long.
                    continue;
1
                }
1
                timing.wait_until(Duration::from_millis(150));
1
                assert!(tree.get(b"atree\0akey")?.is_some());
1
                break;
1
            }
1

            
1
            Ok(())
1
        })
1
    }

            
1
    #[test]
1
    fn out_of_order_expiration() -> anyhow::Result<()> {
1
        run_test("kv-out-of-order-expiration", |context, roots| loop {
1
            context.update_key_expiration(full_key(Some("atree"), "akey"), None);
1
            context.update_key_expiration(full_key(Some("atree"), "bkey"), None);
1
            context.update_key_expiration(full_key(Some("atree"), "ckey"), None);
1
            let mut persistence_watcher = context.kv_persistence_watcher();
1
            drop(roots.delete_tree(KEY_TREE));
1
            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
1
            tree.set(b"atree\0akey", b"somevalue")?;
1
            tree.set(b"atree\0bkey", b"somevalue")?;
1
            tree.set(b"atree\0ckey", b"somevalue")?;
1
            let timing = TimingTest::new(Duration::from_millis(100));
1
            context.update_key_expiration(
1
                full_key(Some("atree"), "akey"),
1
                Some(Timestamp::now() + Duration::from_secs(3)),
1
            );
1
            context.update_key_expiration(
1
                full_key(Some("atree"), "ckey"),
1
                Some(Timestamp::now() + Duration::from_secs(1)),
1
            );
1
            context.update_key_expiration(
1
                full_key(Some("atree"), "bkey"),
1
                Some(Timestamp::now() + Duration::from_secs(2)),
1
            );
1
            persistence_watcher.mark_read();
1
            if timing.elapsed() > Duration::from_millis(500) {
                println!("Restarting");
                continue;
1
            }
1

            
1
            // Wait for the first key to expire.
1
            persistence_watcher
1
                .watch_timeout(Duration::from_secs(5))
1
                .unwrap();
1
            persistence_watcher.mark_read();
1
            if timing.elapsed() > Duration::from_millis(1500) {
                println!("Restarting");
                continue;
1
            }
1
            assert!(tree.get(b"atree\0akey")?.is_some());
1
            assert!(tree.get(b"atree\0bkey")?.is_some());
1
            assert!(tree.get(b"atree\0ckey")?.is_none());

            
            // Wait for the next key to expire.
1
            persistence_watcher
1
                .watch_timeout(Duration::from_secs(5))
1
                .unwrap();
1
            persistence_watcher.mark_read();
1
            if timing.elapsed() > Duration::from_millis(2500) {
                println!("Restarting");
                continue;
1
            }
1
            assert!(tree.get(b"atree\0akey")?.is_some());
1
            assert!(tree.get(b"atree\0bkey")?.is_none());

            
            // Wait for the final key to expire.
1
            persistence_watcher
1
                .watch_timeout(Duration::from_secs(5))
1
                .unwrap();
1
            if timing.elapsed() > Duration::from_millis(3500) {
                println!("Restarting");
                continue;
1
            }
1
            assert!(tree.get(b"atree\0akey")?.is_none());

            
1
            return Ok(());
1
        })
1
    }

            
1
    #[test]
1
    fn basic_persistence() -> anyhow::Result<()> {
1
        run_test_with_persistence(
1
            "kv-basic-persistence",
1
            KeyValuePersistence::lazy([
1
                PersistenceThreshold::after_changes(2),
1
                PersistenceThreshold::after_changes(1).and_duration(Duration::from_secs(2)),
1
            ]),
1
            &|context, roots| {
1
                // Initialize the test state
1
                let mut persistence_watcher = context.kv_persistence_watcher();
1
                let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
1
                let start = Instant::now();
1
                // Set three keys in quick succession. The first two should
1
                // persist immediately after the second is set, and the
1
                // third should show up after 2 seconds.
1
                context
1
                    .perform_kv_operation(KeyOperation {
1
                        namespace: None,
1
                        key: String::from("key1"),
1
                        command: Command::Set(SetCommand {
1
                            value: Value::Bytes(Bytes::default()),
1
                            expiration: None,
1
                            keep_existing_expiration: false,
1
                            check: None,
1
                            return_previous_value: false,
1
                        }),
1
                    })
1
                    .unwrap();
1
                context
1
                    .perform_kv_operation(KeyOperation {
1
                        namespace: None,
1
                        key: String::from("key2"),
1
                        command: Command::Set(SetCommand {
1
                            value: Value::Bytes(Bytes::default()),
1
                            expiration: None,
1
                            keep_existing_expiration: false,
1
                            check: None,
1
                            return_previous_value: false,
1
                        }),
1
                    })
1
                    .unwrap();
1
                context
1
                    .perform_kv_operation(KeyOperation {
1
                        namespace: None,
1
                        key: String::from("key3"),
1
                        command: Command::Set(SetCommand {
1
                            value: Value::Bytes(Bytes::default()),
1
                            expiration: None,
1
                            keep_existing_expiration: false,
1
                            check: None,
1
                            return_previous_value: false,
1
                        }),
1
                    })
1
                    .unwrap();
1
                // Wait for the first persistence to occur.
1
                persistence_watcher.next_value()?;
1

            
1
                assert!(tree.get(b"\0key1").unwrap().is_some());
1
                assert!(tree.get(b"\0key2").unwrap().is_some());
1
                assert!(tree.get(b"\0key3").unwrap().is_none());
1

            
1
                // Wait for the second persistence
1
                persistence_watcher.next_value()?;
1
                assert!(tree.get(b"\0key3").unwrap().is_some());
1
                // The total operation should have taken *at least* two seconds,
1
                // since the second persistence should have delayed for two
1
                // seconds itself.
1
                assert!(start.elapsed() > Duration::from_secs(2));
1

            
1
                Ok(())
1
            },
1
        )
1
    }

            
1
    #[test]
1
    fn saves_on_drop() -> anyhow::Result<()> {
1
        let dir = TestDirectory::new("saves-on-drop.bonsaidb");
1
        let sled = nebari::Config::new(&dir)
1
            .file_manager(AnyFileManager::std())
1
            .open()?;
1
        let tree = sled.tree(Unversioned::tree(KEY_TREE))?;

            
1
        let context = Context::new(
1
            sled,
1
            KeyValuePersistence::lazy([PersistenceThreshold::after_changes(2)]),
1
            None,
1
        );
1
        context
1
            .perform_kv_operation(KeyOperation {
1
                namespace: None,
1
                key: String::from("key1"),
1
                command: Command::Set(SetCommand {
1
                    value: Value::Bytes(Bytes::default()),
1
                    expiration: None,
1
                    keep_existing_expiration: false,
1
                    check: None,
1
                    return_previous_value: false,
1
                }),
1
            })
1
            .unwrap();
1
        assert!(tree.get(b"\0key1").unwrap().is_none());
1
        drop(context);
1
        // Dropping spawns a task that should persist the keys. Give a moment
1
        // for the runtime to execute the task.
1
        std::thread::sleep(Duration::from_millis(100));
1
        assert!(tree.get(b"\0key1").unwrap().is_some());

            
1
        Ok(())
1
    }
}