1
use std::{
2
    borrow::Cow,
3
    collections::{btree_map, BTreeMap, VecDeque},
4
    sync::Arc,
5
    time::Duration,
6
};
7

            
8
use async_lock::Mutex;
9
use async_trait::async_trait;
10
use bonsaidb_core::{
11
    keyvalue::{
12
        Command, KeyCheck, KeyOperation, KeyStatus, KeyValue, Numeric, Output, SetCommand,
13
        Timestamp, Value,
14
    },
15
    transaction::{ChangedKey, Changes},
16
};
17
use bonsaidb_utils::fast_async_lock;
18
use nebari::{
19
    io::any::AnyFile,
20
    tree::{CompareSwap, KeyEvaluation, Operation, Root, Unversioned},
21
    AbortError, ArcBytes, Roots,
22
};
23
use serde::{Deserialize, Serialize};
24
use tokio::{
25
    runtime::Handle,
26
    sync::{oneshot, watch},
27
};
28

            
29
use crate::{
30
    config::KeyValuePersistence,
31
    database::compat,
32
    tasks::{Job, Keyed, Task},
33
    Database, Error,
34
};
35

            
36
751363
#[derive(Serialize, Deserialize, Debug, Clone)]
37
pub struct Entry {
38
    pub value: Value,
39
    pub expiration: Option<Timestamp>,
40
    #[serde(default)]
41
    pub last_updated: Timestamp,
42
}
43

            
44
impl Entry {
45
3
    pub(crate) async fn restore(
46
3
        self,
47
3
        namespace: Option<String>,
48
3
        key: String,
49
3
        database: &Database,
50
3
    ) -> Result<(), bonsaidb_core::Error> {
51
3
        database
52
3
            .execute_key_operation(KeyOperation {
53
3
                namespace,
54
3
                key,
55
3
                command: Command::Set(SetCommand {
56
3
                    value: self.value,
57
3
                    expiration: self.expiration,
58
3
                    keep_existing_expiration: false,
59
3
                    check: None,
60
3
                    return_previous_value: false,
61
3
                }),
62
3
            })
63
            .await?;
64
3
        Ok(())
65
3
    }
66
}
67

            
68
#[async_trait]
69
impl KeyValue for Database {
70
809134
    async fn execute_key_operation(
71
809134
        &self,
72
809134
        op: KeyOperation,
73
809134
    ) -> Result<Output, bonsaidb_core::Error> {
74
809134
        self.data.context.perform_kv_operation(op).await
75
1618268
    }
76
}
77

            
78
impl Database {
79
8991
    pub(crate) async fn all_key_value_entries(
80
8991
        &self,
81
8991
    ) -> Result<BTreeMap<(Option<String>, String), Entry>, Error> {
82
        // Lock the state so that new new modifications can be made while we gather this snapshot.
83
8991
        let state = fast_async_lock!(self.data.context.key_value_state);
84
8991
        let database = self.clone();
85
8991
        // Initialize our entries with any dirty keys and any keys that are about to be persisted.
86
8991
        let mut all_entries = BTreeMap::new();
87
8991
        let mut all_entries = tokio::task::spawn_blocking(move || {
88
8963
            database
89
8963
                .roots()
90
8963
                .tree(Unversioned::tree(KEY_TREE))?
91
8963
                .scan::<Error, _, _, _, _>(
92
8963
                    &(..),
93
8963
                    true,
94
8963
                    |_, _, _| true,
95
8963
                    |_, _| KeyEvaluation::ReadData,
96
8963
                    |key, _, entry: ArcBytes<'static>| {
97
212
                        let entry = bincode::deserialize::<Entry>(&entry)
98
212
                            .map_err(|err| AbortError::Other(Error::from(err)))?;
99
212
                        let full_key = std::str::from_utf8(&key)
100
212
                            .map_err(|err| AbortError::Other(Error::from(err)))?;
101

            
102
212
                        if let Some(split_key) = split_key(full_key) {
103
212
                            // Do not overwrite the existing key
104
212
                            all_entries.entry(split_key).or_insert(entry);
105
212
                        }
106

            
107
212
                        Ok(())
108
8963
                    },
109
8963
                )?;
110
8937
            Result::<_, Error>::Ok(all_entries)
111
8991
        })
112
8217
        .await??;
113

            
114
        // Apply the pending writes first
115
8907
        if let Some(pending_keys) = &state.keys_being_persisted {
116
583
            for (key, possible_entry) in pending_keys.iter() {
117
583
                let (namespace, key) = split_key(key).unwrap();
118
583
                if let Some(updated_entry) = possible_entry {
119
504
                    all_entries.insert((namespace, key), updated_entry.clone());
120
504
                } else {
121
79
                    all_entries.remove(&(namespace, key));
122
79
                }
123
            }
124
8431
        }
125

            
126
8907
        for (key, possible_entry) in &state.dirty_keys {
127
189
            let (namespace, key) = split_key(key).unwrap();
128
189
            if let Some(updated_entry) = possible_entry {
129
189
                all_entries.insert((namespace, key), updated_entry.clone());
130
189
            } else {
131
                all_entries.remove(&(namespace, key));
132
            }
133
        }
134

            
135
8907
        Ok(all_entries)
136
8907
    }
