1
use std::borrow::Cow;
2
use std::collections::{btree_map, BTreeMap, VecDeque};
3
use std::sync::{Arc, Weak};
4
use std::time::Duration;
5

            
6
use bonsaidb_core::connection::{Connection, HasSession};
7
use bonsaidb_core::keyvalue::{
8
    Command, KeyCheck, KeyOperation, KeyStatus, KeyValue, Numeric, Output, SetCommand, Timestamp,
9
    Value,
10
};
11
use bonsaidb_core::permissions::bonsai::{
12
    keyvalue_key_resource_name, BonsaiAction, DatabaseAction, KeyValueAction,
13
};
14
use bonsaidb_core::transaction::{ChangedKey, Changes};
15
use nebari::io::any::AnyFile;
16
use nebari::tree::{CompareSwap, Operation, Root, ScanEvaluation, Unversioned};
17
use nebari::{AbortError, ArcBytes, Roots};
18
use parking_lot::Mutex;
19
use serde::{Deserialize, Serialize};
20
use watchable::{Watchable, Watcher};
21

            
22
use crate::config::KeyValuePersistence;
23
use crate::database::compat;
24
use crate::storage::StorageLock;
25
use crate::tasks::{Job, Keyed, Task};
26
use crate::{Database, DatabaseNonBlocking, Error};
27

            
28
1501680
#[derive(Serialize, Deserialize, Debug, Clone)]
29
pub struct Entry {
30
    pub value: Value,
31
    pub expiration: Option<Timestamp>,
32
    #[serde(default)]
33
    pub last_updated: Timestamp,
34
}
35

            
36
impl Entry {
37
    pub(crate) fn restore(
38
        self,
39
        namespace: Option<String>,
40
        key: String,
41
        database: &Database,
42
    ) -> Result<(), bonsaidb_core::Error> {
43
3
        database.execute_key_operation(KeyOperation {
44
3
            namespace,
45
3
            key,
46
3
            command: Command::Set(SetCommand {
47
3
                value: self.value,
48
3
                expiration: self.expiration,
49
3
                keep_existing_expiration: false,
50
3
                check: None,
51
3
                return_previous_value: false,
52
3
            }),
53
3
        })?;
54
3
        Ok(())
55
3
    }
56
}
57

            
58
impl KeyValue for Database {
59
    fn execute_key_operation(&self, op: KeyOperation) -> Result<Output, bonsaidb_core::Error> {
60
1537127
        self.check_permission(
61
1537127
            keyvalue_key_resource_name(self.name(), op.namespace.as_deref(), &op.key),
62
1537127
            &BonsaiAction::Database(DatabaseAction::KeyValue(KeyValueAction::ExecuteOperation)),
63
1537127
        )?;
64
1537127
        self.data.context.perform_kv_operation(op)
65
1537127
    }
66
}
67

            
68
impl Database {
69
18018
    pub(crate) fn all_key_value_entries(
70
18018
        &self,
71
18018
    ) -> Result<BTreeMap<(Option<String>, String), Entry>, Error> {
72
18018
        // Lock the state so that new new modifications can be made while we gather this snapshot.
73
18018
        let state = self.data.context.key_value_state.lock();
74
18018
        let database = self.clone();
75
18018
        // Initialize our entries with any dirty keys and any keys that are about to be persisted.
76
18018
        let mut all_entries = BTreeMap::new();
77
18018
        database
78
18018
            .roots()
79
18018
            .tree(Unversioned::tree(KEY_TREE))?
80
17900
            .scan::<Error, _, _, _, _>(
81
17900
                &(..),
82
17900
                true,
83
17900
                |_, _, _| ScanEvaluation::ReadData,
84
17900
                |_, _| ScanEvaluation::ReadData,
85
17900
                |key, _, entry: ArcBytes<'static>| {
86
25
                    let entry = bincode::deserialize::<Entry>(&entry)
87
25
                        .map_err(|err| AbortError::Other(Error::from(err)))?;
88
25
                    let full_key = std::str::from_utf8(&key)
89
25
                        .map_err(|err| AbortError::Other(Error::from(err)))?;
90

            
91
25
                    if let Some(split_key) = split_key(full_key) {
92
25
                        // Do not overwrite the existing key
93
25
                        all_entries.entry(split_key).or_insert(entry);
94
25
                    }
95

            
96
25
                    Ok(())
97
17900
                },
98
17900
            )?;
99

            
100
        // Apply the pending writes first
101
17863
        if let Some(pending_keys) = &state.keys_being_persisted {
102
14
            for (key, possible_entry) in pending_keys.iter() {
103
14
                let (namespace, key) = split_key(key).unwrap();
104
14
                if let Some(updated_entry) = possible_entry {
105
11
                    all_entries.insert((namespace, key), updated_entry.clone());
106
11
                } else {
107
3
                    all_entries.remove(&(namespace, key));
108
3
                }
109
            }
110
17850
        }
111

            
112
17863
        for (key, possible_entry) in &state.dirty_keys {
113
22
            let (namespace, key) = split_key(key).unwrap();
114
22
            if let Some(updated_entry) = possible_entry {
115
20
                all_entries.insert((namespace, key), updated_entry.clone());
116
20
            } else {
117
2
                all_entries.remove(&(namespace, key));
118
2
            }
119
        }
120

            
121
17863
        Ok(all_entries)
122
18018
    }
