1
use std::{
2
    collections::HashMap,
3
    fmt::{Debug, Display},
4
    marker::PhantomData,
5
    path::{Path, PathBuf},
6
    sync::Arc,
7
};
8

            
9
use async_lock::{Mutex, RwLock};
10
use async_trait::async_trait;
11
pub use bonsaidb_core::circulate::Relay;
12
#[cfg(feature = "password-hashing")]
13
use bonsaidb_core::connection::{Authenticated, Authentication};
14
#[cfg(any(feature = "encryption", feature = "compression"))]
15
use bonsaidb_core::document::KeyId;
16
use bonsaidb_core::{
17
    admin::{
18
        self,
19
        database::{self, ByName, Database as DatabaseRecord},
20
        Admin, ADMIN_DATABASE_NAME,
21
    },
22
    connection::{self, Connection, StorageConnection},
23
    schema::{Schema, SchemaName, Schematic},
24
};
25
#[cfg(feature = "multiuser")]
26
use bonsaidb_core::{
27
    admin::{user::User, PermissionGroup, Role},
28
    document::CollectionDocument,
29
    schema::{Nameable, NamedCollection},
30
};
31
use bonsaidb_utils::{fast_async_lock, fast_async_read, fast_async_write};
32
use futures::TryFutureExt;
33
use itertools::Itertools;
34
use nebari::{
35
    io::{
36
        any::{AnyFile, AnyFileManager},
37
        FileManager,
38
    },
39
    ChunkCache, ThreadPool,
40
};
41
use rand::{thread_rng, Rng};
42
use tokio::{
43
    fs::{self, File},
44
    io::{AsyncReadExt, AsyncWriteExt},
45
};
46

            
47
#[cfg(feature = "compression")]
48
use crate::config::Compression;
49
#[cfg(feature = "encryption")]
50
use crate::vault::{self, LocalVaultKeyStorage, Vault};
51
use crate::{
52
    config::{KeyValuePersistence, StorageConfiguration},
53
    database::Context,
54
    jobs::manager::Manager,
55
    tasks::TaskManager,
56
    Database, Error,
57
};
58

            
59
#[cfg(feature = "password-hashing")]
60
mod argon;
61

            
62
mod backup;
63
pub use backup::BackupLocation;
64

            
65
/// A file-based, multi-database, multi-user database engine.
66
///
67
/// ## Converting from `Database::open` to `Storage::open`
68
///
69
/// [`Database::open`](Database::open) is a simple method that uses `Storage` to
70
/// create a database named `default` with the schema provided. These two ways
71
/// of opening the database are the same:
72
///
73
/// ```rust
74
/// // `bonsaidb_core` is re-exported to `bonsaidb::core` or `bonsaidb_local::core`.
75
/// use bonsaidb_core::{connection::StorageConnection, schema::Schema};
76
/// // `bonsaidb_local` is re-exported to `bonsaidb::local` if using the omnibus crate.
77
/// use bonsaidb_local::{
78
///     config::{Builder, StorageConfiguration},
79
///     Database, Storage,
80
/// };
81
/// # async fn open<MySchema: Schema>() -> anyhow::Result<()> {
82
/// // This creates a Storage instance, creates a database, and returns it.
83
/// let db = Database::open::<MySchema>(StorageConfiguration::new("my-db.bonsaidb")).await?;
84
///
85
/// // This is the equivalent code being executed:
86
/// let storage =
87
///     Storage::open(StorageConfiguration::new("my-db.bonsaidb").with_schema::<MySchema>()?)
88
///         .await?;
89
/// storage.create_database::<MySchema>("default", true).await?;
90
/// let db = storage.database::<MySchema>("default").await?;
91
/// #     Ok(())
92
/// # }
93
/// ```
94
///
95
/// ## Using multiple databases
96
///
97
/// This example shows how to use `Storage` to create and use multiple databases
98
/// with multiple schemas:
99
///
100
/// ```rust
101
/// use bonsaidb_core::{
102
///     connection::StorageConnection,
103
///     schema::{Collection, Schema},
104
/// };
105
/// use bonsaidb_local::{
106
///     config::{Builder, StorageConfiguration},
107
///     Storage,
108
/// };
109
/// use serde::{Deserialize, Serialize};
110
///
111
/// #[derive(Debug, Schema)]
112
/// #[schema(name = "my-schema", collections = [BlogPost, Author])]
113
/// # #[schema(core = bonsaidb_core)]
114
/// struct MySchema;
115
///
116
/// #[derive(Debug, Serialize, Deserialize, Collection)]
117
/// #[collection(name = "blog-posts")]
118
/// # #[collection(core = bonsaidb_core)]
119
/// struct BlogPost {
120
///     pub title: String,
121
///     pub contents: String,
122
///     pub author_id: u64,
123
/// }
124
///
125
/// #[derive(Debug, Serialize, Deserialize, Collection)]
126
/// #[collection(name = "blog-posts")]
127
/// # #[collection(core = bonsaidb_core)]
128
/// struct Author {
129
///     pub name: String,
130
/// }
131
///
132
/// # async fn test_fn() -> Result<(), bonsaidb_core::Error> {
133
/// let storage = Storage::open(
134
///     StorageConfiguration::new("my-db.bonsaidb")
135
///         .with_schema::<BlogPost>()?
136
///         .with_schema::<MySchema>()?,
137
/// )
138
/// .await?;
139
///
140
/// storage
141
///     .create_database::<BlogPost>("ectons-blog", true)
142
///     .await?;
143
/// let ectons_blog = storage.database::<BlogPost>("ectons-blog").await?;
144
/// storage
145
///     .create_database::<MySchema>("another-db", true)
146
///     .await?;
147
/// let another_db = storage.database::<MySchema>("another-db").await?;
148
///
149
/// #     Ok(())
150
/// # }
151
/// ```
152
198928
#[derive(Debug, Clone)]
153
pub struct Storage {
154
    data: Arc<Data>,
155
}
156

            
157
#[derive(Debug)]
158
struct Data {
159
    id: StorageId,
160
    path: PathBuf,
161
    threadpool: ThreadPool<AnyFile>,
162
    file_manager: AnyFileManager,
163
    pub(crate) tasks: TaskManager,
164
    schemas: RwLock<HashMap<SchemaName, Arc<dyn DatabaseOpener>>>,
165
    available_databases: RwLock<HashMap<String, SchemaName>>,
166
    open_roots: Mutex<HashMap<String, Context>>,
167
    #[cfg(feature = "password-hashing")]
168
    argon: argon::Hasher,
169
    #[cfg(feature = "encryption")]
170
    pub(crate) vault: Arc<Vault>,
171
    #[cfg(feature = "encryption")]
172
    default_encryption_key: Option<KeyId>,
173
    #[cfg(any(feature = "compression", feature = "encryption"))]
174
    tree_vault: Option<TreeVault>,
175
    pub(crate) key_value_persistence: KeyValuePersistence,
176
    chunk_cache: ChunkCache,
177
    pub(crate) check_view_integrity_on_database_open: bool,
178
    relay: Relay,
179
}
180

            
181
impl Storage {
182
    /// Creates or opens a multi-database [`Storage`] with its data stored in `directory`.
183
2594
    pub async fn open(configuration: StorageConfiguration) -> Result<Self, Error> {
184
171
        let owned_path = configuration
185
171
            .path
186
171
            .clone()
187
171
            .unwrap_or_else(|| PathBuf::from("db.bonsaidb"));
188
171
        let file_manager = if configuration.memory_only {
189
30
            AnyFileManager::memory()
190
        } else {
191
141
            AnyFileManager::std()
192
        };
193

            
194
171
        let manager = Manager::default();
195
684
        for _ in 0..configuration.workers.worker_count {
196
684
            manager.spawn_worker();
197
684
        }
198
171
        let tasks = TaskManager::new(manager);
199
171

            
200
171
        fs::create_dir_all(&owned_path).await?;
201

            
202
342
        let id = Self::lookup_or_create_id(&configuration, &owned_path).await?;
203

            
204
        #[cfg(feature = "encryption")]
205
170
        let vault = {
206
171
            let vault_key_storage = match configuration.vault_key_storage {
207
3
                Some(storage) => storage,
208
                None => Arc::new(
209
168
                    LocalVaultKeyStorage::new(owned_path.join("vault-keys"))
210
153
                        .await
211
168
                        .map_err(|err| Error::Vault(vault::Error::Initializing(err.to_string())))?,
212
                ),
213
            };
214

            
215
1285
            Arc::new(Vault::initialize(id, &owned_path, vault_key_storage).await?)
216
        };
217

            
218
170
        let check_view_integrity_on_database_open = configuration.views.check_integrity_on_open;
219
170
        let key_value_persistence = configuration.key_value_persistence;
220
170
        #[cfg(feature = "password-hashing")]
221
170
        let argon = argon::Hasher::new(configuration.argon);
222
170
        #[cfg(feature = "encryption")]
223
170
        let default_encryption_key = configuration.default_encryption_key;
224
170
        #[cfg(all(feature = "compression", feature = "encryption"))]
225
170
        let tree_vault = TreeVault::new_if_needed(
226
170
            default_encryption_key.clone(),
227
170
            &vault,
228
170
            configuration.default_compression,
229
170
        );
230
        #[cfg(all(not(feature = "compression"), feature = "encryption"))]
231
        let tree_vault = TreeVault::new_if_needed(default_encryption_key.clone(), &vault);
232
        #[cfg(all(feature = "compression", not(feature = "encryption")))]
233
        let tree_vault = TreeVault::new_if_needed(configuration.default_compression);
234

            
235
170
        let storage = tokio::task::spawn_blocking::<_, Result<Self, Error>>(move || {
236
170
            Ok(Self {
237
170
                data: Arc::new(Data {
238
170
                    id,
239
170
                    tasks,
240
170
                    #[cfg(feature = "password-hashing")]
241
170
                    argon,
242
170
                    #[cfg(feature = "encryption")]
243
170
                    vault,
244
170
                    #[cfg(feature = "encryption")]
245
170
                    default_encryption_key,
246
170
                    #[cfg(any(feature = "compression", feature = "encryption"))]
247
170
                    tree_vault,
248
170
                    path: owned_path,
249
170
                    file_manager,
250
170
                    chunk_cache: ChunkCache::new(2000, 160_384),
251
170
                    threadpool: ThreadPool::default(),
252
170
                    schemas: RwLock::new(configuration.initial_schemas),
253
170
                    available_databases: RwLock::default(),
254
170
                    open_roots: Mutex::default(),
255
170
                    key_value_persistence,
256
170
                    check_view_integrity_on_database_open,
257
170
                    relay: Relay::default(),
258
170
                }),
259
170
            })
260
170
        })
261
122
        .await??;
262

            
263
355
        storage.cache_available_databases().await?;
264

            
265
170
        storage.create_admin_database_if_needed().await?;
266

            
267
170
        Ok(storage)
268
171
    }
269

            
270
    /// Returns the path of the database storage.
271
    #[must_use]
272
16354
    pub fn path(&self) -> &Path {
273
16354
        &self.data.path
274
16354
    }
275

            
276
2619
    async fn lookup_or_create_id(
277
2619
        configuration: &StorageConfiguration,
278
2619
        path: &Path,
279
2619
    ) -> Result<StorageId, Error> {
280
171
        Ok(StorageId(if let Some(id) = configuration.unique_id {
281
            // The configuraiton id override is not persisted to disk. This is
282
            // mostly to prevent someone from accidentally adding this
283
            // configuration, realizing it breaks things, and then wanting to
284
            // revert. This makes reverting to the old value easier.
285
            id
286
        } else {
287
            // Load/Store a randomly generated id into a file. While the value
288
            // is numerical, the file contents are the ascii decimal, making it
289
            // easier for a human to view, and if needed, edit.
290
171
            let id_path = path.join("server-id");
291
171

            
292
171
            if id_path.exists() {
293
                // This value is important enought to not allow launching the
294
                // server if the file can't be read or contains unexpected data.
295
17
                let existing_id = String::from_utf8(
296
17
                    File::open(id_path)
297
17
                        .and_then(|mut f| async move {
298
17
                            let mut bytes = Vec::new();
299
34
                            f.read_to_end(&mut bytes).await.map(|_| bytes)
300
51
                        })
301
51
                        .await
302
17
                        .expect("error reading server-id file"),
303
17
                )
304
17
                .expect("server-id contains invalid data");
305
17

            
306
17
                existing_id.parse().expect("server-id isn't numeric")
307
            } else {
308
154
                let id = { thread_rng().gen::<u64>() };
309
154
                File::create(id_path)
310
154
                    .and_then(|mut file| async move {
311
154
                        let id = id.to_string();
312
154
                        file.write_all(id.as_bytes()).await?;
313
154
                        file.shutdown().await
314
291
                    })
315
291
                    .await
316
154
                    .map_err(|err| {
317
                        Error::Core(bonsaidb_core::Error::Configuration(format!(
318
                            "Error writing server-id file: {}",
319
                            err
320
                        )))
321
154
                    })?;
322
154
                id
323
            }
324
        }))
325
171
    }
326

            
327
2594
    async fn cache_available_databases(&self) -> Result<(), Error> {
328
170
        let available_databases = self
329
170
            .admin()
330
170
            .await
331
170
            .view::<ByName>()
332
186
            .query()
333
185
            .await?
334
170
            .into_iter()
335
178
            .map(|map| (map.key, map.value))
336
170
            .collect();
337
170
        let mut storage_databases = fast_async_write!(self.data.available_databases);
338
170
        *storage_databases = available_databases;
339
170
        Ok(())
340
170
    }
341

            
342
2594
    async fn create_admin_database_if_needed(&self) -> Result<(), Error> {
343
170
        self.register_schema::<Admin>().await?;
344
170
        match self.database::<Admin>(ADMIN_DATABASE_NAME).await {
345
16
            Ok(_) => {}
346
            Err(bonsaidb_core::Error::DatabaseNotFound(_)) => {
347
154
                self.create_database::<Admin>(ADMIN_DATABASE_NAME, true)
348
154
                    .await?;
349
            }
350
            Err(err) => return Err(Error::Core(err)),
351
        }
352
170
        Ok(())
353
170
    }
354

            
355
    /// Returns the unique id of the server.
356
    ///
357
    /// This value is set from the [`StorageConfiguration`] or randomly
358
    /// generated when creating a server. It shouldn't be changed after a server
359
    /// is in use, as doing can cause issues. For example, the vault that
360
    /// manages encrypted storage uses the server ID to store the vault key. If
361
    /// the server ID changes, the vault key storage will need to be updated
362
    /// with the new server ID.
363
    #[must_use]
364
    pub fn unique_id(&self) -> StorageId {
365
        self.data.id
366
    }
367

            
368
    #[must_use]
369
    #[cfg(feature = "encryption")]
370
1401375
    pub(crate) fn vault(&self) -> &Arc<Vault> {
371
1401375
        &self.data.vault
372
1401375
    }
373

            
374
    #[must_use]
375
    #[cfg(any(feature = "encryption", feature = "compression"))]
376
2438559
    pub(crate) fn tree_vault(&self) -> Option<&TreeVault> {
377
2438559
        self.data.tree_vault.as_ref()
378
2438559
    }
379

            
380
    #[must_use]
381
    #[cfg(feature = "encryption")]
382
2393993
    pub(crate) fn default_encryption_key(&self) -> Option<&KeyId> {
383
2393993
        self.data.default_encryption_key.as_ref()
384
2393993
    }
385

            
386
    #[must_use]
387
    #[cfg(all(feature = "compression", not(feature = "encryption")))]
388
    #[allow(clippy::unused_self)]
389
    pub(crate) fn default_encryption_key(&self) -> Option<&KeyId> {
390
        None
391
    }
392

            
393
    /// Registers a schema for use within the server.
394
170
    pub async fn register_schema<DB: Schema>(&self) -> Result<(), Error> {
395
170
        let mut schemas = fast_async_write!(self.data.schemas);
396
        if schemas
397
            .insert(
398
170
                DB::schema_name(),
399
170
                Arc::new(StorageSchemaOpener::<DB>::new()?),
400
            )
401
170
            .is_none()
402
        {
403
170
            Ok(())
404
        } else {
405
            Err(Error::Core(bonsaidb_core::Error::SchemaAlreadyRegistered(
406
                DB::schema_name(),
407
            )))
408
        }
409
170
    }
410

            
411
    #[cfg_attr(
412
        not(any(feature = "encryption", feature = "compression")),
413
        allow(unused_mut)
414
    )]
