1
use std::{
2
    borrow::Cow,
3
    collections::{btree_map, BTreeMap, VecDeque},
4
    sync::{Arc, Weak},
5
    time::Duration,
6
};
7

            
8
use bonsaidb_core::{
9
    connection::{Connection, HasSession},
10
    keyvalue::{
11
        Command, KeyCheck, KeyOperation, KeyStatus, KeyValue, Numeric, Output, SetCommand,
12
        Timestamp, Value,
13
    },
14
    permissions::bonsai::{
15
        keyvalue_key_resource_name, BonsaiAction, DatabaseAction, KeyValueAction,
16
    },
17
    transaction::{ChangedKey, Changes},
18
};
19
use nebari::{
20
    io::any::AnyFile,
21
    tree::{CompareSwap, Operation, Root, ScanEvaluation, Unversioned},
22
    AbortError, ArcBytes, Roots,
23
};
24
use parking_lot::Mutex;
25
use serde::{Deserialize, Serialize};
26
use watchable::{Watchable, Watcher};
27

            
28
use crate::{
29
    config::KeyValuePersistence,
30
    database::compat,
31
    tasks::{Job, Keyed, Task},
32
    Database, DatabaseNonBlocking, Error,
33
};
34

            
35
1241025
#[derive(Serialize, Deserialize, Debug, Clone)]
36
pub struct Entry {
37
    pub value: Value,
38
    pub expiration: Option<Timestamp>,
39
    #[serde(default)]
40
    pub last_updated: Timestamp,
41
}
42

            
43
impl Entry {
44
    pub(crate) fn restore(
45
        self,
46
        namespace: Option<String>,
47
        key: String,
48
        database: &Database,
49
    ) -> Result<(), bonsaidb_core::Error> {
50
3
        database.execute_key_operation(KeyOperation {
51
3
            namespace,
52
3
            key,
53
3
            command: Command::Set(SetCommand {
54
3
                value: self.value,
55
3
                expiration: self.expiration,
56
3
                keep_existing_expiration: false,
57
3
                check: None,
58
3
                return_previous_value: false,
59
3
            }),
60
3
        })?;
61
3
        Ok(())
62
3
    }
63
}
64

            
65
impl KeyValue for Database {
66
    fn execute_key_operation(&self, op: KeyOperation) -> Result<Output, bonsaidb_core::Error> {
67
1253557
        self.check_permission(
68
1253557
            keyvalue_key_resource_name(self.name(), op.namespace.as_deref(), &op.key),
69
1253557
            &BonsaiAction::Database(DatabaseAction::KeyValue(KeyValueAction::ExecuteOperation)),
70
1253557
        )?;
71
1253557
        self.data.context.perform_kv_operation(op)
72
1253557
    }
73
}
74

            
75
impl Database {
76
12385
    pub(crate) fn all_key_value_entries(
77
12385
        &self,
78
12385
    ) -> Result<BTreeMap<(Option<String>, String), Entry>, Error> {
79
12385
        // Lock the state so that new new modifications can be made while we gather this snapshot.
80
12385
        let state = self.data.context.key_value_state.lock();
81
12385
        let database = self.clone();
82
12385
        // Initialize our entries with any dirty keys and any keys that are about to be persisted.
83
12385
        let mut all_entries = BTreeMap::new();
84
12385
        database
85
12385
            .roots()
86
12385
            .tree(Unversioned::tree(KEY_TREE))?
87
12383
            .scan::<Error, _, _, _, _>(
88
12383
                &(..),
89
12383
                true,
90
12383
                |_, _, _| ScanEvaluation::ReadData,
91
12383
                |_, _| ScanEvaluation::ReadData,
92
12383
                |key, _, entry: ArcBytes<'static>| {
93
25
                    let entry = bincode::deserialize::<Entry>(&entry)
94
25
                        .map_err(|err| AbortError::Other(Error::from(err)))?;
95
25
                    let full_key = std::str::from_utf8(&key)
96
25
                        .map_err(|err| AbortError::Other(Error::from(err)))?;
97

            
98
25
                    if let Some(split_key) = split_key(full_key) {
99
25
                        // Do not overwrite the existing key
100
25
                        all_entries.entry(split_key).or_insert(entry);
101
25
                    }
102

            
103
25
                    Ok(())
104
12383
                },
105
12383
            )?;
106

            
107
        // Apply the pending writes first
108
12382
        if let Some(pending_keys) = &state.keys_being_persisted {
109
68
            for (key, possible_entry) in pending_keys.iter() {
110
68
                let (namespace, key) = split_key(key).unwrap();
111
68
                if let Some(updated_entry) = possible_entry {
112
37
                    all_entries.insert((namespace, key), updated_entry.clone());
113
37
                } else {
114
31
                    all_entries.remove(&(namespace, key));
115
31
                }
116
            }
117
12315
        }
118

            
119
12382
        for (key, possible_entry) in &state.dirty_keys {
120
105
            let (namespace, key) = split_key(key).unwrap();
121
105
            if let Some(updated_entry) = possible_entry {
122
73
                all_entries.insert((namespace, key), updated_entry.clone());
123
73
            } else {
124
32
                all_entries.remove(&(namespace, key));
125
32
            }
126
        }
127

            
128
12382
        Ok(all_entries)
129
12385
    }