123
}
124

            
125
pub(crate) const KEY_TREE: &str = "kv";
126

            
127
1536993
fn full_key(namespace: Option<&str>, key: &str) -> String {
128
1536993
    let full_length = namespace.map_or_else(|| 0, str::len) + key.len() + 1;
129
1536993
    let mut full_key = String::with_capacity(full_length);
130
1536993
    if let Some(ns) = namespace {
131
12247
        full_key.push_str(ns);
132
1524746
    }
133
1536993
    full_key.push('\0');
134
1536993
    full_key.push_str(key);
135
1536993
    full_key
136
1536993
}
137

            
138
fn split_key(full_key: &str) -> Option<(Option<String>, String)> {
139
99936
    if let Some((namespace, key)) = full_key.split_once('\0') {
140
99936
        let namespace = if namespace.is_empty() {
141
95210
            None
142
        } else {
143
4726
            Some(namespace.to_string())
144
        };
145
99936
        Some((namespace, key.to_string()))
146
    } else {
147
        None
148
    }
149
99936
}
150

            
151
1521853
fn increment(existing: &Numeric, amount: &Numeric, saturating: bool) -> Numeric {
152
1521853
    match amount {
153
456
        Numeric::Integer(amount) => {
154
456
            let existing_value = existing.as_i64_lossy(saturating);
155
456
            let new_value = if saturating {
156
304
                existing_value.saturating_add(*amount)
157
            } else {
158
152
                existing_value.wrapping_add(*amount)
159
            };
160
456
            Numeric::Integer(new_value)
161
        }
162
1520941
        Numeric::UnsignedInteger(amount) => {
163
1520941
            let existing_value = existing.as_u64_lossy(saturating);
164
1520941
            let new_value = if saturating {
165
1520789
                existing_value.saturating_add(*amount)
166
            } else {
167
152
                existing_value.wrapping_add(*amount)
168
            };
169
1520941
            Numeric::UnsignedInteger(new_value)
170
        }
171
456
        Numeric::Float(amount) => {
172
456
            let existing_value = existing.as_f64_lossy();
173
456
            let new_value = existing_value + *amount;
174
456
            Numeric::Float(new_value)
175
        }
176
    }
177
1521853
}
178

            
179
1368
fn decrement(existing: &Numeric, amount: &Numeric, saturating: bool) -> Numeric {
180
1368
    match amount {
181
456
        Numeric::Integer(amount) => {
182
456
            let existing_value = existing.as_i64_lossy(saturating);
183
456
            let new_value = if saturating {
184
304
                existing_value.saturating_sub(*amount)
185
            } else {
186
152
                existing_value.wrapping_sub(*amount)
187
            };
188
456
            Numeric::Integer(new_value)
189
        }
190
608
        Numeric::UnsignedInteger(amount) => {
191
608
            let existing_value = existing.as_u64_lossy(saturating);
192
608
            let new_value = if saturating {
193
304
                existing_value.saturating_sub(*amount)
194
            } else {
195
304
                existing_value.wrapping_sub(*amount)
196
            };
197
608
            Numeric::UnsignedInteger(new_value)
198
        }
199
304
        Numeric::Float(amount) => {
200
304
            let existing_value = existing.as_f64_lossy();
201
304
            let new_value = existing_value - *amount;
202
304
            Numeric::Float(new_value)
203
        }
204
    }
205
1368
}
206

            
207
#[derive(Debug)]
208
pub struct KeyValueState {
209
    roots: Roots<AnyFile>,
210
    persistence: KeyValuePersistence,
211
    last_commit: Timestamp,
212
    background_worker_target: Watchable<BackgroundWorkerProcessTarget>,
213
    expiring_keys: BTreeMap<String, Timestamp>,
214
    expiration_order: VecDeque<String>,
215
    dirty_keys: BTreeMap<String, Option<Entry>>,
216
    keys_being_persisted: Option<Arc<BTreeMap<String, Option<Entry>>>>,
217
    last_persistence: Watchable<Timestamp>,
218
    shutdown: Option<flume::Sender<()>>,
219
}
220

            
221
impl KeyValueState {
222
35853
    pub fn new(
223
35853
        persistence: KeyValuePersistence,
224
35853
        roots: Roots<AnyFile>,
225
35853
        background_worker_target: Watchable<BackgroundWorkerProcessTarget>,
226
35853
    ) -> Self {
227
35853
        Self {
228
35853
            roots,
229
35853
            persistence,
230
35853
            last_commit: Timestamp::now(),
231
35853
            expiring_keys: BTreeMap::new(),
232
35853
            background_worker_target,
233
35853
            expiration_order: VecDeque::new(),
234
35853
            dirty_keys: BTreeMap::new(),
235
35853
            keys_being_persisted: None,
236
35853
            last_persistence: Watchable::new(Timestamp::MIN),
237
35853
            shutdown: None,
238
35853
        }
239
35853
    }
240

            
241
31894
    pub fn shutdown(&mut self, state: &Arc<Mutex<KeyValueState>>) -> Option<flume::Receiver<()>> {
242
31894
        if self.keys_being_persisted.is_none() && self.commit_dirty_keys(state) {
243
372
            let (shutdown_sender, shutdown_receiver) = flume::bounded(1);
244
372
            self.shutdown = Some(shutdown_sender);
245
372
            Some(shutdown_receiver)
246
        } else {
247
31522
            None
248
        }
249
31894
    }
250

            
251
1537131
    pub fn perform_kv_operation(
252
1537131
        &mut self,
253
1537131
        op: KeyOperation,
254
1537131
        state: &Arc<Mutex<KeyValueState>>,
255
1537131
    ) -> Result<Output, bonsaidb_core::Error> {
256
1537131
        let now = Timestamp::now();
257
1537131
        // If there are any keys that have expired, clear them before executing any operations.
258
1537131
        self.remove_expired_keys(now);
259
1537131
        let result = match op.command {
260
6527
            Command::Set(command) => {
261
6527
                self.execute_set_operation(op.namespace.as_deref(), &op.key, command, now)
262
            }
263
5863
            Command::Get { delete } => {
264
5863
                self.execute_get_operation(op.namespace.as_deref(), &op.key, delete)
265
            }
266
1216
            Command::Delete => self.execute_delete_operation(op.namespace.as_deref(), &op.key),
267
1522005
            Command::Increment { amount, saturating } => self.execute_increment_operation(
268
1522005
                op.namespace.as_deref(),
269
1522005
                &op.key,
270
1522005
                &amount,
271
1522005
                saturating,
272
1522005
                now,
273
1522005
            ),
274
1520
            Command::Decrement { amount, saturating } => self.execute_decrement_operation(
275
1520
                op.namespace.as_deref(),
276
1520
                &op.key,
277
1520
                &amount,
278
1520
                saturating,
279
1520
                now,
280
1520
            ),
281
        };
282
1537131
        if result.is_ok() {
283
1536523
            if self.needs_commit(now) {
284
34557
                self.commit_dirty_keys(state);
285
1501966
            }
286
1536523
            self.update_background_worker_target();
287
608
        }
288
1537131
        result
289
1537131
    }
290

            
291
    #[cfg_attr(
292
        feature = "tracing",
293
        tracing::instrument(level = "trace", skip(self, set, now),)
294
    )]