415
1317963
    pub(crate) async fn open_roots(&self, name: &str) -> Result<Context, Error> {
416
1317963
        let mut open_roots = fast_async_lock!(self.data.open_roots);
417
1317963
        if let Some(roots) = open_roots.get(name) {
418
1297125
            Ok(roots.clone())
419
        } else {
420
20838
            let task_self = self.clone();
421
20838
            let task_name = name.to_string();
422
20838
            let roots = tokio::task::spawn_blocking(move || {
423
20838
                let mut config = nebari::Config::new(task_self.data.path.join(task_name))
424
20838
                    .file_manager(task_self.data.file_manager.clone())
425
20838
                    .cache(task_self.data.chunk_cache.clone())
426
20838
                    .shared_thread_pool(&task_self.data.threadpool);
427

            
428
                #[cfg(any(feature = "encryption", feature = "compression"))]
429
20838
                if let Some(vault) = task_self.data.tree_vault.clone() {
430
4495
                    config = config.vault(vault);
431
16445
                }
432

            
433
20838
                config.open().map_err(Error::from)
434
20838
            })
435
20662
            .await
436
20838
            .unwrap()?;
437
20838
            let context = Context::new(roots, self.data.key_value_persistence.clone());
438
20838

            
439
20838
            open_roots.insert(name.to_owned(), context.clone());
440
20838

            
441
20838
            Ok(context)
442
        }
443
1317963
    }