130
}
131

            
132
pub(crate) const KEY_TREE: &str = "kv";
133

            
134
1253864
fn full_key(namespace: Option<&str>, key: &str) -> String {
135
1253864
    let full_length = namespace.map_or_else(|| 0, str::len) + key.len() + 1;
136
1253864
    let mut full_key = String::with_capacity(full_length);
137
1253864
    if let Some(ns) = namespace {
138
9993
        full_key.push_str(ns);
139
1243871
    }
140
1253864
    full_key.push('\0');
141
1253864
    full_key.push_str(key);
142
1253864
    full_key
143
1253864
}
144

            
145
fn split_key(full_key: &str) -> Option<(Option<String>, String)> {
146
89766
    if let Some((namespace, key)) = full_key.split_once('\0') {
147
89766
        let namespace = if namespace.is_empty() {
148
85542
            None
149
        } else {
150
4224
            Some(namespace.to_string())
151
        };
152
89766
        Some((namespace, key.to_string()))
153
    } else {
154
        None
155
    }
156
89766
}
157

            
158
1241510
fn increment(existing: &Numeric, amount: &Numeric, saturating: bool) -> Numeric {
159
1241510
    match amount {
160
372
        Numeric::Integer(amount) => {
161
372
            let existing_value = existing.as_i64_lossy(saturating);
162
372
            let new_value = if saturating {
163
248
                existing_value.saturating_add(*amount)
164
            } else {
165
124
                existing_value.wrapping_add(*amount)
166
            };
167
372
            Numeric::Integer(new_value)
168
        }
169
1240766
        Numeric::UnsignedInteger(amount) => {
170
1240766
            let existing_value = existing.as_u64_lossy(saturating);
171
1240766
            let new_value = if saturating {
172
1240642
                existing_value.saturating_add(*amount)
173
            } else {
174
124
                existing_value.wrapping_add(*amount)
175
            };
176
1240766
            Numeric::UnsignedInteger(new_value)
177
        }
178
372
        Numeric::Float(amount) => {
179
372
            let existing_value = existing.as_f64_lossy();
180
372
            let new_value = existing_value + *amount;
181
372
            Numeric::Float(new_value)
182
        }
183
    }
184
1241510
}
185

            
186
1116
fn decrement(existing: &Numeric, amount: &Numeric, saturating: bool) -> Numeric {
187
1116
    match amount {
188
372
        Numeric::Integer(amount) => {
189
372
            let existing_value = existing.as_i64_lossy(saturating);
190
372
            let new_value = if saturating {
191
248
                existing_value.saturating_sub(*amount)
192
            } else {
193
124
                existing_value.wrapping_sub(*amount)
194
            };
195
372
            Numeric::Integer(new_value)
196
        }
197
496
        Numeric::UnsignedInteger(amount) => {
198
496
            let existing_value = existing.as_u64_lossy(saturating);
199
496
            let new_value = if saturating {
200
248
                existing_value.saturating_sub(*amount)
201
            } else {
202
248
                existing_value.wrapping_sub(*amount)
203
            };
204
496
            Numeric::UnsignedInteger(new_value)
205
        }
206
248
        Numeric::Float(amount) => {
207
248
            let existing_value = existing.as_f64_lossy();
208
248
            let new_value = existing_value - *amount;
209
248
            Numeric::Float(new_value)
210
        }
211
    }
212
1116
}
213

            
214
#[derive(Debug)]
215
pub struct KeyValueState {
216
    roots: Roots<AnyFile>,
217
    persistence: KeyValuePersistence,
218
    last_commit: Timestamp,
219
    background_worker_target: Watchable<BackgroundWorkerProcessTarget>,
220
    expiring_keys: BTreeMap<String, Timestamp>,
221
    expiration_order: VecDeque<String>,
222
    dirty_keys: BTreeMap<String, Option<Entry>>,
223
    keys_being_persisted: Option<Arc<BTreeMap<String, Option<Entry>>>>,
224
    last_persistence: Watchable<Timestamp>,
225
    shutdown: Option<flume::Sender<()>>,
226
}
227

            
228
impl KeyValueState {
229
26877
    pub fn new(
230
26877
        persistence: KeyValuePersistence,
231
26877
        roots: Roots<AnyFile>,
232
26877
        background_worker_target: Watchable<BackgroundWorkerProcessTarget>,
233
26877
    ) -> Self {
234
26877
        Self {
235
26877
            roots,
236
26877
            persistence,
237
26877
            last_commit: Timestamp::now(),
238
26877
            expiring_keys: BTreeMap::new(),
239
26877
            background_worker_target,
240
26877
            expiration_order: VecDeque::new(),
241
26877
            dirty_keys: BTreeMap::new(),
242
26877
            keys_being_persisted: None,
243
26877
            last_persistence: Watchable::new(Timestamp::MIN),
244
26877
            shutdown: None,
245
26877
        }
246
26877
    }
247

            
248
24027
    pub fn shutdown(&mut self, state: &Arc<Mutex<KeyValueState>>) -> Option<flume::Receiver<()>> {
249
24027
        if self.keys_being_persisted.is_none() && self.commit_dirty_keys(state) {
250
301
            let (shutdown_sender, shutdown_receiver) = flume::bounded(1);
251
301
            self.shutdown = Some(shutdown_sender);
252
301
            Some(shutdown_receiver)
253
        } else {
254
23726
            None
255
        }
256
24027
    }
257

            
258
1253974
    pub fn perform_kv_operation(
259
1253974
        &mut self,
260
1253974
        op: KeyOperation,
261
1253974
        state: &Arc<Mutex<KeyValueState>>,
262
1253974
    ) -> Result<Output, bonsaidb_core::Error> {
263
1253974
        let now = Timestamp::now();
264
1253974
        // If there are any keys that have expired, clear them before executing any operations.
265
1253974
        self.remove_expired_keys(now);
266
1253974
        let result = match op.command {
267
5323
            Command::Set(command) => {
268
5323
                self.execute_set_operation(op.namespace.as_deref(), &op.key, command, now)
269
            }
270
4785
            Command::Get { delete } => {
271
4785
                self.execute_get_operation(op.namespace.as_deref(), &op.key, delete)
272
            }
273
992
            Command::Delete => self.execute_delete_operation(op.namespace.as_deref(), &op.key),
274
1241634
            Command::Increment { amount, saturating } => self.execute_increment_operation(
275
1241634
                op.namespace.as_deref(),
276
1241634
                &op.key,
277
1241634
                &amount,
278
1241634
                saturating,
279
1241634
                now,
280
1241634
            ),
281
1240
            Command::Decrement { amount, saturating } => self.execute_decrement_operation(
282
1240
                op.namespace.as_deref(),
283
1240
                &op.key,
284
1240
                &amount,
285
1240
                saturating,
286
1240
                now,
287
1240
            ),
288
        };
289
1253974
        if result.is_ok() {
290
1253478
            if self.needs_commit(now) {
291
21875
                self.commit_dirty_keys(state);
292
1231603
            }
293
1253478
            self.update_background_worker_target();
294
496
        }
295
1253974
        result
296
1253974
    }
297

            
298
5323
    fn execute_set_operation(
299
5323
        &mut self,
300
5323
        namespace: Option<&str>,
301
5323
        key: &str,
302
5323
        set: SetCommand,
303
5323
        now: Timestamp,
304
5323
    ) -> Result<Output, bonsaidb_core::Error> {
305
5199
        let mut entry = Entry {
306
5323
            value: set.value.validate()?,
307
5199
            expiration: set.expiration,
308
5199
            last_updated: now,
309
5199
        };
310
5199
        let full_key = full_key(namespace, key);
311
5199
        let possible_existing_value =
312
5199
            if set.check.is_some() || set.return_previous_value || set.keep_existing_expiration {
313
1172
                Some(self.get(&full_key).map_err(Error::from)?)
314
            } else {
315
4027
                None
316
            };
317
5199
        let existing_value_ref = possible_existing_value.as_ref().and_then(Option::as_ref);
318

            
319
5199
        let updating = match set.check {
320
308
            Some(KeyCheck::OnlyIfPresent) => existing_value_ref.is_some(),
321
368
            Some(KeyCheck::OnlyIfVacant) => existing_value_ref.is_none(),
322
4523
            None => true,
323
        };
324
5199
        if updating {
325
4891
            if set.keep_existing_expiration {
326
124
                if let Some(existing_value) = existing_value_ref {
327
124
                    entry.expiration = existing_value.expiration;
328
124
                }
329
4767
            }
330
4891
            self.update_key_expiration(&full_key, entry.expiration);
331

            
332
4891
            let previous_value = if let Some(existing_value) = possible_existing_value {
333
                // we already fetched, no need to ask for the existing value back
334
864
                self.set(full_key, entry);
335
864
                existing_value
336
            } else {
337
4027
                self.replace(full_key, entry).map_err(Error::from)?
338
            };
339
4891
            if set.return_previous_value {
340
432
                Ok(Output::Value(previous_value.map(|entry| entry.value)))
341
4459
            } else if previous_value.is_none() {
342
1983
                Ok(Output::Status(KeyStatus::Inserted))
343
            } else {
344
2476
                Ok(Output::Status(KeyStatus::Updated))
345
            }
346
        } else {
347
308
            Ok(Output::Status(KeyStatus::NotChanged))
348
        }
349
5323
    }
350

            
351
6145
    pub fn update_key_expiration<'key>(
352
6145
        &mut self,
353
6145
        tree_key: impl Into<Cow<'key, str>>,
354
6145
        expiration: Option<Timestamp>,
355
6145
    ) {
356
6145
        let tree_key = tree_key.into();
357
6145
        let mut changed_first_expiration = false;
358
6145
        if let Some(expiration) = expiration {
359
815
            let key = if self.expiring_keys.contains_key(tree_key.as_ref()) {
360
                // Update the existing entry.
361
249
                let existing_entry_index = self
362
249
                    .expiration_order
363
249
                    .iter()
364
249
                    .enumerate()
365
249
                    .find_map(
366
249
                        |(index, key)| {
367
249
                            if &tree_key == key {
368
249
                                Some(index)
369
                            } else {
370
                                None
371
                            }
372
249
                        },
373
249
                    )
374
249
                    .unwrap();
375
249
                changed_first_expiration = existing_entry_index == 0;
376
249
                self.expiration_order.remove(existing_entry_index).unwrap()
377
            } else {
378
566
                tree_key.into_owned()
379
            };
380

            
381
            // Insert the key into the expiration_order queue
382
815
            let mut insert_at = None;
383
815
            for (index, expiring_key) in self.expiration_order.iter().enumerate() {
384
376
                if self.expiring_keys.get(expiring_key).unwrap() > &expiration {
385
156
                    insert_at = Some(index);
386
156
                    break;
387
220
                }
388
            }
389
815
            if let Some(insert_at) = insert_at {
390
156
                changed_first_expiration |= insert_at == 0;
391
156

            
392
156
                self.expiration_order.insert(insert_at, key.clone());
393
659
            } else {
394
659
                changed_first_expiration |= self.expiration_order.is_empty();
395
659
                self.expiration_order.push_back(key.clone());
396
659
            }
397
815
            self.expiring_keys.insert(key, expiration);
398
5330
        } else if self.expiring_keys.remove(tree_key.as_ref()).is_some() {
399
125
            let index = self
400
125
                .expiration_order
401
125
                .iter()
402
125
                .enumerate()
403
125
                .find_map(|(index, key)| {
404
125
                    if tree_key.as_ref() == key {
405
125
                        Some(index)
406
                    } else {
407
                        None
408
                    }
409
125
                })
410
125
                .unwrap();
411
125

            
412
125
            changed_first_expiration |= index == 0;
413
125
            self.expiration_order.remove(index);
414
5205
        }
415

            
416
6145
        if changed_first_expiration {
417
844
            self.update_background_worker_target();
418
5301
        }
419
6145
    }