137
}
138

            
139
pub(crate) const KEY_TREE: &str = "kv";
140

            
141
809069
fn full_key(namespace: Option<&str>, key: &str) -> String {
142
809069
    let full_length = namespace.map_or_else(|| 0, str::len) + key.len() + 1;
143
809069
    let mut full_key = String::with_capacity(full_length);
144
809069
    if let Some(ns) = namespace {
145
6436
        full_key.push_str(ns);
146
802633
    }
147
809069
    full_key.push('\0');
148
809069
    full_key.push_str(key);
149
809069
    full_key
150
809069
}
151

            
152
fn split_key(full_key: &str) -> Option<(Option<String>, String)> {
153
125639
    if let Some((namespace, key)) = full_key.split_once('\0') {
154
125639
        let namespace = if namespace.is_empty() {
155
122159
            None
156
        } else {
157
3480
            Some(namespace.to_string())
158
        };
159
125639
        Some((namespace, key.to_string()))
160
    } else {
161
        None
162
    }
163
125639
}
164

            
165
800930
fn increment(existing: &Numeric, amount: &Numeric, saturating: bool) -> Numeric {
166
800930
    match amount {
167
240
        Numeric::Integer(amount) => {
168
240
            let existing_value = existing.as_i64_lossy(saturating);
169
240
            let new_value = if saturating {
170
160
                existing_value.saturating_add(*amount)
171
            } else {
172
80
                existing_value.wrapping_add(*amount)
173
            };
174
240
            Numeric::Integer(new_value)
175
        }
176
800450
        Numeric::UnsignedInteger(amount) => {
177
800450
            let existing_value = existing.as_u64_lossy(saturating);
178
800450
            let new_value = if saturating {
179
800370
                existing_value.saturating_add(*amount)
180
            } else {
181
80
                existing_value.wrapping_add(*amount)
182
            };
183
800450
            Numeric::UnsignedInteger(new_value)
184
        }
185
240
        Numeric::Float(amount) => {
186
240
            let existing_value = existing.as_f64_lossy();
187
240
            let new_value = existing_value + *amount;
188
240
            Numeric::Float(new_value)
189
        }
190
    }
191
800930
}
192

            
193
720
fn decrement(existing: &Numeric, amount: &Numeric, saturating: bool) -> Numeric {
194
720
    match amount {
195
240
        Numeric::Integer(amount) => {
196
240
            let existing_value = existing.as_i64_lossy(saturating);
197
240
            let new_value = if saturating {
198
160
                existing_value.saturating_sub(*amount)
199
            } else {
200
80
                existing_value.wrapping_sub(*amount)
201
            };
202
240
            Numeric::Integer(new_value)
203
        }
204
320
        Numeric::UnsignedInteger(amount) => {
205
320
            let existing_value = existing.as_u64_lossy(saturating);
206
320
            let new_value = if saturating {
207
160
                existing_value.saturating_sub(*amount)
208
            } else {
209
160
                existing_value.wrapping_sub(*amount)
210
            };
211
320
            Numeric::UnsignedInteger(new_value)
212
        }
213
160
        Numeric::Float(amount) => {
214
160
            let existing_value = existing.as_f64_lossy();
215
160
            let new_value = existing_value - *amount;
216
160
            Numeric::Float(new_value)
217
        }
218
    }
219
720
}
220

            
221
#[derive(Debug)]
222
pub struct KeyValueState {
223
    roots: Roots<AnyFile>,
224
    persistence: KeyValuePersistence,
225
    last_commit: Timestamp,
226
    background_worker_target: watch::Sender<BackgroundWorkerProcessTarget>,
227
    expiring_keys: BTreeMap<String, Timestamp>,
228
    expiration_order: VecDeque<String>,
229
    dirty_keys: BTreeMap<String, Option<Entry>>,
230
    keys_being_persisted: Option<Arc<BTreeMap<String, Option<Entry>>>>,
231
    shutdown: Option<oneshot::Sender<()>>,
232
}
233

            
234
impl KeyValueState {
235
21693
    pub fn new(
236
21693
        persistence: KeyValuePersistence,
237
21693
        roots: Roots<AnyFile>,
238
21693
        background_worker_target: watch::Sender<BackgroundWorkerProcessTarget>,
239
21693
    ) -> Self {
240
21693
        Self {
241
21693
            roots,
242
21693
            persistence,
243
21693
            last_commit: Timestamp::now(),
244
21693
            expiring_keys: BTreeMap::new(),
245
21693
            background_worker_target,
246
21693
            expiration_order: VecDeque::new(),
247
21693
            dirty_keys: BTreeMap::new(),
248
21693
            keys_being_persisted: None,
249
21693
            shutdown: None,
250
21693
        }
251
21693
    }
252

            
253
15326
    pub async fn shutdown(
254
15326
        &mut self,
255
15326
        state: &Arc<Mutex<KeyValueState>>,
256
15326
    ) -> Result<(), oneshot::error::RecvError> {
257
15326
        let (shutdown_sender, shutdown_receiver) = oneshot::channel();
258
15326
        self.shutdown = Some(shutdown_sender);
259
15326
        if self.keys_being_persisted.is_none() {
260
15273
            self.commit_dirty_keys(state);
261
15273
        }
262
15326
        shutdown_receiver.await
263
    }
264

            
265
809138
    pub async fn perform_kv_operation(
266
809138
        &mut self,
267
809138
        op: KeyOperation,
268
809138
        state: &Arc<Mutex<KeyValueState>>,
269
809138
    ) -> Result<Output, bonsaidb_core::Error> {
270
809138
        let now = Timestamp::now();
271
809138
        // If there are any keys that have expired, clear them before executing any operations.
272
809138
        self.remove_expired_keys(now);
273
809138
        let result = match op.command {
274
3523
            Command::Set(command) => {
275
3523
                self.execute_set_operation(op.namespace.as_deref(), &op.key, command, now)
276
            }
277
3165
            Command::Get { delete } => {
278
3165
                self.execute_get_operation(op.namespace.as_deref(), &op.key, delete)
279
            }
280
640
            Command::Delete => self.execute_delete_operation(op.namespace.as_deref(), &op.key),
281
801010
            Command::Increment { amount, saturating } => self.execute_increment_operation(
282
801010
                op.namespace.as_deref(),
283
801010
                &op.key,
284
801010
                &amount,
285
801010
                saturating,
286
801010
                now,
287
801010
            ),
288
800
            Command::Decrement { amount, saturating } => self.execute_decrement_operation(
289
800
                op.namespace.as_deref(),
290
800
                &op.key,
291
800
                &amount,
292
800
                saturating,
293
800
                now,
294
800
            ),
295
        };
296
809138
        if result.is_ok() {
297
808818
            if self.needs_commit(now) {
298
109276
                self.commit_dirty_keys(state);
299
699542
            }
300
808818
            self.update_background_worker_target();
301
320
        }
302
809138
        result
303
809138
    }
304

            
305
3523
    fn execute_set_operation(
306
3523
        &mut self,
307
3523
        namespace: Option<&str>,
308
3523
        key: &str,
309
3523
        set: SetCommand,
310
3523
        now: Timestamp,
311
3523
    ) -> Result<Output, bonsaidb_core::Error> {
312
3443
        let mut entry = Entry {
313
3523
            value: set.value.validate()?,
314
3443
            expiration: set.expiration,
315
3443
            last_updated: now,
316
3443
        };
317
3443
        let full_key = full_key(namespace, key);
318
3443
        let possible_existing_value =
319
3443
            if set.check.is_some() || set.return_previous_value || set.keep_existing_expiration {
320
718
                Some(self.get(&full_key).map_err(Error::from)?)
321
            } else {
322
2725
                None
323
            };
324
3443
        let existing_value_ref = possible_existing_value.as_ref().and_then(Option::as_ref);
325

            
326
3443
        let updating = match set.check {
327
186
            Some(KeyCheck::OnlyIfPresent) => existing_value_ref.is_some(),
328
212
            Some(KeyCheck::OnlyIfVacant) => existing_value_ref.is_none(),
329
3045
            None => true,
330
        };
331
3443
        if updating {
332
3257
            if set.keep_existing_expiration {
333
80
                if let Some(existing_value) = existing_value_ref {
334
80
                    entry.expiration = existing_value.expiration;
335
80
                }
336
3177
            }
337
3257
            self.update_key_expiration(&full_key, entry.expiration);
338

            
339
3257
            let previous_value = if let Some(existing_value) = possible_existing_value {
340
                // we already fetched, no need to ask for the existing value back
341
532
                self.set(full_key, entry);
342
532
                existing_value
343
            } else {
344
2725
                self.replace(full_key, entry).map_err(Error::from)?
345
            };
346
3257
            if set.return_previous_value {
347
266
                Ok(Output::Value(previous_value.map(|entry| entry.value)))
348
2991
            } else if previous_value.is_none() {
349
1367
                Ok(Output::Status(KeyStatus::Inserted))
350
            } else {
351
1624
                Ok(Output::Status(KeyStatus::Updated))
352
            }
353
        } else {
354
186
            Ok(Output::Status(KeyStatus::NotChanged))
355
        }
356
3523
    }
357

            
358
4068
    pub fn update_key_expiration<'key>(
359
4068
        &mut self,
360
4068
        tree_key: impl Into<Cow<'key, str>>,
361
4068
        expiration: Option<Timestamp>,
362
4068
    ) {
363
4068
        let tree_key = tree_key.into();
364
4068
        let mut changed_first_expiration = false;
365
4068
        if let Some(expiration) = expiration {
366
517
            let key = if self.expiring_keys.contains_key(tree_key.as_ref()) {
367
                // Update the existing entry.
368
161
                let existing_entry_index = self
369
161
                    .expiration_order
370
161
                    .iter()
371
161
                    .enumerate()
372
161
                    .find_map(
373
161
                        |(index, key)| {
374
161
                            if &tree_key == key {
375
161
                                Some(index)
376
                            } else {
377
                                None
378
                            }
379
161
                        },
380
161
                    )
381
161
                    .unwrap();
382
161
                changed_first_expiration = existing_entry_index == 0;
383
161
                self.expiration_order.remove(existing_entry_index).unwrap()
384
            } else {
385
356
                tree_key.into_owned()
386
            };
387

            
388
            // Insert the key into the expiration_order queue
389
517
            let mut insert_at = None;
390
517
            for (index, expiring_key) in self.expiration_order.iter().enumerate() {
391
244
                if self.expiring_keys.get(expiring_key).unwrap() > &expiration {
392
82
                    insert_at = Some(index);
393
82
                    break;
394
162
                }
395
            }
396
517
            if let Some(insert_at) = insert_at {
397
82
                changed_first_expiration |= insert_at == 0;
398
82

            
399
82
                self.expiration_order.insert(insert_at, key.clone());
400
435
            } else {
401
435
                changed_first_expiration |= self.expiration_order.is_empty();
402
435
                self.expiration_order.push_back(key.clone());
403
435
            }
404
517
            self.expiring_keys.insert(key, expiration);
405
3551
        } else if self.expiring_keys.remove(tree_key.as_ref()).is_some() {
406
81
            let index = self
407
81
                .expiration_order
408
81
                .iter()
409
81
                .enumerate()
410
81
                .find_map(|(index, key)| {
411
81
                    if tree_key.as_ref() == key {
412
81
                        Some(index)
413
                    } else {
414
                        None
415
                    }
416
81
                })
417
81
                .unwrap();
418
81

            
419
81
            changed_first_expiration |= index == 0;
420
81
            self.expiration_order.remove(index);
421
3470
        }
422

            
423
4068
        if changed_first_expiration {
424
516
            self.update_background_worker_target();
425
3552
        }
426
4068
    }