295
6527
    fn execute_set_operation(
296
6527
        &mut self,
297
6527
        namespace: Option<&str>,
298
6527
        key: &str,
299
6527
        set: SetCommand,
300
6527
        now: Timestamp,
301
6527
    ) -> Result<Output, bonsaidb_core::Error> {
302
6375
        let mut entry = Entry {
303
6527
            value: set.value.validate()?,
304
6375
            expiration: set.expiration,
305
6375
            last_updated: now,
306
6375
        };
307
6375
        let full_key = full_key(namespace, key);
308
6375
        let possible_existing_value =
309
6375
            if set.check.is_some() || set.return_previous_value || set.keep_existing_expiration {
310
1438
                Some(self.get(&full_key).map_err(Error::from)?)
311
            } else {
312
4937
                None
313
            };
314
6375
        let existing_value_ref = possible_existing_value.as_ref().and_then(Option::as_ref);
315

            
316
6375
        let updating = match set.check {
317
378
            Some(KeyCheck::OnlyIfPresent) => existing_value_ref.is_some(),
318
452
            Some(KeyCheck::OnlyIfVacant) => existing_value_ref.is_none(),
319
5545
            None => true,
320
        };
321
6375
        if updating {
322
5997
            if set.keep_existing_expiration {
323
152
                if let Some(existing_value) = existing_value_ref {
324
152
                    entry.expiration = existing_value.expiration;
325
152
                }
326
5845
            }
327
5997
            self.update_key_expiration(&full_key, entry.expiration);
328

            
329
5997
            let previous_value = if let Some(existing_value) = possible_existing_value {
330
                // we already fetched, no need to ask for the existing value back
331
1060
                self.set(full_key, entry);
332
1060
                existing_value
333
            } else {
334
4937
                self.replace(full_key, entry).map_err(Error::from)?
335
            };
336
5997
            if set.return_previous_value {
337
530
                Ok(Output::Value(previous_value.map(|entry| entry.value)))
338
5467
            } else if previous_value.is_none() {
339
2431
                Ok(Output::Status(KeyStatus::Inserted))
340
            } else {
341
3036
                Ok(Output::Status(KeyStatus::Updated))
342
            }
343
        } else {
344
378
            Ok(Output::Status(KeyStatus::NotChanged))
345
        }
346
6527
    }
347

            
348
    #[cfg_attr(
349
        feature = "tracing",
350
        tracing::instrument(level = "trace", skip(self, tree_key, expiration))
351
    )]
352
7531
    pub fn update_key_expiration<'key>(
353
7531
        &mut self,
354
7531
        tree_key: impl Into<Cow<'key, str>>,
355
7531
        expiration: Option<Timestamp>,
356
7531
    ) {
357
7531
        let tree_key = tree_key.into();
358
7531
        let mut changed_first_expiration = false;
359
7531
        if let Some(expiration) = expiration {
360
997
            let key = if self.expiring_keys.contains_key(tree_key.as_ref()) {
361
                // Update the existing entry.
362
305
                let existing_entry_index = self
363
305
                    .expiration_order
364
305
                    .iter()
365
305
                    .enumerate()
366
305
                    .find_map(
367
306
                        |(index, key)| {
368
306
                            if &tree_key == key {
369
305
                                Some(index)
370
                            } else {
371
1
                                None
372
                            }
373
306
                        },
374
305
                    )
375
305
                    .unwrap();
376
305
                changed_first_expiration = existing_entry_index == 0;
377
305
                self.expiration_order.remove(existing_entry_index).unwrap()
378
            } else {
379
692
                tree_key.into_owned()
380
            };
381

            
382
            // Insert the key into the expiration_order queue
383
997
            let mut insert_at = None;
384
997
            for (index, expiring_key) in self.expiration_order.iter().enumerate() {
385
460
                if self.expiring_keys.get(expiring_key).unwrap() > &expiration {
386
153
                    insert_at = Some(index);
387
153
                    break;
388
307
                }
389
            }
390
997
            if let Some(insert_at) = insert_at {
391
153
                changed_first_expiration |= insert_at == 0;
392
153

            
393
153
                self.expiration_order.insert(insert_at, key.clone());
394
844
            } else {
395
844
                changed_first_expiration |= self.expiration_order.is_empty();
396
844
                self.expiration_order.push_back(key.clone());
397
844
            }
398
997
            self.expiring_keys.insert(key, expiration);
399
6534
        } else if self.expiring_keys.remove(tree_key.as_ref()).is_some() {
400
153
            let index = self
401
153
                .expiration_order
402
153
                .iter()
403
153
                .enumerate()
404
153
                .find_map(|(index, key)| {
405
153
                    if tree_key.as_ref() == key {
406
153
                        Some(index)
407
                    } else {
408
                        None
409
                    }
410
153
                })
411
153
                .unwrap();
412
153

            
413
153
            changed_first_expiration |= index == 0;
414
153
            self.expiration_order.remove(index);
415
6381
        }
416

            
417
7531
        if changed_first_expiration {
418
995
            self.update_background_worker_target();
419
6536
        }
420
7531
    }
421

            
422
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
423
5863
    fn execute_get_operation(
424
5863
        &mut self,
425
5863
        namespace: Option<&str>,
426
5863
        key: &str,
427
5863
        delete: bool,
428
5863
    ) -> Result<Output, bonsaidb_core::Error> {
429
5863
        let full_key = full_key(namespace, key);
430
5863
        let entry = if delete {
431
304
            self.remove(full_key).map_err(Error::from)?
432
        } else {
433
5559
            self.get(&full_key).map_err(Error::from)?
434
        };
435

            
436
5863
        Ok(Output::Value(entry.map(|e| e.value)))
437
5863
    }
438

            
439
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
440
1216
    fn execute_delete_operation(
441
1216
        &mut self,
442
1216
        namespace: Option<&str>,
443
1216
        key: &str,
444
1216
    ) -> Result<Output, bonsaidb_core::Error> {
445
1216
        let full_key = full_key(namespace, key);
446
1216
        let value = self.remove(full_key).map_err(Error::from)?;
447
1216
        if value.is_some() {
448
456
            Ok(Output::Status(KeyStatus::Deleted))
449
        } else {
450
760
            Ok(Output::Status(KeyStatus::NotChanged))
451
        }
452
1216
    }