444

            
445
1747447
    pub(crate) fn tasks(&self) -> &'_ TaskManager {
446
1747447
        &self.data.tasks
447
1747447
    }
448

            
449
1317938
    pub(crate) fn check_view_integrity_on_database_open(&self) -> bool {
450
1317938
        self.data.check_view_integrity_on_database_open
451
1317938
    }
452

            
453
332
    pub(crate) fn relay(&self) -> &'_ Relay {
454
332
        &self.data.relay
455
332
    }
456

            
457
21250
    fn validate_name(name: &str) -> Result<(), Error> {
458
21250
        if name.chars().enumerate().all(|(index, c)| {
459
147238
            c.is_ascii_alphanumeric()
460
6912
                || (index == 0 && c == '_')
461
2595
                || (index > 0 && (c == '.' || c == '-'))
462
147238
        }) {
463
21146
            Ok(())
464
        } else {
465
104
            Err(Error::Core(bonsaidb_core::Error::InvalidDatabaseName(
466
104
                name.to_owned(),
467
104
            )))
468
        }
469
21250
    }
470

            
471
    /// Returns the administration database.
472
    #[allow(clippy::missing_panics_doc)]
473
37733
    pub async fn admin(&self) -> Database {
474
        Database::new::<Admin, _>(
475
            ADMIN_DATABASE_NAME,
476
37758
            self.open_roots(ADMIN_DATABASE_NAME).await.unwrap(),
477
37758
            self.clone(),
478
        )
479
        .await
480
37758
        .unwrap()
481
37758
    }