427

            
428
3165
    fn execute_get_operation(
429
3165
        &mut self,
430
3165
        namespace: Option<&str>,
431
3165
        key: &str,
432
3165
        delete: bool,
433
3165
    ) -> Result<Output, bonsaidb_core::Error> {
434
3165
        let full_key = full_key(namespace, key);
435
3165
        let entry = if delete {
436
160
            self.remove(full_key).map_err(Error::from)?
437
        } else {
438
3005
            self.get(&full_key).map_err(Error::from)?
439
        };
440

            
441
3165
        Ok(Output::Value(entry.map(|e| e.value)))
442
3165
    }
443

            
444
640
    fn execute_delete_operation(
445
640
        &mut self,
446
640
        namespace: Option<&str>,
447
640
        key: &str,
448
640
    ) -> Result<Output, bonsaidb_core::Error> {
449
640
        let full_key = full_key(namespace, key);
450
640
        let value = self.remove(full_key).map_err(Error::from)?;
451
640
        if value.is_some() {
452
240
            Ok(Output::Status(KeyStatus::Deleted))
453
        } else {
454
400
            Ok(Output::Status(KeyStatus::NotChanged))
455
        }
456
640
    }
457

            
458
801010
    fn execute_increment_operation(
459
801010
        &mut self,
460
801010
        namespace: Option<&str>,
461
801010
        key: &str,
462
801010
        amount: &Numeric,
463
801010
        saturating: bool,
464
801010
        now: Timestamp,
465
801010
    ) -> Result<Output, bonsaidb_core::Error> {
466
801010
        self.execute_numeric_operation(namespace, key, amount, saturating, now, increment)
467
801010
    }