453

            
454
    #[cfg_attr(
455
        feature = "tracing",
456
        tracing::instrument(level = "trace", skip(self, amount, saturating, now))
457
    )]
458
1522005
    fn execute_increment_operation(
459
1522005
        &mut self,
460
1522005
        namespace: Option<&str>,
461
1522005
        key: &str,
462
1522005
        amount: &Numeric,
463
1522005
        saturating: bool,
464
1522005
        now: Timestamp,
465
1522005
    ) -> Result<Output, bonsaidb_core::Error> {
466
1522005
        self.execute_numeric_operation(namespace, key, amount, saturating, now, increment)
467
1522005
    }
468

            
469
    #[cfg_attr(
470
        feature = "tracing",
471
        tracing::instrument(level = "trace", skip(self, amount, saturating, now))
472
    )]
473
1520
    fn execute_decrement_operation(
474
1520
        &mut self,
475
1520
        namespace: Option<&str>,
476
1520
        key: &str,
477
1520
        amount: &Numeric,
478
1520
        saturating: bool,
479
1520
        now: Timestamp,
480
1520
    ) -> Result<Output, bonsaidb_core::Error> {
481
1520
        self.execute_numeric_operation(namespace, key, amount, saturating, now, decrement)
482
1520
    }
483

            
484
1523525
    fn execute_numeric_operation<F: Fn(&Numeric, &Numeric, bool) -> Numeric>(
485
1523525
        &mut self,
486
1523525
        namespace: Option<&str>,
487
1523525
        key: &str,
488
1523525
        amount: &Numeric,
489
1523525
        saturating: bool,
490
1523525
        now: Timestamp,
491
1523525
        op: F,
492
1523525
    ) -> Result<Output, bonsaidb_core::Error> {
493
1523525
        let full_key = full_key(namespace, key);
494
1523525
        let current = self.get(&full_key).map_err(Error::from)?;
495
1523525
        let mut entry = current.unwrap_or(Entry {
496
1523525
            value: Value::Numeric(Numeric::UnsignedInteger(0)),
497
1523525
            expiration: None,
498
1523525
            last_updated: now,
499
1523525
        });
500
1523525

            
501
1523525
        match entry.value {
502
1523221
            Value::Numeric(existing) => {
503
1523221
                let value = Value::Numeric(op(&existing, amount, saturating).validate()?);
504
1523069
                entry.value = value.clone();
505
1523069

            
506
1523069
                self.set(full_key, entry);
507
1523069
                Ok(Output::Value(Some(value)))
508
            }
509
304
            Value::Bytes(_) => Err(bonsaidb_core::Error::other(
510
304
                "bonsaidb-local",
511
304
                "type of stored `Value` is not `Numeric`",
512
304
            )),
513
        }
514
1523525
    }
515

            
516
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
517
1520
    fn remove(&mut self, key: String) -> Result<Option<Entry>, nebari::Error> {
518
1520
        self.update_key_expiration(&key, None);
519

            
520
1520
        if let Some(dirty_entry) = self.dirty_keys.get_mut(&key) {
521
200
            Ok(dirty_entry.take())
522
1320
        } else if let Some(persisting_entry) = self
523
1320
            .keys_being_persisted
524
1320
            .as_ref()
525
1320
            .and_then(|keys| keys.get(&key))
526
        {
527
112
            self.dirty_keys.insert(key, None);
528
112
            Ok(persisting_entry.clone())
529
        } else {
530
            // There might be a value on-disk we need to remove.
531
1208
            let previous_value = Self::retrieve_key_from_disk(&self.roots, &key)?;
532
1208
            self.dirty_keys.insert(key, None);
533
1208
            Ok(previous_value)
534
        }
535
1520
    }
536

            
537
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
538
    fn get(&self, key: &str) -> Result<Option<Entry>, nebari::Error> {
539
1530522
        if let Some(entry) = self.dirty_keys.get(key) {
540
1431752
            Ok(entry.clone())
541
98770
        } else if let Some(persisting_entry) = self
542
98770
            .keys_being_persisted
543
98770
            .as_ref()
544
98770
            .and_then(|keys| keys.get(key))
545
        {
546
69228
            Ok(persisting_entry.clone())
547
        } else {
548
29542
            Self::retrieve_key_from_disk(&self.roots, key)
549
        }
550
1530522
    }
551

            
552
1524129
    fn set(&mut self, key: String, value: Entry) {
553
1524129
        self.dirty_keys.insert(key, Some(value));
554
1524129
    }
555

            
556
4937
    fn replace(&mut self, key: String, value: Entry) -> Result<Option<Entry>, nebari::Error> {
557
4937
        let mut value = Some(value);
558
4937
        let map_entry = self.dirty_keys.entry(key);
559
4937
        if matches!(map_entry, btree_map::Entry::Vacant(_)) {
560
            // This key is clean, and the caller is expecting the previous
561
            // value.
562
3397
            let stored_value = if let Some(persisting_entry) = self
563
3397
                .keys_being_persisted
564
3397
                .as_ref()
565
3397
                .and_then(|keys| keys.get(map_entry.key()))
566
            {
567
870
                persisting_entry.clone()
568
            } else {
569
2527
                Self::retrieve_key_from_disk(&self.roots, map_entry.key())?
570
            };
571
3397
            map_entry.or_insert(value);
572
3397
            Ok(stored_value)
573
        } else {
574
            // This key is already dirty, we can just replace the value and
575
            // return the old value.
576
1540
            map_entry.and_modify(|map_entry| {
577
1540
                std::mem::swap(&mut value, map_entry);
578
1540
            });
579
1540
            Ok(value)
580
        }
581
4937
    }
582

            
583
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(roots)))]
584
    fn retrieve_key_from_disk(
585
        roots: &Roots<AnyFile>,
586
        key: &str,
587
    ) -> Result<Option<Entry>, nebari::Error> {
588
33277
        roots
589
33277
            .tree(Unversioned::tree(KEY_TREE))?
590
33277
            .get(key.as_bytes())
591
33277
            .map(|current| current.and_then(|current| bincode::deserialize::<Entry>(&current).ok()))
592
33277
    }