482

            
483
    /// Opens a database through a generic-free trait.
484
1282495
    pub(crate) async fn database_without_schema(&self, name: &str) -> Result<Database, Error> {
485
51285
        let schema = {
486
51439
            let available_databases = fast_async_read!(self.data.available_databases);
487
51439
            available_databases
488
51439
                .get(name)
489
51439
                .ok_or_else(|| {
490
154
                    Error::Core(bonsaidb_core::Error::DatabaseNotFound(name.to_string()))
491
51439
                })?
492
51285
                .clone()
493
        };
494

            
495
51285
        let mut schemas = fast_async_write!(self.data.schemas);
496
51285
        if let Some(schema) = schemas.get_mut(&schema) {
497
51285
            let db = schema.open(name.to_string(), self.clone()).await?;
498
51285
            Ok(db)
499
        } else {
500
            Err(Error::Core(bonsaidb_core::Error::SchemaNotRegistered(
501
                schema,
502
            )))
503
        }
504
51439
    }
505

            
506
    #[cfg(feature = "internal-apis")]
507
    #[doc(hidden)]
508
    /// Opens a database through a generic-free trait.
509
1270350
    pub async fn database_without_schema_internal(&self, name: &str) -> Result<Database, Error> {
510
94665
        self.database_without_schema(name).await
511
50814
    }
512

            
513
    #[cfg(feature = "multiuser")]
514
34
    async fn update_user_with_named_id<
515
34
        'user,
516
34
        'other,
517
34
        Col: NamedCollection<PrimaryKey = u64>,
518
34
        U: Nameable<'user, u64> + Send + Sync,