468

            
469
800
    fn execute_decrement_operation(
470
800
        &mut self,
471
800
        namespace: Option<&str>,
472
800
        key: &str,
473
800
        amount: &Numeric,
474
800
        saturating: bool,
475
800
        now: Timestamp,
476
800
    ) -> Result<Output, bonsaidb_core::Error> {
477
800
        self.execute_numeric_operation(namespace, key, amount, saturating, now, decrement)
478
800
    }
479

            
480
801810
    fn execute_numeric_operation<F: Fn(&Numeric, &Numeric, bool) -> Numeric>(
481
801810
        &mut self,
482
801810
        namespace: Option<&str>,
483
801810
        key: &str,
484
801810
        amount: &Numeric,
485
801810
        saturating: bool,
486
801810
        now: Timestamp,
487
801810
        op: F,
488
801810
    ) -> Result<Output, bonsaidb_core::Error> {
489
801810
        let full_key = full_key(namespace, key);
490
801810
        let current = self.get(&full_key).map_err(Error::from)?;
491
801810
        let mut entry = current.unwrap_or(Entry {
492
801810
            value: Value::Numeric(Numeric::UnsignedInteger(0)),
493
801810
            expiration: None,
494
801810
            last_updated: now,
495
801810
        });
496
801810

            
497
801810
        match entry.value {
498
801650
            Value::Numeric(existing) => {
499
801650
                let value = Value::Numeric(op(&existing, amount, saturating).validate()?);
500
801570
                entry.value = value.clone();
501
801570

            
502
801570
                self.set(full_key, entry);
503
801570
                Ok(Output::Value(Some(value)))
504
            }
505
160
            Value::Bytes(_) => Err(bonsaidb_core::Error::Database(String::from(
506
160
                "type of stored `Value` is not `Numeric`",
507
160
            ))),
508
        }
509
801810
    }
510

            
511
800
    fn remove(&mut self, key: String) -> Result<Option<Entry>, nebari::Error> {
512
800
        self.update_key_expiration(&key, None);
513

            
514
800
        if let Some(dirty_entry) = self.dirty_keys.get_mut(&key) {
515
190
            Ok(dirty_entry.take())
516
610
        } else if let Some(persisting_entry) = self
517
610
            .keys_being_persisted
518
610
            .as_ref()
519
610
            .and_then(|keys| keys.get(&key))
520
        {
521
            self.dirty_keys.insert(key, None);
522
            Ok(persisting_entry.clone())
523
        } else {
524
            // There might be a value on-disk we need to remove.
525
610
            let previous_value = Self::retrieve_key_from_disk(&self.roots, &key)?;
526
610
            self.dirty_keys.insert(key, None);
527
610
            Ok(previous_value)
528
        }
529
800
    }
530

            
531
    fn get(&self, key: &str) -> Result<Option<Entry>, nebari::Error> {
532
805533
        if let Some(entry) = self.dirty_keys.get(key) {
533
681039
            Ok(entry.clone())
534
124494
        } else if let Some(persisting_entry) = self
535
124494
            .keys_being_persisted
536
124494
            .as_ref()
537
124494
            .and_then(|keys| keys.get(key))
538
        {
539
69577
            Ok(persisting_entry.clone())
540
        } else {
541
54917
            Self::retrieve_key_from_disk(&self.roots, key)
542
        }
543
805533
    }
544

            
545
802102
    fn set(&mut self, key: String, value: Entry) {
546
802102
        self.dirty_keys.insert(key, Some(value));
547
802102
    }
548

            
549
2725
    fn replace(&mut self, key: String, value: Entry) -> Result<Option<Entry>, nebari::Error> {
550
2725
        let mut value = Some(value);
551
2725
        let map_entry = self.dirty_keys.entry(key);
552
2725
        if matches!(map_entry, btree_map::Entry::Vacant(_)) {
553
            // This key is clean, and the caller is expecting the previous
554
            // value.
555
2199
            let stored_value = if let Some(persisting_entry) = self
556
2199
                .keys_being_persisted
557
2199
                .as_ref()
558
2199
                .and_then(|keys| keys.get(map_entry.key()))
559
            {
560
218
                persisting_entry.clone()
561
            } else {
562
1981
                Self::retrieve_key_from_disk(&self.roots, map_entry.key())?
563
            };
564
2199
            map_entry.or_insert(value);
565
2199
            Ok(stored_value)
566
        } else {
567
            // This key is already dirty, we can just replace the value and
568
            // return the old value.
569
526
            map_entry.and_modify(|map_entry| {
570
526
                std::mem::swap(&mut value, map_entry);
571
526
            });
572
526
            Ok(value)
573
        }
574
2725
    }
575

            
576
57508
    fn retrieve_key_from_disk(
577
57508
        roots: &Roots<AnyFile>,
578
57508
        key: &str,
579
57508
    ) -> Result<Option<Entry>, nebari::Error> {
580
57508
        roots
581
57508
            .tree(Unversioned::tree(KEY_TREE))?
582
57508
            .get(key.as_bytes())
583
57508
            .map(|current| current.and_then(|current| bincode::deserialize::<Entry>(&current).ok()))
584
57508
    }