593

            
594
1703940
    fn update_background_worker_target(&mut self) {
595
1703940
        let key_expiration_target = self.expiration_order.get(0).map(|key| {
596
4634
            let expiration_timeout = self.expiring_keys.get(key).unwrap();
597
4634
            *expiration_timeout
598
1703940
        });
599
1703940
        let now = Timestamp::now();
600
1703940
        let persisting = self.keys_being_persisted.is_some();
601
1703940
        let commit_target = (!persisting)
602
1703940
            .then(|| {
603
101118
                self.persistence.duration_until_next_commit(
604
101118
                    self.dirty_keys.len(),
605
101118
                    (now - self.last_commit).unwrap_or_default(),
606
101118
                )
607
1703940
            })
608
1703940
            .flatten()
609
1703940
            .map(|duration| now + duration);
610
1703940
        match (commit_target, key_expiration_target) {
611
73019
            (Some(target), _) | (_, Some(target)) if target <= now => {
612
68910
                self.background_worker_target
613
68910
                    .replace(BackgroundWorkerProcessTarget::Now);
614
68910
            }
615
            (Some(commit_target), Some(key_target)) => {
616
                let closest_target = key_target.min(commit_target);
617
                let new_target = BackgroundWorkerProcessTarget::Timestamp(closest_target);
618
                let _: Result<_, _> = self.background_worker_target.update(new_target);
619
            }
620
4109
            (Some(target), None) | (None, Some(target)) => {
621
4109
                let _: Result<_, _> = self
622
4109
                    .background_worker_target
623
4109
                    .update(BackgroundWorkerProcessTarget::Timestamp(target));
624
4109
            }
625
1630921
            (None, None) => {
626
1630921
                let _: Result<_, _> = self
627
1630921
                    .background_worker_target
628
1630921
                    .update(BackgroundWorkerProcessTarget::Never);
629
1630921
            }
630
        }
631
1703940
    }
632

            
633
1605049
    fn remove_expired_keys(&mut self, now: Timestamp) {
634
1605513
        while !self.expiration_order.is_empty()
635
2627
            && self.expiring_keys.get(&self.expiration_order[0]).unwrap() <= &now
636
464
        {
637
464
            let key = self.expiration_order.pop_front().unwrap();
638
464
            self.expiring_keys.remove(&key);
639
464
            self.dirty_keys.insert(key, None);
640
464
        }
641
1605049
    }
642

            
643
1604441
    fn needs_commit(&mut self, now: Timestamp) -> bool {
644
1604441
        if self.keys_being_persisted.is_some() {
645
1504154
            false
646
        } else {
647
100287
            let since_last_commit = (now - self.last_commit).unwrap_or_default();
648
100287
            self.persistence
649
100287
                .should_commit(self.dirty_keys.len(), since_last_commit)
650
        }
651
1604441
    }
652

            
653
130010
    fn stage_dirty_keys(&mut self) -> Option<Arc<BTreeMap<String, Option<Entry>>>> {
654
130010
        if !self.dirty_keys.is_empty() && self.keys_being_persisted.is_none() {
655
98504
            let keys = Arc::new(std::mem::take(&mut self.dirty_keys));
656
98504
            self.keys_being_persisted = Some(keys.clone());
657
98504
            Some(keys)
658
        } else {
659
31506
            None
660
        }
661
130010
    }
662

            
663
    pub fn commit_dirty_keys(&mut self, state: &Arc<Mutex<KeyValueState>>) -> bool {
664
129638
        if let Some(keys) = self.stage_dirty_keys() {
665
98504
            let roots = self.roots.clone();
666
98504
            let state = state.clone();
667
98504
            std::thread::Builder::new()
668
98504
                .name(String::from("keyvalue-persist"))
669
98504
                .spawn(move || Self::persist_keys(&state, &roots, &keys))
670
98504
                .unwrap();
671
98504
            self.last_commit = Timestamp::now();
672
98504
            true
673
        } else {
674
31134
            false
675
        }
676
129638
    }
677

            
678
    #[cfg(test)]
679
5
    pub fn persistence_watcher(&self) -> Watcher<Timestamp> {
680
5
        self.last_persistence.watch()
681
5
    }
682

            
683
197008
    #[cfg_attr(feature = "instrument", tracing::instrument(level = "trace", skip_all))]
684
    fn persist_keys(
685
        key_value_state: &Arc<Mutex<KeyValueState>>,
686
        roots: &Roots<AnyFile>,
687
        keys: &BTreeMap<String, Option<Entry>>,
688
    ) -> Result<(), bonsaidb_core::Error> {
689
        let mut transaction = roots
690
            .transaction(&[Unversioned::tree(KEY_TREE)])
691
            .map_err(Error::from)?;
692
        let all_keys = keys
693
            .keys()
694
99875
            .map(|key| ArcBytes::from(key.as_bytes().to_vec()))
695
            .collect();
696
        let mut changed_keys = Vec::new();
697
        transaction
698
            .tree::<Unversioned>(0)
699
            .unwrap()
700
            .modify(
701
                all_keys,
702
99875
                Operation::CompareSwap(CompareSwap::new(&mut |key, existing_value| {
703
99875
                    let full_key = std::str::from_utf8(key).unwrap();
704
99875
                    let (namespace, key) = split_key(full_key).unwrap();
705

            
706
99875
                    if let Some(new_value) = keys.get(full_key).unwrap() {
707
98203
                        changed_keys.push(ChangedKey {
708
98203
                            namespace,
709
98203
                            key,
710
98203
                            deleted: false,
711
98203
                        });
712
98203
                        let bytes = bincode::serialize(new_value).unwrap();
713
98203
                        nebari::tree::KeyOperation::Set(ArcBytes::from(bytes))
714
1672
                    } else if existing_value.is_some() {
715
1028
                        changed_keys.push(ChangedKey {
716
1028
                            namespace,
717
1028
                            key,
718
1028
                            deleted: existing_value.is_some(),
719
1028
                        });
720
1028
                        nebari::tree::KeyOperation::Remove
721
                    } else {
722
644
                        nebari::tree::KeyOperation::Skip
723
                    }
724
99875
                })),
725
            )
726
            .map_err(Error::from)?;
727

            
728
        if !changed_keys.is_empty() {
729
            transaction
730
                .entry_mut()
731
                .set_data(compat::serialize_executed_transaction_changes(
732
                    &Changes::Keys(changed_keys),
733
                )?)
734
                .map_err(Error::from)?;
735
            transaction.commit().map_err(Error::from)?;
736
        }
737

            
738
        // If we are shutting down, check if we still have dirty keys.
739
        let final_keys = {
740
            let mut state = key_value_state.lock();
741
            state.last_persistence.replace(Timestamp::now());
742
            state.keys_being_persisted = None;
743
            state.update_background_worker_target();
744
            // This block is a little ugly to avoid having to acquire the lock
745
            // twice. If we're shutting down and have no dirty keys, we notify
746
            // the waiting shutdown task. If we have any dirty keys, we wait do
747
            // to that step because we're going to recurse and reach this spot
748
            // again.
749
            if state.shutdown.is_some() {
750
                let staged_keys = state.stage_dirty_keys();
751
                if staged_keys.is_none() {
752
                    let shutdown = state.shutdown.take().unwrap();
753
                    let _: Result<_, _> = shutdown.send(());
754
                }
755
                staged_keys
756
            } else {
757
                None
758
            }
759
        };
760
        if let Some(final_keys) = final_keys {
761
            Self::persist_keys(key_value_state, roots, &final_keys)?;
762
        }
763
        Ok(())
764
    }