519
34
        O: Nameable<'other, u64> + Send + Sync,
520
34
        F: FnOnce(&mut CollectionDocument<User>, u64) -> bool,
521
34
    >(
522
34
        &self,
523
34
        user: U,
524
34
        other: O,
525
34
        callback: F,
526
34
    ) -> Result<(), bonsaidb_core::Error> {
527
34
        let admin = self.admin().await;
528
34
        let other = other.name()?;
529
34
        let (user, other) =
530
34
            futures::try_join!(User::load(user.name()?, &admin), other.id::<Col, _>(&admin),)?;
531
34
        match (user, other) {
532
34
            (Some(mut user), Some(other)) => {
533
34
                if callback(&mut user, other) {
534
17
                    user.update(&admin).await?;
535
17
                }
536
34
                Ok(())
537
            }
538
            // TODO make this a generic not found with a name parameter.
539
            _ => Err(bonsaidb_core::Error::UserNotFound),
540
        }
541
34
    }
542
}
543

            
544
#[async_trait]
545
pub trait DatabaseOpener: Send + Sync + Debug {
546
    fn schematic(&self) -> &'_ Schematic;
547
    async fn open(&self, name: String, storage: Storage) -> Result<Database, Error>;
548
}
549

            
550
#[derive(Debug)]
551
pub struct StorageSchemaOpener<DB: Schema> {
552
    schematic: Schematic,
553
    _phantom: PhantomData<DB>,
554
}
555

            
556
impl<DB> StorageSchemaOpener<DB>
557
where
558
    DB: Schema,
559
{
560
419
    pub fn new() -> Result<Self, Error> {
561
419
        let schematic = DB::schematic()?;
562
419
        Ok(Self {
563
419
            schematic,
564
419
            _phantom: PhantomData::default(),
565
419
        })
566
419
    }
567
}
568

            
569
#[async_trait]
570
impl<DB> DatabaseOpener for StorageSchemaOpener<DB>
571
where
572
    DB: Schema,
573
{
574
    fn schematic(&self) -> &'_ Schematic {
575
        &self.schematic
576
    }
577

            
578
51285
    async fn open(&self, name: String, storage: Storage) -> Result<Database, Error> {
579
51285
        let roots = storage.open_roots(&name).await?;
580
51285
        let db = Database::new::<DB, _>(name, roots, storage).await?;
581
51285
        Ok(db)
582
102570
    }
583
}
584

            
585
#[async_trait]
586
impl StorageConnection for Storage {
587
    type Database = Database;
588

            
589
    #[cfg_attr(
590
        feature = "tracing",
591
63738
        tracing::instrument(skip(name, schema, only_if_needed))
592
    )]
593
    async fn create_database_with_schema(
594
        &self,
595
        name: &str,
596
        schema: SchemaName,
597
        only_if_needed: bool,
598
21246
    ) -> Result<(), bonsaidb_core::Error> {
599
21246
        Self::validate_name(name)?;
600

            
601
        {
602
21169
            let schemas = fast_async_read!(self.data.schemas);
603
21169
            if !schemas.contains_key(&schema) {
604
77
                return Err(bonsaidb_core::Error::SchemaNotRegistered(schema));
605
21092
            }
606
        }
607

            
608
21092
        let mut available_databases = fast_async_write!(self.data.available_databases);
609
21092
        let admin = self.admin().await;
610
21092
        if !admin
611
21092
            .view::<database::ByName>()
612
21092
            .with_key(name.to_ascii_lowercase())
613
21092
            .query()
614
18172
            .await?
615
21092
            .is_empty()
616
        {
617
610
            if only_if_needed {
618
533
                return Ok(());
619
77
            }
620
77

            
621
77
            return Err(bonsaidb_core::Error::DatabaseNameAlreadyTaken(
622
77
                name.to_string(),
623
77
            ));
624
20482
        }
625
20482

            
626
20482
        admin
627
20482
            .collection::<DatabaseRecord>()
628
20482
            .push(&admin::Database {
629
20482
                name: name.to_string(),
630
20482
                schema: schema.clone(),
631
20482
            })
632
19682
            .await?;
633
20482
        available_databases.insert(name.to_string(), schema);
634
20482

            
635
20482
        Ok(())
636
42492
    }
637

            
638
614
    async fn database<DB: Schema>(
639
614
        &self,
640
614
        name: &str,
641
614
    ) -> Result<Self::Database, bonsaidb_core::Error> {
642
614
        let db = self.database_without_schema(name).await?;
643
460
        if db.data.schema.name == DB::schema_name() {
644
460
            Ok(db)
645
        } else {
646
            Err(bonsaidb_core::Error::SchemaMismatch {
647
                database_name: name.to_owned(),
648
                schema: DB::schema_name(),
649
                stored_schema: db.data.schema.name.clone(),
650
            })
651
        }
652
1228
    }