585

            
586
964268
    fn update_background_worker_target(&mut self) {
587
964268
        let key_expiration_target =
588
964268
            self.expiration_order
589
964268
                .get(0)
590
964268
                .map_or_else(Timestamp::max, |key| {
591
2734
                    let expiration_timeout = self.expiring_keys.get(key).unwrap();
592
2734
                    *expiration_timeout
593
964268
                });
594
964268
        let now = Timestamp::now();
595
964268
        if self.keys_being_persisted.is_some() {
596
836530
            drop(
597
836530
                self.background_worker_target
598
836530
                    .send(BackgroundWorkerProcessTarget::Never),
599
836530
            );
600
836530
            return;
601
127738
        }
602
127738
        let duration_until_commit = self.persistence.duration_until_next_commit(
603
127738
            self.dirty_keys.len(),
604
127738
            (now - self.last_commit).unwrap_or_default(),
605
127738
        );
606
127738
        if duration_until_commit == Duration::ZERO {
607
69399
            drop(
608
69399
                self.background_worker_target
609
69399
                    .send(BackgroundWorkerProcessTarget::Now),
610
69399
            );
611
69399
        } else {
612
58339
            let commit_target = now + duration_until_commit;
613
58339
            let closest_target = key_expiration_target.min(commit_target);
614
58339
            if *self.background_worker_target.borrow()
615
58339
                != BackgroundWorkerProcessTarget::Timestamp(closest_target)
616
58155
            {
617
58155
                drop(
618
58155
                    self.background_worker_target
619
58155
                        .send(BackgroundWorkerProcessTarget::Timestamp(closest_target)),
620
58155
                );
621
58155
            }
622
        }
623
964268
    }
624

            
625
839922
    fn remove_expired_keys(&mut self, now: Timestamp) {
626
840170
        while !self.expiration_order.is_empty()
627
1346
            && self.expiring_keys.get(&self.expiration_order[0]).unwrap() <= &now
628
248
        {
629
248
            let key = self.expiration_order.pop_front().unwrap();
630
248
            self.expiring_keys.remove(&key);
631
248
            self.dirty_keys.insert(key, None);
632
248
        }
633
839922
    }
634

            
635
839602
    fn needs_commit(&mut self, now: Timestamp) -> bool {
636
839602
        if self.keys_being_persisted.is_some() {
637
712372
            false
638
        } else {
639
127230
            let since_last_commit = (now - self.last_commit).unwrap_or_default();
640
127230
            self.persistence
641
127230
                .should_commit(self.dirty_keys.len(), since_last_commit)
642
        }
643
839602
    }
644

            
645
139476
    fn stage_dirty_keys(&mut self) -> Option<Arc<BTreeMap<String, Option<Entry>>>> {
646
139476
        if !self.dirty_keys.is_empty() && self.keys_being_persisted.is_none() {
647
124150
            let keys = Arc::new(std::mem::take(&mut self.dirty_keys));
648
124150
            self.keys_being_persisted = Some(keys.clone());
649
124150
            Some(keys)
650
        } else {
651
15326
            None
652
        }
653
139476
    }
654

            
655
    fn commit_dirty_keys(&mut self, state: &Arc<Mutex<KeyValueState>>) {
656
139291
        if let Some(keys) = self.stage_dirty_keys() {
657
124149
            let roots = self.roots.clone();
658
124149
            let state = state.clone();
659
124149
            let tokio = Handle::current();
660
124149
            tokio::task::spawn_blocking(move || Self::persist_keys(&state, &roots, &keys, &tokio));
661
124149
            self.last_commit = Timestamp::now();
662
124149
        }
663
139291
    }
664

            
665
124150
    fn persist_keys(
666
124150
        key_value_state: &Arc<Mutex<KeyValueState>>,
667
124150
        roots: &Roots<AnyFile>,
668
124150
        keys: &BTreeMap<String, Option<Entry>>,
669
124150
        runtime: &Handle,
670
124150
    ) -> Result<(), bonsaidb_core::Error> {
671
124150
        let mut transaction = roots
672
124150
            .transaction(&[Unversioned::tree(KEY_TREE)])
673
124150
            .map_err(Error::from)?;
674
124150
        let all_keys = keys
675
124150
            .keys()
676
124655
            .map(|key| ArcBytes::from(key.as_bytes().to_vec()))
677
124150
            .collect();
678
124150
        let mut changed_keys = Vec::new();
679
124150
        transaction
680
124150
            .tree::<Unversioned>(0)
681
124150
            .unwrap()
682
124150
            .modify(
683
124150
                all_keys,
684
124655
                Operation::CompareSwap(CompareSwap::new(&mut |key, existing_value| {
685
124655
                    let full_key = std::str::from_utf8(key).unwrap();
686
124655
                    let (namespace, key) = split_key(full_key).unwrap();
687

            
688
124655
                    if let Some(new_value) = keys.get(full_key).unwrap() {
689
123799
                        changed_keys.push(ChangedKey {
690
123799
                            namespace,
691
123799
                            key,
692
123799
                            deleted: false,
693
123799
                        });
694
123799
                        let bytes = bincode::serialize(new_value).unwrap();
695
123799
                        nebari::tree::KeyOperation::Set(ArcBytes::from(bytes))
696
856
                    } else if existing_value.is_some() {
697
538
                        changed_keys.push(ChangedKey {
698
538
                            namespace,
699
538
                            key,
700
538
                            deleted: existing_value.is_some(),
701
538
                        });
702
538
                        nebari::tree::KeyOperation::Remove
703
                    } else {
704
318
                        nebari::tree::KeyOperation::Skip
705
                    }
706
124655
                })),
707
124150
            )
708
124150
            .map_err(Error::from)?;
709

            
710
124150
        if !changed_keys.is_empty() {
711
123832
            transaction
712
123832
                .entry_mut()
713
123832
                .set_data(compat::serialize_executed_transaction_changes(
714
123832
                    &Changes::Keys(changed_keys),
715
123832
                )?)
716
123832
                .map_err(Error::from)?;
717
123832
            transaction.commit().map_err(Error::from)?;
718
318
        }
719

            
720
        // If we are shutting down, check if we still have dirty keys.
721
124150
        if let Some(final_keys) = runtime.block_on(async {
722
124150
            let mut state = fast_async_lock!(key_value_state);
723
124150
            state.keys_being_persisted = None;
724
124150
            state.update_background_worker_target();
725
124150
            // This block is a little ugly to avoid having to acquire the lock
726
124150
            // twice. If we're shutting down and have no dirty keys, we notify
727
124150
            // the waiting shutdown task. If we have any dirty keys, we wait do
728
124150
            // to that step because we're going to recurse and reach this spot
729
124150
            // again.
730
124150
            if state.shutdown.is_some() {
731
185
                let staged_keys = state.stage_dirty_keys();
732
185
                if staged_keys.is_none() {
733
184
                    let shutdown = state.shutdown.take().unwrap();
734
184
                    let _ = shutdown.send(());
735
184
                }
736
185
                staged_keys
737
            } else {
738
123965
                None
739
            }
740
124150
        }) {
741
1
            Self::persist_keys(key_value_state, roots, &final_keys, runtime)?;
742
124149
        }
743
124150
        Ok(())
744
124150
    }