765
}
766

            
767
35853
pub fn background_worker(
768
35853
    key_value_state: &Weak<Mutex<KeyValueState>>,
769
35853
    timestamp_receiver: &mut Watcher<BackgroundWorkerProcessTarget>,
770
35853
    storage_lock: Option<StorageLock>,
771
35853
) {
772
172416
    loop {
773
172416
        let mut perform_operations = false;
774
172416
        let current_target = *timestamp_receiver.read();
775
172416
        match current_target {
776
            // With no target, sleep until we receive a target.
777
            BackgroundWorkerProcessTarget::Never => {
778
104386
                if timestamp_receiver.watch().is_err() {
779
31698
                    break;
780
72688
                }
781
            }
782
1218
            BackgroundWorkerProcessTarget::Timestamp(target) => {
783
1218
                // With a target, we need to wait to receive a target only as
784
1218
                // long as there is time remaining.
785
1218
                let remaining = target - Timestamp::now();
786
1218
                if let Some(remaining) = remaining {
787
                    // recv_timeout panics if Instant::checked_add(remaining)
788
                    // fails. So, we will cap the sleep time at 1 day.
789
1218
                    let remaining = remaining.min(Duration::from_secs(60 * 60 * 24));
790
1218
                    match timestamp_receiver.watch_timeout(remaining) {
791
1217
                        Ok(_) | Err(watchable::TimeoutError::Timeout) => {
792
1217
                            perform_operations = true;
793
1217
                        }
794
1
                        Err(watchable::TimeoutError::Disconnected) => break,
795
                    }
796
                } else {
797
                    perform_operations = true;
798
                }
799
            }
800
66812
            BackgroundWorkerProcessTarget::Now => {
801
66812
                perform_operations = true;
802
66812
            }
803
        };
804

            
805
140717
        let Some(key_value_state) = key_value_state.upgrade() else {
806
195
            break;
807
        };
808

            
809
140522
        if perform_operations {
810
71877
            let mut state = key_value_state.lock();
811
71877
            let now = Timestamp::now();
812
71877
            state.remove_expired_keys(now);
813
71877
            if state.needs_commit(now) {
814
63575
                state.commit_dirty_keys(&key_value_state);
815
63580
            }
816
67918
            state.update_background_worker_target();
817
68645
        }
818
    }
819

            
820
    // The key-value store's delayed persistence can cause the key-value storage
821
    // to be written past when the last reference to the storage is still held.
822
    // The storage lock being held ensures that another reader/writer doesn't
823
    // begin accessing this same storage again.
824
31894
    drop(storage_lock);
825
31894
}
826

            
827
1635030
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
828
pub enum BackgroundWorkerProcessTarget {
829
    Now,
830
    Timestamp(Timestamp),
831
    Never,
832
}
833

            
834
#[derive(Debug)]
835
pub struct ExpirationLoader {
836
    pub database: Database,
837
    pub launched_at: Timestamp,
838
}
839

            
840
impl Keyed<Task> for ExpirationLoader {
841
18012
    fn key(&self) -> Task {
842
18012
        Task::ExpirationLoader(self.database.data.name.clone())
843
18012
    }
844
}
845

            
846
impl Job for ExpirationLoader {
847
    type Error = Error;
848
    type Output = ();
849

            
850
    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
851
17716
    fn execute(&mut self) -> Result<Self::Output, Self::Error> {
852
17716
        let database = self.database.clone();
853
17716
        let launched_at = self.launched_at;
854

            
855
17716
        for ((namespace, key), entry) in database.all_key_value_entries()? {
856
48
            if entry.last_updated < launched_at && entry.expiration.is_some() {
857
1
                self.database
858
1
                    .update_key_expiration(full_key(namespace.as_deref(), &key), entry.expiration);
859
47
            }
860
        }
861

            
862
17561
        self.database
863
17561
            .storage()
864
17561
            .instance
865
17561
            .tasks()
866
17561
            .mark_key_value_expiration_loaded(self.database.data.name.clone());
867
17561

            
868
17561
        Ok(())
869
17716
    }
870
}
871

            
872
#[cfg(test)]
873
mod tests {
874
    use std::time::{Duration, Instant};
875

            
876
    use bonsaidb_core::arc_bytes::serde::Bytes;
877
    use bonsaidb_core::test_util::{TestDirectory, TimingTest};
878
    use nebari::io::any::{AnyFile, AnyFileManager};
879

            
880
    use super::*;
881
    use crate::config::PersistenceThreshold;
882
    use crate::database::Context;
883

            
884
6
    fn run_test_with_persistence<
885
6
        F: Fn(Context, nebari::Roots<AnyFile>) -> anyhow::Result<()> + Send,
886
6
    >(
887
6
        name: &str,
888
6
        persistence: KeyValuePersistence,
889
6
        test_contents: &F,
890
6
    ) -> anyhow::Result<()> {
891
6
        let dir = TestDirectory::new(name);
892
6
        let sled = nebari::Config::new(&dir)
893
6
            .file_manager(AnyFileManager::std())
894
6
            .open()?;
895

            
896
6
        let context = Context::new(sled.clone(), persistence, None);
897
6

            
898
6
        test_contents(context, sled)?;
899

            
900
6
        Ok(())
901
6
    }
902

            
903
5
    fn run_test<F: Fn(Context, nebari::Roots<AnyFile>) -> anyhow::Result<()> + Send>(
904
5
        name: &str,
905
5
        test_contents: F,
906
5
    ) -> anyhow::Result<()> {
907
5
        run_test_with_persistence(name, KeyValuePersistence::default(), &test_contents)
908
5
    }
909

            
910
1
    #[test]
911
1
    fn basic_expiration() -> anyhow::Result<()> {
912
1
        run_test("kv-basic-expiration", |context, roots| {
913
1
            // Initialize the test state
914
1
            let mut persistence_watcher = context.kv_persistence_watcher();
915
1
            roots.delete_tree(KEY_TREE)?;
916
1
            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
917
1
            tree.set(b"atree\0akey", b"somevalue")?;
918

            
919
            // Expire the existing key
920
1
            context.update_key_expiration(
921
1
                full_key(Some("atree"), "akey"),
922
1
                Some(Timestamp::now() + Duration::from_millis(100)),
923
1
            );
924
1
            // Wait for persistence.
925
1
            persistence_watcher.next_value()?;
926

            
927
            // Verify it is gone.
928
1
            assert!(tree.get(b"akey")?.is_none());
929

            
930
1
            Ok(())
931
1
        })
932
1
    }
933

            
934
1
    #[test]
935
1
    fn updating_expiration() -> anyhow::Result<()> {
936
1
        run_test("kv-updating-expiration", |context, roots| {
937
1
            // Initialize the test state
938
1
            let mut persistence_watcher = context.kv_persistence_watcher();
939
1
            roots.delete_tree(KEY_TREE)?;
940
1
            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
941
1
            tree.set(b"atree\0akey", b"somevalue")?;
942
1
            let start = Timestamp::now();
943
1

            
944
1
            // Set the expiration once.
945
1
            context.update_key_expiration(
946
1
                full_key(Some("atree"), "akey"),
947
1
                Some(start + Duration::from_millis(100)),
948
1
            );
949
1
            // Set the expiration to a longer value.
950
1
            let correct_expiration = start + Duration::from_secs(1);
951
1
            context
952
1
                .update_key_expiration(full_key(Some("atree"), "akey"), Some(correct_expiration));
953

            
954
            // Wait for persistence, and ensure that the next persistence is
955
            // after our expiration timestamp.
956
1
            assert!(persistence_watcher.next_value()? > correct_expiration);
957

            
958
            // Verify the key is gone now.
959
1
            assert_eq!(tree.get(b"atree\0akey")?, None);
960

            
961
1
            Ok(())
962
1
        })
963
1
    }
964

            
965
1
    #[test]
966
1
    fn multiple_keys_expiration() -> anyhow::Result<()> {
967
1
        run_test("kv-multiple-keys-expiration", |context, roots| {
968
1
            // Initialize the test state
969
1
            let mut persistence_watcher = context.kv_persistence_watcher();
970
1
            roots.delete_tree(KEY_TREE)?;
971
1
            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
972
1
            tree.set(b"atree\0akey", b"somevalue")?;
973
1
            tree.set(b"atree\0bkey", b"somevalue")?;
974

            
975
            // Expire both keys, one for a shorter time than the other.
976
1
            context.update_key_expiration(
977
1
                full_key(Some("atree"), "akey"),
978
1
                Some(Timestamp::now() + Duration::from_millis(100)),
979
1
            );
980
1
            context.update_key_expiration(
981
1
                full_key(Some("atree"), "bkey"),
982
1
                Some(Timestamp::now() + Duration::from_secs(1)),
983
1
            );
984
1

            
985
1
            // Wait for the first persistence.
986
1
            persistence_watcher.next_value()?;
987
1
            assert!(tree.get(b"atree\0akey")?.is_none());
988
1
            assert!(tree.get(b"atree\0bkey")?.is_some());
989

            
990
            // Wait for the second persistence.
991
1
            persistence_watcher.next_value()?;
992
1
            assert!(tree.get(b"atree\0bkey")?.is_none());
993

            
994
1
            Ok(())
995
1
        })
996
1
    }
997

            
998
1
    #[test]
999
1
    fn clearing_expiration() -> anyhow::Result<()> {
1
        run_test("kv-clearing-expiration", |sender, sled| {
            loop {
1
                sled.delete_tree(KEY_TREE)?;
1
                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
1
                tree.set(b"atree\0akey", b"somevalue")?;
1
                let timing = TimingTest::new(Duration::from_millis(100));
1
                sender.update_key_expiration(
1
                    full_key(Some("atree"), "akey"),
1
                    Some(Timestamp::now() + Duration::from_millis(100)),
1
                );
1
                sender.update_key_expiration(full_key(Some("atree"), "akey"), None);
1
                if timing.elapsed() > Duration::from_millis(100) {
                    // Restart, took too long.
                    continue;
1
                }
1
                timing.wait_until(Duration::from_millis(150));
1
                assert!(tree.get(b"atree\0akey")?.is_some());
1
                break;
1
            }
1

            
1
            Ok(())
1
        })
1
    }

            
1
    #[test]
1
    fn out_of_order_expiration() -> anyhow::Result<()> {
1
        run_test("kv-out-of-order-expiration", |context, roots| loop {
1
            context.update_key_expiration(full_key(Some("atree"), "akey"), None);
1
            context.update_key_expiration(full_key(Some("atree"), "bkey"), None);
1
            context.update_key_expiration(full_key(Some("atree"), "ckey"), None);
1
            let mut persistence_watcher = context.kv_persistence_watcher();
1
            drop(roots.delete_tree(KEY_TREE));
1
            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
1
            tree.set(b"atree\0akey", b"somevalue")?;
1
            tree.set(b"atree\0bkey", b"somevalue")?;
1
            tree.set(b"atree\0ckey", b"somevalue")?;
1
            let timing = TimingTest::new(Duration::from_millis(100));
1
            context.update_key_expiration(
1
                full_key(Some("atree"), "akey"),
1
                Some(Timestamp::now() + Duration::from_secs(3)),
1
            );
1
            context.update_key_expiration(
1
                full_key(Some("atree"), "ckey"),
1
                Some(Timestamp::now() + Duration::from_secs(1)),
1
            );
1
            context.update_key_expiration(
1
                full_key(Some("atree"), "bkey"),
1
                Some(Timestamp::now() + Duration::from_secs(2)),
1
            );
1
            persistence_watcher.mark_read();
1
            if timing.elapsed() > Duration::from_millis(500) {
                println!("Restarting");
                continue;
1
            }
1

            
1
            // Wait for the first key to expire.
1
            persistence_watcher
1
                .watch_timeout(Duration::from_secs(5))
1
                .unwrap();
1
            persistence_watcher.mark_read();
1
            if timing.elapsed() > Duration::from_millis(1500) {
                println!("Restarting");
                continue;
1
            }
1
            assert!(tree.get(b"atree\0akey")?.is_some());
1
            assert!(tree.get(b"atree\0bkey")?.is_some());
1
            assert!(tree.get(b"atree\0ckey")?.is_none());

            
            // Wait for the next key to expire.
1
            persistence_watcher
1
                .watch_timeout(Duration::from_secs(5))
1
                .unwrap();
1
            persistence_watcher.mark_read();
1
            if timing.elapsed() > Duration::from_millis(2500) {
                println!("Restarting");
                continue;
1
            }
1
            assert!(tree.get(b"atree\0akey")?.is_some());
1
            assert!(tree.get(b"atree\0bkey")?.is_none());

            
            // Wait for the final key to expire.
1
            persistence_watcher
1
                .watch_timeout(Duration::from_secs(5))
1
                .unwrap();
1
            if timing.elapsed() > Duration::from_millis(3500) {
                println!("Restarting");
                continue;
1
            }
1
            assert!(tree.get(b"atree\0akey")?.is_none());

            
1
            return Ok(());
1
        })
1
    }

            
1
    #[test]
1
    fn basic_persistence() -> anyhow::Result<()> {
1
        run_test_with_persistence(
1
            "kv-basic-persistence",
1
            KeyValuePersistence::lazy([
1
                PersistenceThreshold::after_changes(2),
1
                PersistenceThreshold::after_changes(1).and_duration(Duration::from_secs(2)),
1
            ]),
1
            &|context, roots| {
1
                // Initialize the test state
1
                let mut persistence_watcher = context.kv_persistence_watcher();
1
                let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
1
                let start = Instant::now();
1
                // Set three keys in quick succession. The first two should
1
                // persist immediately after the second is set, and the
1
                // third should show up after 2 seconds.
1
                context
1
                    .perform_kv_operation(KeyOperation {
1
                        namespace: None,
1
                        key: String::from("key1"),
1
                        command: Command::Set(SetCommand {
1
                            value: Value::Bytes(Bytes::default()),
1
                            expiration: None,
1
                            keep_existing_expiration: false,
1
                            check: None,
1
                            return_previous_value: false,
1
                        }),
1
                    })
1
                    .unwrap();
1
                context
1
                    .perform_kv_operation(KeyOperation {
1
                        namespace: None,
1
                        key: String::from("key2"),
1
                        command: Command::Set(SetCommand {
1
                            value: Value::Bytes(Bytes::default()),
1
                            expiration: None,
1
                            keep_existing_expiration: false,
1
                            check: None,
1
                            return_previous_value: false,
1
                        }),
1
                    })
1
                    .unwrap();
1
                context
1
                    .perform_kv_operation(KeyOperation {
1
                        namespace: None,
1
                        key: String::from("key3"),
1
                        command: Command::Set(SetCommand {
1
                            value: Value::Bytes(Bytes::default()),
1
                            expiration: None,
1
                            keep_existing_expiration: false,
1
                            check: None,
1
                            return_previous_value: false,
1
                        }),
1
                    })
1
                    .unwrap();
1
                // Wait for the first persistence to occur.
1
                persistence_watcher.next_value()?;
1

            
1
                assert!(tree.get(b"\0key1").unwrap().is_some());
1
                assert!(tree.get(b"\0key2").unwrap().is_some());
1
                assert!(tree.get(b"\0key3").unwrap().is_none());
1

            
1
                // Wait for the second persistence
1
                persistence_watcher.next_value()?;
1
                assert!(tree.get(b"\0key3").unwrap().is_some());
1
                // The total operation should have taken *at least* two seconds,
1
                // since the second persistence should have delayed for two
1
                // seconds itself.
1
                assert!(start.elapsed() > Duration::from_secs(2));
1

            
1
                Ok(())
1
            },
1
        )
1
    }

            
1
    #[test]
1
    fn saves_on_drop() -> anyhow::Result<()> {
1
        let dir = TestDirectory::new("saves-on-drop.bonsaidb");
1
        let sled = nebari::Config::new(&dir)
1
            .file_manager(AnyFileManager::std())
1
            .open()?;
1
        let tree = sled.tree(Unversioned::tree(KEY_TREE))?;

            
1
        let context = Context::new(
1
            sled,
1
            KeyValuePersistence::lazy([PersistenceThreshold::after_changes(2)]),
1
            None,
1
        );
1
        context
1
            .perform_kv_operation(KeyOperation {
1
                namespace: None,
1
                key: String::from("key1"),
1
                command: Command::Set(SetCommand {
1
                    value: Value::Bytes(Bytes::default()),
1
                    expiration: None,
1
                    keep_existing_expiration: false,
1
                    check: None,
1
                    return_previous_value: false,
1
                }),
1
            })
1
            .unwrap();
1
        assert!(tree.get(b"\0key1").unwrap().is_none());
1
        drop(context);
1
        // Dropping spawns a task that should persist the keys. Give a moment
1
        // for the runtime to execute the task.
1
        std::thread::sleep(Duration::from_millis(100));
1
        assert!(tree.get(b"\0key1").unwrap().is_some());

            
1
        Ok(())
1
    }
}