653

            
654
37962
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(name)))]
655
12654
    async fn delete_database(&self, name: &str) -> Result<(), bonsaidb_core::Error> {
656
12654
        let admin = self.admin().await;
657
12654
        let mut available_databases = fast_async_write!(self.data.available_databases);
658
12654
        available_databases.remove(name);
659

            
660
12654
        let mut open_roots = fast_async_lock!(self.data.open_roots);
661
12654
        open_roots.remove(name);
662
12654

            
663
12654
        let database_folder = self.path().join(name);
664
12654
        if database_folder.exists() {
665
12500
            let file_manager = self.data.file_manager.clone();
666
12500
            tokio::task::spawn_blocking(move || file_manager.delete_directory(&database_folder))
667
9375
                .await
668
12500
                .unwrap()
669
12500
                .map_err(Error::Nebari)?;
670
154
        }
671

            
672
12654
        if let Some(entry) = admin
673
12654
            .view::<database::ByName>()
674
12654
            .with_key(name.to_ascii_lowercase())
675
12654
            .query()
676
12654
            .await?
677
12654
            .first()
678
        {
679
12577
            admin.delete::<DatabaseRecord, _>(&entry.source).await?;
680

            
681
12577
            Ok(())
682
        } else {
683
77
            return Err(bonsaidb_core::Error::DatabaseNotFound(name.to_string()));
684
        }
685
25308
    }
686

            
687
231
    #[cfg_attr(feature = "tracing", tracing::instrument)]
688
77
    async fn list_databases(&self) -> Result<Vec<connection::Database>, bonsaidb_core::Error> {
689
77
        let available_databases = fast_async_read!(self.data.available_databases);
690
77
        Ok(available_databases
691
77
            .iter()
692
1929
            .map(|(name, schema)| connection::Database {
693
1929
                name: name.to_string(),
694
1929
                schema: schema.clone(),
695
1929
            })
696
77
            .collect())
697
154
    }
698

            
699
231
    #[cfg_attr(feature = "tracing", tracing::instrument)]
700
77
    async fn list_available_schemas(&self) -> Result<Vec<SchemaName>, bonsaidb_core::Error> {
701
77
        let available_databases = fast_async_read!(self.data.available_databases);
702
77
        Ok(available_databases.values().unique().cloned().collect())
703
154
    }
704

            
705
531
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(username)))]
706
    #[cfg(feature = "multiuser")]
707
177
    async fn create_user(&self, username: &str) -> Result<u64, bonsaidb_core::Error> {
708
177
        let result = self
709
177
            .admin()
710
            .await
711
177
            .collection::<User>()
712
304
            .push(&User::default_with_username(username))
713
304
            .await?;
714
152
        Ok(result.id)
715
354
    }
716

            
717
    #[cfg(feature = "password-hashing")]
718
9
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user, password)))]
719
    async fn set_user_password<'user, U: Nameable<'user, u64> + Send + Sync>(
720
        &self,
721
        user: U,
722
        password: bonsaidb_core::connection::SensitiveString,
723
3
    ) -> Result<(), bonsaidb_core::Error> {
724
3
        let admin = self.admin().await;
725
5
        let mut user = User::load(user, &admin)
726
5
            .await?
727
3
            .ok_or(bonsaidb_core::Error::UserNotFound)?;
728
3
        user.contents.argon_hash = Some(self.data.argon.hash(user.header.id, password).await?);
729
3
        user.update(&admin).await
730
6
    }
731

            
732
    #[cfg(all(feature = "multiuser", feature = "password-hashing"))]
733
15
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user)))]
734
    async fn authenticate<'user, U: Nameable<'user, u64> + Send + Sync>(
735
        &self,
736
        user: U,
737
        authentication: Authentication,
738
5
    ) -> Result<Authenticated, bonsaidb_core::Error> {
739
5
        let admin = self.admin().await;
740
5
        let user = User::load(user, &admin)
741
            .await?
742
5
            .ok_or(bonsaidb_core::Error::InvalidCredentials)?;
743
5
        match authentication {
744
5
            Authentication::Password(password) => {
745
5
                let saved_hash = user
746
5
                    .contents
747
5
                    .argon_hash
748
5
                    .clone()
749
5
                    .ok_or(bonsaidb_core::Error::InvalidCredentials)?;
750

            
751
5
                self.data
752
5
                    .argon
753
5
                    .verify(user.header.id, password, saved_hash)
754
5
                    .await?;
755
5
                let permissions = user.contents.effective_permissions(&admin).await?;
756
5
                Ok(Authenticated {
757
5
                    user_id: user.header.id,
758
5
                    permissions,
759
5
                })
760
            }
761
        }
762
10
    }
763

            
764
30
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user, permission_group)))]
765
    #[cfg(feature = "multiuser")]
766
    async fn add_permission_group_to_user<
767
        'user,
768
        'group,
769
        U: Nameable<'user, u64> + Send + Sync,
770
        G: Nameable<'group, u64> + Send + Sync,
771
    >(
772
        &self,
773
        user: U,
774
        permission_group: G,
775
10
    ) -> Result<(), bonsaidb_core::Error> {
776
10
        self.update_user_with_named_id::<PermissionGroup, _, _, _>(
777
10
            user,
778
10
            permission_group,
779
10
            |user, permission_group_id| {
780
10
                if user.contents.groups.contains(&permission_group_id) {
781
5
                    false
782
                } else {
783
5
                    user.contents.groups.push(permission_group_id);
784
5
                    true
785
                }
786
10
            },
787
13
        )
788
13
        .await
789
20
    }
790

            
791
24
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user, permission_group)))]
792
    #[cfg(feature = "multiuser")]
793
    async fn remove_permission_group_from_user<
794
        'user,
795
        'group,
796
        U: Nameable<'user, u64> + Send + Sync,