420

            
421
4785
    fn execute_get_operation(
422
4785
        &mut self,
423
4785
        namespace: Option<&str>,
424
4785
        key: &str,
425
4785
        delete: bool,
426
4785
    ) -> Result<Output, bonsaidb_core::Error> {
427
4785
        let full_key = full_key(namespace, key);
428
4785
        let entry = if delete {
429
248
            self.remove(full_key).map_err(Error::from)?
430
        } else {
431
4537
            self.get(&full_key).map_err(Error::from)?
432
        };
433

            
434
4785
        Ok(Output::Value(entry.map(|e| e.value)))
435
4785
    }
436

            
437
992
    fn execute_delete_operation(
438
992
        &mut self,
439
992
        namespace: Option<&str>,
440
992
        key: &str,
441
992
    ) -> Result<Output, bonsaidb_core::Error> {
442
992
        let full_key = full_key(namespace, key);
443
992
        let value = self.remove(full_key).map_err(Error::from)?;
444
992
        if value.is_some() {
445
372
            Ok(Output::Status(KeyStatus::Deleted))
446
        } else {
447
620
            Ok(Output::Status(KeyStatus::NotChanged))
448
        }
449
992
    }
450

            
451
1241634
    fn execute_increment_operation(
452
1241634
        &mut self,
453
1241634
        namespace: Option<&str>,
454
1241634
        key: &str,
455
1241634
        amount: &Numeric,
456
1241634
        saturating: bool,
457
1241634
        now: Timestamp,
458
1241634
    ) -> Result<Output, bonsaidb_core::Error> {
459
1241634
        self.execute_numeric_operation(namespace, key, amount, saturating, now, increment)
460
1241634
    }