745
}
746

            
747
21693
pub async fn background_worker(
748
21693
    key_value_state: Arc<Mutex<KeyValueState>>,
749
21693
    mut timestamp_receiver: watch::Receiver<BackgroundWorkerProcessTarget>,
750
21693
) -> Result<(), Error> {
751
217204
    loop {
752
217204
        let mut perform_operations = false;
753
217204
        let current_timestamp = *timestamp_receiver.borrow_and_update();
754
217204
        let changed_result = match current_timestamp {
755
154340
            BackgroundWorkerProcessTarget::Never => timestamp_receiver.changed().await,
756
32329
            BackgroundWorkerProcessTarget::Timestamp(target) => {
757
32329
                let remaining = target - Timestamp::now();
758
32329
                if let Some(remaining) = remaining {
759
32329
                    tokio::select! {
760
31137
                        changed = timestamp_receiver.changed() => changed,
761
                        _ = tokio::time::sleep(remaining) => {
762
                            perform_operations = true;
763
                            Ok(())
764
                        },
765
                    }
766
                } else {
767
                    perform_operations = true;
768
                    Ok(())
769
                }
770
            }
771
            BackgroundWorkerProcessTarget::Now => {
772
30535
                perform_operations = true;
773
30535
                Ok(())
774
            }
775
        };
776

            
777
195653
        if changed_result.is_err() {
778
            break;
779
195653
        }
780
195653

            
781
195653
        if perform_operations {
782
48049
            let mut state = fast_async_lock!(key_value_state);
783
30784
            let now = Timestamp::now();
784
30784
            state.remove_expired_keys(now);
785
30784
            if state.needs_commit(now) {
786
14742
                state.commit_dirty_keys(&key_value_state);
787
16068
            }
788
30784
            state.update_background_worker_target();
789
164869
        }
790
    }
791

            
792
    Ok(())
793
}
794

            
795
58339
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
796
pub enum BackgroundWorkerProcessTarget {
797
    Now,
798
    Timestamp(Timestamp),
799
    Never,
800
}
801

            
802
#[derive(Debug)]
803
pub struct ExpirationLoader {
804
    pub database: Database,
805
    pub launched_at: Timestamp,
806
}
807

            
808
impl Keyed<Task> for ExpirationLoader {
809
10194
    fn key(&self) -> Task {
810
10194
        Task::ExpirationLoader(self.database.data.name.clone())
811
10194
    }
812
}
813

            
814
#[async_trait]
815
impl Job for ExpirationLoader {
816
    type Output = ();
817
    type Error = Error;
818

            
819
26360
    #[cfg_attr(feature = "tracing", tracing::instrument)]
820
8804
    async fn execute(&mut self) -> Result<Self::Output, Self::Error> {
821
8804
        let database = self.database.clone();
822
8804
        let launched_at = self.launched_at;
823

            
824
8804
        for ((namespace, key), entry) in database.all_key_value_entries().await? {
825
791
            if entry.last_updated < launched_at && entry.expiration.is_some() {
826
1
                self.database
827
1
                    .update_key_expiration_async(
828
1
                        full_key(namespace.as_deref(), &key),
829
1
                        entry.expiration,
830
1
                    )
831
                    .await;
832
790
            }
833
        }
834

            
835
8720
        self.database
836
8720
            .storage()
837
8720
            .tasks()
838
8720
            .mark_key_value_expiration_loaded(self.database.data.name.clone())
839
            .await;
840

            
841
8720
        Ok(())
842
17524
    }
843
}
844

            
845
#[cfg(test)]
846
mod tests {
847
    use std::time::Duration;
848

            
849
    use bonsaidb_core::{
850
        arc_bytes::serde::Bytes,
851
        test_util::{TestDirectory, TimingTest},
852
    };
853
    use futures::Future;
854
    use nebari::io::any::{AnyFile, AnyFileManager};
855

            
856
    use super::*;
857
    use crate::{config::PersistenceThreshold, database::Context};
858

            
859
6
    async fn run_test_with_persistence<
860
6
        F: Fn(Context, nebari::Roots<AnyFile>) -> R + Send,
861
6
        R: Future<Output = anyhow::Result<()>> + Send,
862
6
    >(
863
6
        name: &str,
864
6
        persistence: KeyValuePersistence,
865
6
        test_contents: &F,
866
6
    ) -> anyhow::Result<()> {
867
6
        let dir = TestDirectory::new(name);
868
6
        let sled = nebari::Config::new(&dir)
869
6
            .file_manager(AnyFileManager::std())
870
6
            .open()?;
871

            
872
6
        let context = Context::new(sled.clone(), persistence);
873
6

            
874
11
        test_contents(context, sled).await?;
875

            
876
6
        Ok(())
877
6
    }
878

            
879
5
    async fn run_test<
880
5
        F: Fn(Context, nebari::Roots<AnyFile>) -> R + Send,
881
5
        R: Future<Output = anyhow::Result<()>> + Send,
882
5
    >(
883
5
        name: &str,
884
5
        test_contents: F,
885
5
    ) -> anyhow::Result<()> {
886
9
        run_test_with_persistence(name, KeyValuePersistence::default(), &test_contents).await
887
5
    }
888

            
889
1
    #[tokio::test]
890
1
    async fn basic_expiration() -> anyhow::Result<()> {
891
1
        run_test("kv-basic-expiration", |sender, sled| async move {
892
            loop {
893
1
                sled.delete_tree(KEY_TREE)?;
894
1
                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
895
1
                tree.set(b"atree\0akey", b"somevalue")?;
896
1
                let timing = TimingTest::new(Duration::from_millis(100));
897
1
                sender
898
1
                    .update_key_expiration_async(
899
1
                        full_key(Some("atree"), "akey"),
900
1
                        Some(Timestamp::now() + Duration::from_millis(100)),
901
1
                    )
902
                    .await;
903
1
                if !timing.wait_until(Duration::from_secs(1)).await {
904
                    println!("basic_expiration restarting due to timing discrepency");
905
                    continue;
906
1
                }
907
1
                assert!(tree.get(b"akey")?.is_none());
908
1
                break;
909
1
            }
910
1

            
911
1
            Ok(())
912
1
        })
913
1
        .await
914
1
    }
915

            
916
1
    #[tokio::test]
917
1
    async fn updating_expiration() -> anyhow::Result<()> {
918
1
        run_test("kv-updating-expiration", |sender, sled| async move {
919
            loop {
920
1
                sled.delete_tree(KEY_TREE)?;
921
1
                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
922
1
                tree.set(b"atree\0akey", b"somevalue")?;
923
1
                let timing = TimingTest::new(Duration::from_millis(100));
924
1
                sender
925
1
                    .update_key_expiration_async(
926
1
                        full_key(Some("atree"), "akey"),
927
1
                        Some(Timestamp::now() + Duration::from_millis(100)),
928
1
                    )
929
                    .await;
930
1
                sender
931
1
                    .update_key_expiration_async(
932
1
                        full_key(Some("atree"), "akey"),
933
1
                        Some(Timestamp::now() + Duration::from_secs(1)),
934
1
                    )
935
                    .await;
936
1
                if timing.elapsed() > Duration::from_millis(100)
937
1
                    || !timing.wait_until(Duration::from_millis(500)).await
938
                {
939
                    continue;
940
1
                }
941
1
                assert!(tree.get(b"atree\0akey")?.is_some());
942

            
943
1
                timing.wait_until(Duration::from_secs_f32(1.5)).await;
944
1
                assert_eq!(tree.get(b"atree\0akey")?, None);
945
1
                break;
946
1
            }
947
1

            
948
1
            Ok(())
949
2
        })
950
2
        .await
951
1
    }
952

            
953
1
    #[tokio::test]
954
1
    async fn multiple_keys_expiration() -> anyhow::Result<()> {
955
1
        run_test("kv-multiple-keys-expiration", |sender, sled| async move {
956
            loop {
957
1
                sled.delete_tree(KEY_TREE)?;
958
1
                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
959
1
                tree.set(b"atree\0akey", b"somevalue")?;
960
1
                tree.set(b"atree\0bkey", b"somevalue")?;
961

            
962
1
                let timing = TimingTest::new(Duration::from_millis(100));
963
1
                sender
964
1
                    .update_key_expiration_async(
965
1
                        full_key(Some("atree"), "akey"),
966
1
                        Some(Timestamp::now() + Duration::from_millis(100)),
967
1
                    )
968
                    .await;
969
1
                sender
970
1
                    .update_key_expiration_async(
971
1
                        full_key(Some("atree"), "bkey"),
972
1
                        Some(Timestamp::now() + Duration::from_secs(1)),
973
1
                    )
974
                    .await;
975

            
976
1
                if !timing.wait_until(Duration::from_millis(200)).await {
977
                    continue;
978
1
                }
979

            
980
1
                assert!(tree.get(b"atree\0akey")?.is_none());
981
1
                assert!(tree.get(b"atree\0bkey")?.is_some());
982
1
                timing.wait_until(Duration::from_millis(1100)).await;
983
1
                assert!(tree.get(b"atree\0bkey")?.is_none());
984

            
985
1
                break;
986
1
            }
987
1

            
988
1
            Ok(())
989
2
        })
990
2
        .await
991
1
    }
992

            
993
1
    #[tokio::test]
994
1
    async fn clearing_expiration() -> anyhow::Result<()> {
995
1
        run_test("kv-clearing-expiration", |sender, sled| async move {
996
            loop {
997
1
                sled.delete_tree(KEY_TREE)?;
998
1
                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
999
1
                tree.set(b"atree\0akey", b"somevalue")?;
1
                let timing = TimingTest::new(Duration::from_millis(100));
1
                sender
1
                    .update_key_expiration_async(
1
                        full_key(Some("atree"), "akey"),
1
                        Some(Timestamp::now() + Duration::from_millis(100)),
1
                    )
                    .await;
1
                sender
1
                    .update_key_expiration_async(full_key(Some("atree"), "akey"), None)
                    .await;
1
                if timing.elapsed() > Duration::from_millis(100) {
                    // Restart, took too long.
                    continue;
1
                }
1
                timing.wait_until(Duration::from_millis(150)).await;
1
                assert!(tree.get(b"atree\0akey")?.is_some());
1
                break;
1
            }
1

            
1
            Ok(())
1
        })
1
        .await
1
    }

            
1
    #[tokio::test]
1
    async fn out_of_order_expiration() -> anyhow::Result<()> {
1
        run_test("kv-out-of-order-expiration", |sender, sled| async move {
1
            let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
1
            tree.set(b"atree\0akey", b"somevalue")?;
1
            tree.set(b"atree\0bkey", b"somevalue")?;
1
            tree.set(b"atree\0ckey", b"somevalue")?;
1
            sender
1
                .update_key_expiration_async(
1
                    full_key(Some("atree"), "akey"),
1
                    Some(Timestamp::now() + Duration::from_secs(3)),
1
                )
                .await;
1
            sender
1
                .update_key_expiration_async(
1
                    full_key(Some("atree"), "ckey"),
1
                    Some(Timestamp::now() + Duration::from_secs(1)),
1
                )
                .await;
1
            sender
1
                .update_key_expiration_async(
1
                    full_key(Some("atree"), "bkey"),
1
                    Some(Timestamp::now() + Duration::from_secs(2)),
1
                )
                .await;
1
            tokio::time::sleep(Duration::from_millis(1200)).await;
1
            assert!(tree.get(b"atree\0akey")?.is_some());
1
            assert!(tree.get(b"atree\0bkey")?.is_some());
1
            assert!(tree.get(b"atree\0ckey")?.is_none());
1
            tokio::time::sleep(Duration::from_secs(1)).await;
1
            assert!(tree.get(b"atree\0akey")?.is_some());
1
            assert!(tree.get(b"atree\0bkey")?.is_none());
1
            tokio::time::sleep(Duration::from_secs(1)).await;
1
            assert!(tree.get(b"atree\0akey")?.is_none());

            
1
            Ok(())
3
        })
3
        .await
1
    }

            
1
    #[tokio::test]
1
    async fn basic_persistence() -> anyhow::Result<()> {
1
        run_test_with_persistence(
1
            "kv-basic-persistence",
1
            KeyValuePersistence::lazy([
1
                PersistenceThreshold::after_changes(2),
1
                PersistenceThreshold::after_changes(1).and_duration(Duration::from_secs(2)),
1
            ]),
1
            &|sender, sled| async move {
1
                loop {
1
                    let timing = TimingTest::new(Duration::from_millis(100));
1
                    let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
1
                    // Set three keys in quick succession. The first two should
1
                    // persist immediately, and the third should show up after 2
1
                    // seconds.
1
                    sender
1
                        .perform_kv_operation(KeyOperation {
1
                            namespace: None,
1
                            key: String::from("key1"),
1
                            command: Command::Set(SetCommand {
1
                                value: Value::Bytes(Bytes::default()),
1
                                expiration: None,
1
                                keep_existing_expiration: false,
1
                                check: None,
1
                                return_previous_value: false,
1
                            }),
1
                        })
                        .await
1
                        .unwrap();
1
                    sender
1
                        .perform_kv_operation(KeyOperation {
1
                            namespace: None,
1
                            key: String::from("key2"),
1
                            command: Command::Set(SetCommand {
1
                                value: Value::Bytes(Bytes::default()),
1
                                expiration: None,
1
                                keep_existing_expiration: false,
1
                                check: None,
1
                                return_previous_value: false,
1
                            }),
1
                        })
                        .await
1
                        .unwrap();
1
                    sender
1
                        .perform_kv_operation(KeyOperation {
1
                            namespace: None,
1
                            key: String::from("key3"),
1
                            command: Command::Set(SetCommand {
1
                                value: Value::Bytes(Bytes::default()),
1
                                expiration: None,
1
                                keep_existing_expiration: false,
1
                                check: None,
1
                                return_previous_value: false,
1
                            }),
1
                        })
                        .await
1
                        .unwrap();
1
                    // Persisting is handled in the background. Sleep for a bit
1
                    // to give it a chance to happen, but not long enough to
1
                    // trigger the longer time-based rule.
1
                    if timing.elapsed() > Duration::from_millis(500)
1
                        || !timing.wait_until(Duration::from_secs(1)).await
1
                    {
1
                        println!("basic_persistence restarting due to timing discrepency");
                        continue;
1
                    }
1
                    assert!(tree.get(b"\0key1").unwrap().is_some());
1
                    assert!(tree.get(b"\0key2").unwrap().is_some());
1
                    assert!(tree.get(b"\0key3").unwrap().is_none());
1
                    if !timing.wait_until(Duration::from_secs(3)).await {
1
                        println!("basic_persistence restarting due to timing discrepency");
                        continue;
1
                    }
1
                    assert!(tree.get(b"\0key3").unwrap().is_some());
1
                    break;
1
                }
1

            
1
                Ok(())
1
            },
2
        )
2
        .await
1
    }

            
1
    #[tokio::test]
1
    async fn saves_on_drop() -> anyhow::Result<()> {
1
        let dir = TestDirectory::new("saves-on-drop.bonsaidb");
1
        let sled = nebari::Config::new(&dir)
1
            .file_manager(AnyFileManager::std())
1
            .open()?;
1
        let tree = sled.tree(Unversioned::tree(KEY_TREE))?;

            
1
        let context = Context::new(
1
            sled.clone(),
1
            KeyValuePersistence::lazy([PersistenceThreshold::after_changes(2)]),
1
        );
1
        context
1
            .perform_kv_operation(KeyOperation {
1
                namespace: None,
1
                key: String::from("key1"),
1
                command: Command::Set(SetCommand {
1
                    value: Value::Bytes(Bytes::default()),
1
                    expiration: None,
1
                    keep_existing_expiration: false,
1
                    check: None,
1
                    return_previous_value: false,
1
                }),
1
            })
            .await
1
            .unwrap();
1
        assert!(tree.get(b"\0key1").unwrap().is_none());
1
        drop(context);
1
        // Dropping spawns a task that should persist the keys. Give a moment
1
        // for the runtime to execute the task.
1
        tokio::time::sleep(Duration::from_millis(100)).await;
1
        assert!(tree.get(b"\0key1").unwrap().is_some());

            
1
        Ok(())
1
    }
}