797
        G: Nameable<'group, u64> + Send + Sync,
798
    >(
799
        &self,
800
        user: U,
801
        permission_group: G,
802
8
    ) -> Result<(), bonsaidb_core::Error> {
803
8
        self.update_user_with_named_id::<PermissionGroup, _, _, _>(
804
8
            user,
805
8
            permission_group,
806
8
            |user, permission_group_id| {
807
8
                let old_len = user.contents.groups.len();
808
8
                user.contents.groups.retain(|id| id != &permission_group_id);
809
8
                old_len != user.contents.groups.len()
810
8
            },
811
8
        )
812
7
        .await
813
16
    }
814

            
815
24
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user, role)))]
816
    #[cfg(feature = "multiuser")]
817
    async fn add_role_to_user<
818
        'user,
819
        'group,
820
        U: Nameable<'user, u64> + Send + Sync,
821
        G: Nameable<'group, u64> + Send + Sync,
822
    >(
823
        &self,
824
        user: U,
825
        role: G,
826
8
    ) -> Result<(), bonsaidb_core::Error> {
827
8
        self.update_user_with_named_id::<PermissionGroup, _, _, _>(user, role, |user, role_id| {
828
8
            if user.contents.roles.contains(&role_id) {
829
4
                false
830
            } else {
831
4
                user.contents.roles.push(role_id);
832
4
                true
833
            }
834
11
        })
835
11
        .await
836
16
    }
837

            
838
24
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user, role)))]
839
    #[cfg(feature = "multiuser")]
840
    async fn remove_role_from_user<
841
        'user,
842
        'group,
843
        U: Nameable<'user, u64> + Send + Sync,
844
        G: Nameable<'group, u64> + Send + Sync,
845
    >(
846
        &self,
847
        user: U,
848
        role: G,
849
8
    ) -> Result<(), bonsaidb_core::Error> {
850
8
        self.update_user_with_named_id::<Role, _, _, _>(user, role, |user, role_id| {
851
8
            let old_len = user.contents.roles.len();
852
8
            user.contents.roles.retain(|id| id != &role_id);
853
8
            old_len != user.contents.roles.len()
854
8
        })
855
7
        .await
856
16
    }
857
}
858

            
859
1
#[test]
860
1
fn name_validation_tests() {
861
1
    assert!(matches!(Storage::validate_name("azAZ09.-"), Ok(())));
862
1
    assert!(matches!(
863
1
        Storage::validate_name("_internal-names-work"),
864
        Ok(())
865
    ));
866
1
    assert!(matches!(
867
1
        Storage::validate_name("-alphaunmericfirstrequired"),
868
        Err(Error::Core(bonsaidb_core::Error::InvalidDatabaseName(_)))
869
    ));
870
1
    assert!(matches!(
871
1
        Storage::validate_name("\u{2661}"),
872
        Err(Error::Core(bonsaidb_core::Error::InvalidDatabaseName(_)))
873
    ));
874
1
}
875

            
876
/// The unique id of a [`Storage`] instance.
877
#[derive(Clone, Copy, Eq, PartialEq, Hash)]
878
pub struct StorageId(u64);
879

            
880
impl StorageId {
881
    /// Returns the id as a u64.
882
    #[must_use]
883
    pub const fn as_u64(self) -> u64 {
884
        self.0
885
    }
886
}
887

            
888
impl Debug for StorageId {
889
4910
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
890
4910
        // let formatted_length = format!();
891
4910
        write!(f, "{:016x}", self.0)
892
4910
    }
893
}
894

            
895
impl Display for StorageId {
896
4910
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
897
4910
        Debug::fmt(self, f)
898
4910
    }
899
}
900

            
901
1195606
#[derive(Debug, Clone)]
902
#[cfg(any(feature = "compression", feature = "encryption"))]
903
pub(crate) struct TreeVault {
904
    #[cfg(feature = "compression")]
905
    compression: Option<Compression>,
906
    #[cfg(feature = "encryption")]
907
    pub key: Option<KeyId>,
908
    #[cfg(feature = "encryption")]
909
    pub vault: Arc<Vault>,
910
}
911

            
912
#[cfg(all(feature = "compression", feature = "encryption"))]
913
impl TreeVault {
914
1403994
    pub(crate) fn new_if_needed(
915
1403994
        key: Option<KeyId>,
916
1403994
        vault: &Arc<Vault>,
917
1403994
        compression: Option<Compression>,
918
1403994
    ) -> Option<Self> {
919
1403994
        if key.is_none() && compression.is_none() {
920
1377876
            None
921
        } else {
922
26143
            Some(Self {
923
26143
                key,
924
26143
                compression,
925
26143
                vault: vault.clone(),
926
26143
            })
927
        }
928
1404019
    }
929

            
930
3453176
    fn header(&self, compressed: bool) -> u8 {
931
3453176
        let mut bits = if self.key.is_some() { 0b1000_0000 } else { 0 };
932

            
933
3453176
        if compressed {
934
1463615
            if let Some(compression) = self.compression {
935
1463615
                bits |= compression as u8;
936
1463615
            }
937
1989561
        }
938

            
939
3453176
        bits
940
3453176
    }
941
}
942

            
943
#[cfg(all(feature = "compression", feature = "encryption"))]
944
impl nebari::Vault for TreeVault {
945
    type Error = Error;
946

            
947
3452875
    fn encrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
948
3452875
        use std::borrow::Cow;
949
3452875

            
950
3452875
        // TODO this allocates too much. The vault should be able to do an
951
3452875
        // in-place encryption operation so that we can use a single buffer.
952
3452875
        let mut includes_compression = false;
953
3452875
        let compressed = match (payload.len(), self.compression) {
954
1465992
            (128..=usize::MAX, Some(Compression::Lz4)) => {
955
1463516
                includes_compression = true;
956
1463516
                Cow::Owned(lz4_flex::block::compress_prepend_size(payload))
957
            }
958
1989359
            _ => Cow::Borrowed(payload),
959
        };
960

            
961
3452875
        let mut complete = if let Some(key) = &self.key {
962
32499
            self.vault.encrypt_payload(key, &compressed, None)?
963
        } else {
964
3420376
            compressed.into_owned()
965
        };
966

            
967
3452875
        let header = self.header(includes_compression);
968
3452875
        if header != 0 {
969
1493663
            let header = [b't', b'r', b'v', header];
970
1493663
            complete.splice(0..0, header);
971
1959596
        }
972

            
973
3453259
        Ok(complete)
974
3453259
    }