461

            
462
1240
    fn execute_decrement_operation(
463
1240
        &mut self,
464
1240
        namespace: Option<&str>,
465
1240
        key: &str,
466
1240
        amount: &Numeric,
467
1240
        saturating: bool,
468
1240
        now: Timestamp,
469
1240
    ) -> Result<Output, bonsaidb_core::Error> {
470
1240
        self.execute_numeric_operation(namespace, key, amount, saturating, now, decrement)
471
1240
    }
472

            
473
1242874
    fn execute_numeric_operation<F: Fn(&Numeric, &Numeric, bool) -> Numeric>(
474
1242874
        &mut self,
475
1242874
        namespace: Option<&str>,
476
1242874
        key: &str,
477
1242874
        amount: &Numeric,
478
1242874
        saturating: bool,
479
1242874
        now: Timestamp,
480
1242874
        op: F,
481
1242874
    ) -> Result<Output, bonsaidb_core::Error> {
482
1242874
        let full_key = full_key(namespace, key);
483
1242874
        let current = self.get(&full_key).map_err(Error::from)?;
484
1242874
        let mut entry = current.unwrap_or(Entry {
485
1242874
            value: Value::Numeric(Numeric::UnsignedInteger(0)),
486
1242874
            expiration: None,
487
1242874
            last_updated: now,
488
1242874
        });
489
1242874

            
490
1242874
        match entry.value {
491
1242626
            Value::Numeric(existing) => {
492
1242626
                let value = Value::Numeric(op(&existing, amount, saturating).validate()?);
493
1242502
                entry.value = value.clone();
494
1242502

            
495
1242502
                self.set(full_key, entry);
496
1242502
                Ok(Output::Value(Some(value)))
497
            }
498
248
            Value::Bytes(_) => Err(bonsaidb_core::Error::Database(String::from(
499
248
                "type of stored `Value` is not `Numeric`",
500
248
            ))),
501
        }
502
1242874
    }
503

            
504
1240
    fn remove(&mut self, key: String) -> Result<Option<Entry>, nebari::Error> {
505
1240
        self.update_key_expiration(&key, None);
506

            
507
1240
        if let Some(dirty_entry) = self.dirty_keys.get_mut(&key) {
508
285
            Ok(dirty_entry.take())
509
955
        } else if let Some(persisting_entry) = self
510
955
            .keys_being_persisted
511
955
            .as_ref()
512
955
            .and_then(|keys| keys.get(&key))
513
        {
514
121
            self.dirty_keys.insert(key, None);
515
121
            Ok(persisting_entry.clone())
516
        } else {
517
            // There might be a value on-disk we need to remove.
518
834
            let previous_value = Self::retrieve_key_from_disk(&self.roots, &key)?;
519
834
            self.dirty_keys.insert(key, None);
520
834
            Ok(previous_value)
521
        }
522
1240
    }
523

            
524
    fn get(&self, key: &str) -> Result<Option<Entry>, nebari::Error> {
525
1248583
        if let Some(entry) = self.dirty_keys.get(key) {
526
1159371
            Ok(entry.clone())
527
89212
        } else if let Some(persisting_entry) = self
528
89212
            .keys_being_persisted
529
89212
            .as_ref()
530
89212
            .and_then(|keys| keys.get(key))
531
        {
532
81480
            Ok(persisting_entry.clone())
533
        } else {
534
7732
            Self::retrieve_key_from_disk(&self.roots, key)
535
        }
536
1248583
    }
537

            
538
1243366
    fn set(&mut self, key: String, value: Entry) {
539
1243366
        self.dirty_keys.insert(key, Some(value));
540
1243366
    }
541

            
542
4027
    fn replace(&mut self, key: String, value: Entry) -> Result<Option<Entry>, nebari::Error> {
543
4027
        let mut value = Some(value);
544
4027
        let map_entry = self.dirty_keys.entry(key);
545
4027
        if matches!(map_entry, btree_map::Entry::Vacant(_)) {
546
            // This key is clean, and the caller is expecting the previous
547
            // value.
548
3005
            let stored_value = if let Some(persisting_entry) = self
549
3005
                .keys_being_persisted
550
3005
                .as_ref()
551
3005
                .and_then(|keys| keys.get(map_entry.key()))
552
            {
553
289
                persisting_entry.clone()
554
            } else {
555
2716
                Self::retrieve_key_from_disk(&self.roots, map_entry.key())?
556
            };
557
3005
            map_entry.or_insert(value);
558
3005
            Ok(stored_value)
559
        } else {
560
            // This key is already dirty, we can just replace the value and
561
            // return the old value.
562
1022
            map_entry.and_modify(|map_entry| {
563
1022
                std::mem::swap(&mut value, map_entry);
564
1022
            });
565
1022
            Ok(value)
566
        }
567
4027
    }
568

            
569
11282
    fn retrieve_key_from_disk(
570
11282
        roots: &Roots<AnyFile>,
571
11282
        key: &str,
572
11282
    ) -> Result<Option<Entry>, nebari::Error> {
573
11282
        roots
574
11282
            .tree(Unversioned::tree(KEY_TREE))?
575
11282
            .get(key.as_bytes())
576
11282
            .map(|current| current.and_then(|current| bincode::deserialize::<Entry>(&current).ok()))
577
11282
    }