975

            
976
414646
    fn decrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
977
414646
        use std::borrow::Cow;
978
414646

            
979
414646
        if payload.len() >= 4 && &payload[0..3] == b"trv" {
980
79842
            let header = payload[3];
981
79842
            let payload = &payload[4..];
982
79842
            let encrypted = (header & 0b1000_0000) != 0;
983
79842
            let compression = header & 0b0111_1111;
984
79842
            let decrypted = if encrypted {
985
8228
                Cow::Owned(self.vault.decrypt_payload(payload, None)?)
986
            } else {
987
71614
                Cow::Borrowed(payload)
988
            };
989
            #[allow(clippy::single_match)] // Make it an error when we add a new algorithm
990
79841
            return Ok(match Compression::from_u8(compression) {
991
                Some(Compression::Lz4) => {
992
73932
                    lz4_flex::block::decompress_size_prepended(&decrypted).map_err(Error::from)?
993
                }
994
5909
                None => decrypted.into_owned(),
995
            });
996
334779
        }
997
334779
        self.vault.decrypt_payload(payload, None)
998
414621
    }
999
}

            
#[cfg(all(feature = "compression", not(feature = "encryption")))]
impl TreeVault {
    pub(crate) fn new_if_needed(compression: Option<Compression>) -> Option<Self> {
        compression.map(|compression| Self {
            compression: Some(compression),
        })
    }
}

            
#[cfg(all(feature = "compression", not(feature = "encryption")))]
impl nebari::Vault for TreeVault {
    type Error = Error;

            
    fn encrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
        Ok(match self.compression {
            Some(Compression::Lz4) => {
                let mut destination =
                    vec![0; lz4_flex::block::get_maximum_output_size(payload.len()) + 8];
                let compressed_length =
                    lz4_flex::block::compress_into(payload, &mut destination[8..])
                        .expect("lz4-flex documents this shouldn't fail");
                destination.truncate(compressed_length + 8);
                destination[0..4].copy_from_slice(&[b't', b'r', b'v', Compression::Lz4 as u8]);
                // to_le_bytes() makes it compatible with lz4-flex decompress_size_prepended.
                let uncompressed_length =
                    u32::try_from(payload.len()).expect("nebari doesn't support >32 bit blocks");
                destination[4..8].copy_from_slice(&uncompressed_length.to_le_bytes());
                destination
            }
            // TODO this shouldn't copy
            None => payload.to_vec(),
        })
    }

            
    fn decrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
        if payload.len() >= 4 && &payload[0..3] == b"trv" {
            let header = payload[3];
            let payload = &payload[4..];
            let encrypted = (header & 0b1000_0000) != 0;
            let compression = header & 0b0111_1111;
            if encrypted {
                return Err(Error::EncryptionDisabled);
            }

            
            #[allow(clippy::single_match)] // Make it an error when we add a new algorithm
            return Ok(match Compression::from_u8(compression) {
                Some(Compression::Lz4) => {
                    lz4_flex::block::decompress_size_prepended(payload).map_err(Error::from)?
                }
                None => payload.to_vec(),
            });
        }
        Ok(payload.to_vec())
    }
}

            
#[cfg(all(not(feature = "compression"), feature = "encryption"))]
impl TreeVault {
    pub(crate) fn new_if_needed(key: Option<KeyId>, vault: &Arc<Vault>) -> Option<Self> {
        key.map(|key| Self {
            key: Some(key),
            vault: vault.clone(),
        })
    }

            
    #[allow(dead_code)] // This implementation is sort of documentation for what it would be. But our Vault payload already can detect if a parsing error occurs, so we don't need a header if only encryption is enabled.
    fn header(&self) -> u8 {
        if self.key.is_some() {
            0b1000_0000
        } else {
            0
        }
    }
}

            
#[cfg(all(not(feature = "compression"), feature = "encryption"))]
impl nebari::Vault for TreeVault {
    type Error = Error;

            
    fn encrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
        if let Some(key) = &self.key {
            self.vault.encrypt_payload(key, payload, None)
        } else {
            // TODO does this need to copy?
            Ok(payload.to_vec())
        }
    }

            
    fn decrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
        self.vault.decrypt_payload(payload, None)
    }
}