578

            
579
1424915
    fn update_background_worker_target(&mut self) {
580
1424915
        let key_expiration_target = self.expiration_order.get(0).map(|key| {
581
3901
            let expiration_timeout = self.expiring_keys.get(key).unwrap();
582
3901
            *expiration_timeout
583
1424915
        });
584
1424915
        let now = Timestamp::now();
585
1424915
        let persisting = self.keys_being_persisted.is_some();
586
1424915
        let commit_target = (!persisting)
587
1424915
            .then(|| {
588
92417
                self.persistence.duration_until_next_commit(
589
92417
                    self.dirty_keys.len(),
590
92417
                    (now - self.last_commit).unwrap_or_default(),
591
92417
                )
592
1424915
            })
593
1424915
            .flatten()
594
1424915
            .map(|duration| now + duration);
595
1424915
        match (commit_target, key_expiration_target) {
596
84938
            (Some(target), _) | (_, Some(target)) if target <= now => {
597
81403
                self.background_worker_target
598
81403
                    .replace(BackgroundWorkerProcessTarget::Now);
599
81403
            }
600
            (Some(commit_target), Some(key_target)) => {
601
                let closest_target = key_target.min(commit_target);
602
                let new_target = BackgroundWorkerProcessTarget::Timestamp(closest_target);
603
                let _ = self.background_worker_target.update(new_target);
604
            }
605
3535
            (Some(target), None) | (None, Some(target)) => {
606
3535
                let _ = self
607
3535
                    .background_worker_target
608
3535
                    .update(BackgroundWorkerProcessTarget::Timestamp(target));
609
3535
            }
610
1339977
            (None, None) => {
611
1339977
                let _ = self
612
1339977
                    .background_worker_target
613
1339977
                    .update(BackgroundWorkerProcessTarget::Never);
614
1339977
            }
615
        }
616
1424915
    }
617

            
618
1335421
    fn remove_expired_keys(&mut self, now: Timestamp) {
619
1335801
        while !self.expiration_order.is_empty()
620
2145
            && self.expiring_keys.get(&self.expiration_order[0]).unwrap() <= &now
621
380
        {
622
380
            let key = self.expiration_order.pop_front().unwrap();
623
380
            self.expiring_keys.remove(&key);
624
380
            self.dirty_keys.insert(key, None);
625
380
        }
626
1335421
    }
627

            
628
1334925
    fn needs_commit(&mut self, now: Timestamp) -> bool {
629
1334925
        if self.keys_being_persisted.is_some() {
630
1243154
            false
631
        } else {
632
91771
            let since_last_commit = (now - self.last_commit).unwrap_or_default();
633
91771
            self.persistence
634
91771
                .should_commit(self.dirty_keys.len(), since_last_commit)
635
        }
636
1334925
    }
637

            
638
112883
    fn stage_dirty_keys(&mut self) -> Option<Arc<BTreeMap<String, Option<Entry>>>> {
639
112883
        if !self.dirty_keys.is_empty() && self.keys_being_persisted.is_none() {
640
89176
            let keys = Arc::new(std::mem::take(&mut self.dirty_keys));
641
89176
            self.keys_being_persisted = Some(keys.clone());
642
89176
            Some(keys)
643
        } else {
644
23707
            None
645
        }
646
112883
    }
647

            
648
    pub fn commit_dirty_keys(&mut self, state: &Arc<Mutex<KeyValueState>>) -> bool {
649
112582
        if let Some(keys) = self.stage_dirty_keys() {
650
89176
            let roots = self.roots.clone();
651
89176
            let state = state.clone();
652
89176
            std::thread::Builder::new()
653
89176
                .name(String::from("keyvalue-persist"))
654
89176
                .spawn(move || Self::persist_keys(&state, &roots, &keys))
655
89176
                .unwrap();
656
89176
            self.last_commit = Timestamp::now();
657
89176
            true
658
        } else {
659
23406
            false
660
        }
661
112582
    }
662

            
663
    #[cfg(test)]
664
1
    pub fn persistence_watcher(&self) -> Watcher<Timestamp> {
665
1
        self.last_persistence.watch()
666
1
    }
667

            
668
89176
    fn persist_keys(
669
89176
        key_value_state: &Arc<Mutex<KeyValueState>>,
670
89176
        roots: &Roots<AnyFile>,
671
89176
        keys: &BTreeMap<String, Option<Entry>>,
672
89176
    ) -> Result<(), bonsaidb_core::Error> {
673
89176
        let mut transaction = roots
674
89176
            .transaction(&[Unversioned::tree(KEY_TREE)])
675
89176
            .map_err(Error::from)?;
676
89176
        let all_keys = keys
677
89176
            .keys()
678
89568
            .map(|key| ArcBytes::from(key.as_bytes().to_vec()))
679
89176
            .collect();
680
89176
        let mut changed_keys = Vec::new();
681
89176
        transaction
682
89176
            .tree::<Unversioned>(0)
683
89176
            .unwrap()
684
89176
            .modify(
685
89176
                all_keys,
686
89568
                Operation::CompareSwap(CompareSwap::new(&mut |key, existing_value| {
687
89568
                    let full_key = std::str::from_utf8(key).unwrap();
688
89568
                    let (namespace, key) = split_key(full_key).unwrap();
689

            
690
89568
                    if let Some(new_value) = keys.get(full_key).unwrap() {
691
88237
                        changed_keys.push(ChangedKey {
692
88237
                            namespace,
693
88237
                            key,
694
88237
                            deleted: false,
695
88237
                        });
696
88237
                        let bytes = bincode::serialize(new_value).unwrap();
697
88237
                        nebari::tree::KeyOperation::Set(ArcBytes::from(bytes))
698
1331
                    } else if existing_value.is_some() {
699
809
                        changed_keys.push(ChangedKey {
700
809
                            namespace,
701
809
                            key,
702
809
                            deleted: existing_value.is_some(),
703
809
                        });
704
809
                        nebari::tree::KeyOperation::Remove
705
                    } else {
706
522
                        nebari::tree::KeyOperation::Skip
707
                    }
708
89568
                })),
709
89176
            )
710
89176
            .map_err(Error::from)?;
711

            
712
89176
        if !changed_keys.is_empty() {
713
88654
            transaction
714
88654
                .entry_mut()
715
88654
                .set_data(compat::serialize_executed_transaction_changes(
716
88654
                    &Changes::Keys(changed_keys),
717
88654
                )?)
718
88654
                .map_err(Error::from)?;
719
88654
            transaction.commit().map_err(Error::from)?;
720
522
        }
721

            
722
        // If we are shutting down, check if we still have dirty keys.
723
89146
        let final_keys = {
724
89146
            let mut state = key_value_state.lock();
725
89146
            state.last_persistence.replace(Timestamp::now());
726
89146
            state.keys_being_persisted = None;
727
89146
            state.update_background_worker_target();
728
89146
            // This block is a little ugly to avoid having to acquire the lock
729
89146
            // twice. If we're shutting down and have no dirty keys, we notify
730
89146
            // the waiting shutdown task. If we have any dirty keys, we wait do
731
89146
            // to that step because we're going to recurse and reach this spot
732
89146
            // again.
733
89146
            if state.shutdown.is_some() {
734
301
                let staged_keys = state.stage_dirty_keys();
735
301
                if staged_keys.is_none() {
736
301
                    let shutdown = state.shutdown.take().unwrap();
737
301
                    let _ = shutdown.send(());
738
301
                }
739
301
                staged_keys
740
            } else {
741
88845
                None
742
            }
743
        };
744
89146
        if let Some(final_keys) = final_keys {
745
            Self::persist_keys(key_value_state, roots, &final_keys)?;
746
89146
        }
747
89146
        Ok(())
748
89146
    }
749
}
750

            
751
26847
pub fn background_worker(
752
26847
    key_value_state: &Weak<Mutex<KeyValueState>>,
753
26847
    timestamp_receiver: &mut Watcher<BackgroundWorkerProcessTarget>,
754
26847
) {
755
189333
    loop {
756
189333
        let mut perform_operations = false;
757
189333
        let current_target = *timestamp_receiver.read();
758
189333
        match current_target {
759
            // With no target, sleep until we receive a target.
760
            BackgroundWorkerProcessTarget::Never => {
761
107792
                if timestamp_receiver.watch().is_err() {
762
23806
                    break;
763
83986
                }
764
            }
765
906
            BackgroundWorkerProcessTarget::Timestamp(target) => {
766
906
                // With a target, we need to wait to receive a target only as
767
906
                // long as there is time remaining.
768
906
                let remaining = target - Timestamp::now();
769
906
                if let Some(remaining) = remaining {
770
                    // recv_timeout panics if Instant::checked_add(remaining)
771
                    // fails. So, we will cap the sleep time at 1 day.
772
906
                    let remaining = remaining.min(Duration::from_secs(60 * 60 * 24));
773
906
                    match timestamp_receiver.watch_timeout(remaining) {
774
875
                        Ok(_) | Err(watchable::TimeoutError::Timeout) => {
775
875
                            perform_operations = true;
776
875
                        }
777
1
                        Err(watchable::TimeoutError::Disconnected) => break,
778
                    }
779
                } else {
780
                    perform_operations = true;
781
                }
782
            }
783
80635
            BackgroundWorkerProcessTarget::Now => {
784
80635
                perform_operations = true;
785
80635
            }
786
        };
787

            
788
165496
        let key_value_state = match key_value_state.upgrade() {
789
165336
            Some(state) => state,
790
            None => {
791
                // The last reference has been dropped.
792
160
                break;
793
            }
794
        };
795

            
796
165336
        if perform_operations {
797
84297
            let mut state = key_value_state.lock();
798
84297
            let now = Timestamp::now();
799
84297
            state.remove_expired_keys(now);
800
84297
            if state.needs_commit(now) {
801
67000
                state.commit_dirty_keys(&key_value_state);
802
67007
            }
803
81447
            state.update_background_worker_target();
804
81039
        }
805
    }
806
23967
}
807

            
808
1343512
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
809
pub enum BackgroundWorkerProcessTarget {
810
    Now,
811
    Timestamp(Timestamp),
812
    Never,
813
}
814

            
815
#[derive(Debug)]
816
pub struct ExpirationLoader {
817
    pub database: Database,
818
    pub launched_at: Timestamp,
819
}
820

            
821
impl Keyed<Task> for ExpirationLoader {
822
12500
    fn key(&self) -> Task {
823
12500
        Task::ExpirationLoader(self.database.data.name.clone())
824
12500
    }
825
}
826

            
827
impl Job for ExpirationLoader {
828
    type Output = ();
829
    type Error = Error;
830

            
831
24340
    #[cfg_attr(feature = "tracing", tracing::instrument)]
832
12170
    fn execute(&mut self) -> Result<Self::Output, Self::Error> {
833
        let database = self.database.clone();
834
        let launched_at = self.launched_at;
835

            
836
        for ((namespace, key), entry) in database.all_key_value_entries()? {
837
            if entry.last_updated < launched_at && entry.expiration.is_some() {
838
                self.database
839
                    .update_key_expiration(full_key(namespace.as_deref(), &key), entry.expiration);
840
            }
841
        }
842

            
843
        self.database
844
            .storage()
845
            .instance
846
            .tasks()
847
            .mark_key_value_expiration_loaded(self.database.data.name.clone());
848

            
849
        Ok(())
850
    }
851
}
852

            
853
#[cfg(test)]
854
mod tests {
855
    use std::time::Duration;
856

            
857
    use bonsaidb_core::{
858
        arc_bytes::serde::Bytes,
859
        test_util::{TestDirectory, TimingTest},
860
    };
861
    use nebari::io::any::{AnyFile, AnyFileManager};
862

            
863
    use super::*;
864
    use crate::{config::PersistenceThreshold, database::Context};
865

            
866
6
    fn run_test_with_persistence<
867
6
        F: Fn(Context, nebari::Roots<AnyFile>) -> anyhow::Result<()> + Send,
868
6
    >(
869
6
        name: &str,
870
6
        persistence: KeyValuePersistence,
871
6
        test_contents: &F,
872
6
    ) -> anyhow::Result<()> {
873
6
        let dir = TestDirectory::new(name);
874
6
        let sled = nebari::Config::new(&dir)
875
6
            .file_manager(AnyFileManager::std())
876
6
            .open()?;
877

            
878
6
        let context = Context::new(sled.clone(), persistence);
879
6

            
880
6
        test_contents(context, sled)?;
881

            
882
6
        Ok(())
883
6
    }
884

            
885
5
    fn run_test<F: Fn(Context, nebari::Roots<AnyFile>) -> anyhow::Result<()> + Send>(
886
5
        name: &str,
887
5
        test_contents: F,
888
5
    ) -> anyhow::Result<()> {
889
5
        run_test_with_persistence(name, KeyValuePersistence::default(), &test_contents)
890
5
    }
891

            
892
1
    #[test]
893
1
    fn basic_expiration() -> anyhow::Result<()> {
894
1
        run_test("kv-basic-expiration", |sender, sled| {
895
            loop {
896
1
                sled.delete_tree(KEY_TREE)?;
897
1
                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
898
1
                tree.set(b"atree\0akey", b"somevalue")?;
899
1
                let timing = TimingTest::new(Duration::from_millis(100));
900
1
                sender.update_key_expiration(
901
1
                    full_key(Some("atree"), "akey"),
902
1
                    Some(Timestamp::now() + Duration::from_millis(100)),
903
1
                );
904
1
                if !timing.wait_until(Duration::from_secs(1)) {
905
                    println!("basic_expiration restarting due to timing discrepency");
906
                    continue;
907
1
                }
908
1
                assert!(tree.get(b"akey")?.is_none());
909
1
                break;
910
1
            }
911
1

            
912
1
            Ok(())
913
1
        })
914
1
    }
915

            
916
1
    #[test]
917
1
    fn updating_expiration() -> anyhow::Result<()> {
918
1
        run_test("kv-updating-expiration", |sender, sled| {
919
            loop {
920
1
                sled.delete_tree(KEY_TREE)?;
921
1
                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
922
1
                tree.set(b"atree\0akey", b"somevalue")?;
923
1
                let timing = TimingTest::new(Duration::from_millis(100));
924
1
                sender.update_key_expiration(
925
1
                    full_key(Some("atree"), "akey"),
926
1
                    Some(Timestamp::now() + Duration::from_millis(100)),
927
1
                );
928
1
                sender.update_key_expiration(
929
1
                    full_key(Some("atree"), "akey"),
930
1
                    Some(Timestamp::now() + Duration::from_secs(1)),
931
1
                );
932
1
                if timing.elapsed() > Duration::from_millis(100)
933
1
                    || !timing.wait_until(Duration::from_millis(500))
934
                {
935
                    continue;
936
1
                }
937
1
                assert!(tree.get(b"atree\0akey")?.is_some());
938

            
939
1
                timing.wait_until(Duration::from_secs_f32(1.5));
940
1
                assert_eq!(tree.get(b"atree\0akey")?, None);
941
1
                break;
942
1
            }
943
1

            
944
1
            Ok(())
945
1
        })
946
1
    }
947

            
948
1
    #[test]
949
1
    fn multiple_keys_expiration() -> anyhow::Result<()> {
950
1
        run_test("kv-multiple-keys-expiration", |sender, sled| {
951
            loop {
952
1
                sled.delete_tree(KEY_TREE)?;
953
1
                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
954
1
                tree.set(b"atree\0akey", b"somevalue")?;
955
1
                tree.set(b"atree\0bkey", b"somevalue")?;
956

            
957
1
                let timing = TimingTest::new(Duration::from_millis(100));
958
1
                sender.update_key_expiration(
959
1
                    full_key(Some("atree"), "akey"),
960
1
                    Some(Timestamp::now() + Duration::from_millis(100)),
961
1
                );
962
1
                sender.update_key_expiration(
963
1
                    full_key(Some("atree"), "bkey"),
964
1
                    Some(Timestamp::now() + Duration::from_secs(1)),
965
1
                );
966
1

            
967
1
                if !timing.wait_until(Duration::from_millis(200)) {
968
                    continue;
969
1
                }
970

            
971
1
                assert!(tree.get(b"atree\0akey")?.is_none());
972
1
                assert!(tree.get(b"atree\0bkey")?.is_some());
973
1
                timing.wait_until(Duration::from_millis(1100));
974
1
                assert!(tree.get(b"atree\0bkey")?.is_none());
975

            
976
1
                break;
977
1
            }
978
1

            
979
1
            Ok(())
980
1
        })
981
1
    }
982

            
983
1
    #[test]
984
1
    fn clearing_expiration() -> anyhow::Result<()> {
985
1
        run_test("kv-clearing-expiration", |sender, sled| {
986
            loop {
987
1
                sled.delete_tree(KEY_TREE)?;
988
1
                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
989
1
                tree.set(b"atree\0akey", b"somevalue")?;
990
1
                let timing = TimingTest::new(Duration::from_millis(100));
991
1
                sender.update_key_expiration(
992
1
                    full_key(Some("atree"), "akey"),
993
1
                    Some(Timestamp::now() + Duration::from_millis(100)),
994
1
                );
995
1
                sender.update_key_expiration(full_key(Some("atree"), "akey"), None);
996
1
                if timing.elapsed() > Duration::from_millis(100) {
997
                    // Restart, took too long.
998
                    continue;
999
1
                }
1
                timing.wait_until(Duration::from_millis(150));
1
                assert!(tree.get(b"atree\0akey")?.is_some());
1
                break;
1
            }
1

            
1
            Ok(())
1
        })
1
    }

            
1
    #[test]
1
    fn out_of_order_expiration() -> anyhow::Result<()> {
1
        run_test("kv-out-of-order-expiration", |context, sled| loop {
1
            context.update_key_expiration(full_key(Some("atree"), "akey"), None);
1
            context.update_key_expiration(full_key(Some("atree"), "bkey"), None);
1
            context.update_key_expiration(full_key(Some("atree"), "ckey"), None);
1
            let mut persistence_watcher = context.kv_persistence_watcher();
1
            drop(sled.delete_tree(KEY_TREE));
1
            let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
1
            tree.set(b"atree\0akey", b"somevalue")?;
1
            tree.set(b"atree\0bkey", b"somevalue")?;
1
            tree.set(b"atree\0ckey", b"somevalue")?;
1
            let timing = TimingTest::new(Duration::from_millis(100));
1
            context.update_key_expiration(
1
                full_key(Some("atree"), "akey"),
1
                Some(Timestamp::now() + Duration::from_secs(3)),
1
            );
1
            context.update_key_expiration(
1
                full_key(Some("atree"), "ckey"),
1
                Some(Timestamp::now() + Duration::from_secs(1)),
1
            );
1
            context.update_key_expiration(
1
                full_key(Some("atree"), "bkey"),
1
                Some(Timestamp::now() + Duration::from_secs(2)),
1
            );
1
            persistence_watcher.mark_read();
1
            if timing.elapsed() > Duration::from_millis(500) {
                println!("Restarting");
                continue;
1
            }
1

            
1
            // Wait for the first key to expire.
1
            persistence_watcher
1
                .watch_timeout(Duration::from_secs(5))
1
                .unwrap();
1
            persistence_watcher.mark_read();
1
            if timing.elapsed() > Duration::from_millis(1500) {
                println!("Restarting");
                continue;
1
            }
1
            assert!(tree.get(b"atree\0akey")?.is_some());
1
            assert!(tree.get(b"atree\0bkey")?.is_some());
1
            assert!(tree.get(b"atree\0ckey")?.is_none());

            
            // Wait for the next key to expire.
1
            persistence_watcher
1
                .watch_timeout(Duration::from_secs(5))
1
                .unwrap();
1
            persistence_watcher.mark_read();
1
            if timing.elapsed() > Duration::from_millis(2500) {
                println!("Restarting");
                continue;
1
            }
1
            assert!(tree.get(b"atree\0akey")?.is_some());
1
            assert!(tree.get(b"atree\0bkey")?.is_none());

            
            // Wait for the final key to expire.
1
            persistence_watcher
1
                .watch_timeout(Duration::from_secs(5))
1
                .unwrap();
1
            if timing.elapsed() > Duration::from_millis(3500) {
                println!("Restarting");
                continue;
1
            }
1
            assert!(tree.get(b"atree\0akey")?.is_none());

            
1
            return Ok(());
1
        })
1
    }

            
1
    #[test]
1
    fn basic_persistence() -> anyhow::Result<()> {
1
        run_test_with_persistence(
1
            "kv-basic-persistence",
1
            KeyValuePersistence::lazy([
1
                PersistenceThreshold::after_changes(2),
1
                PersistenceThreshold::after_changes(1).and_duration(Duration::from_secs(2)),
1
            ]),
1
            &|sender, sled| {
1
                loop {
1
                    let timing = TimingTest::new(Duration::from_millis(200));
1
                    let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
1
                    // Set three keys in quick succession. The first two should
1
                    // persist immediately, and the third should show up after 2
1
                    // seconds.
1
                    sender
1
                        .perform_kv_operation(KeyOperation {
1
                            namespace: None,
1
                            key: String::from("key1"),
1
                            command: Command::Set(SetCommand {
1
                                value: Value::Bytes(Bytes::default()),
1
                                expiration: None,
1
                                keep_existing_expiration: false,
1
                                check: None,
1
                                return_previous_value: false,
1
                            }),
1
                        })
1
                        .unwrap();
1
                    sender
1
                        .perform_kv_operation(KeyOperation {
1
                            namespace: None,
1
                            key: String::from("key2"),
1
                            command: Command::Set(SetCommand {
1
                                value: Value::Bytes(Bytes::default()),
1
                                expiration: None,
1
                                keep_existing_expiration: false,
1
                                check: None,
1
                                return_previous_value: false,
1
                            }),
1
                        })
1
                        .unwrap();
1
                    sender
1
                        .perform_kv_operation(KeyOperation {
1
                            namespace: None,
1
                            key: String::from("key3"),
1
                            command: Command::Set(SetCommand {
1
                                value: Value::Bytes(Bytes::default()),
1
                                expiration: None,
1
                                keep_existing_expiration: false,
1
                                check: None,
1
                                return_previous_value: false,
1
                            }),
1
                        })
1
                        .unwrap();
1
                    // Persisting is handled in the background. Sleep for a bit
1
                    // to give it a chance to happen, but not long enough to
1
                    // trigger the longer time-based rule.
1
                    if timing.elapsed() > Duration::from_millis(500)
1
                        || !timing.wait_until(Duration::from_secs(1))
1
                    {
1
                        println!("basic_persistence restarting due to timing discrepency");
                        continue;
1
                    }
1
                    assert!(tree.get(b"\0key1").unwrap().is_some());
1
                    assert!(tree.get(b"\0key2").unwrap().is_some());
1
                    assert!(tree.get(b"\0key3").unwrap().is_none());
1
                    if !timing.wait_until(Duration::from_secs(3)) {
1
                        println!("basic_persistence restarting due to timing discrepency");
                        continue;
1
                    }
1
                    assert!(tree.get(b"\0key3").unwrap().is_some());
1
                    break;
1
                }
1

            
1
                Ok(())
1
            },
1
        )
1
    }

            
1
    #[test]
1
    fn saves_on_drop() -> anyhow::Result<()> {
1
        let dir = TestDirectory::new("saves-on-drop.bonsaidb");
1
        let sled = nebari::Config::new(&dir)
1
            .file_manager(AnyFileManager::std())
1
            .open()?;
1
        let tree = sled.tree(Unversioned::tree(KEY_TREE))?;

            
1
        let context = Context::new(
1
            sled,
1
            KeyValuePersistence::lazy([PersistenceThreshold::after_changes(2)]),
1
        );
1
        context
1
            .perform_kv_operation(KeyOperation {
1
                namespace: None,
1
                key: String::from("key1"),
1
                command: Command::Set(SetCommand {
1
                    value: Value::Bytes(Bytes::default()),
1
                    expiration: None,
1
                    keep_existing_expiration: false,
1
                    check: None,
1
                    return_previous_value: false,
1
                }),
1
            })
1
            .unwrap();
1
        assert!(tree.get(b"\0key1").unwrap().is_none());
1
        drop(context);
1
        // Dropping spawns a task that should persist the keys. Give a moment
1
        // for the runtime to execute the task.
1
        std::thread::sleep(Duration::from_millis(100));
1
        assert!(tree.get(b"\0key1").unwrap().is_some());

            
1
        Ok(())
1
    }
}