1
use std::{
2
    collections::HashMap,
3
    fmt::{Debug, Display},
4
    marker::PhantomData,
5
    path::{Path, PathBuf},
6
    sync::Arc,
7
};
8

            
9
use async_lock::{Mutex, RwLock};
10
use async_trait::async_trait;
11
pub use bonsaidb_core::circulate::Relay;
12
#[cfg(feature = "password-hashing")]
13
use bonsaidb_core::connection::{Authenticated, Authentication};
14
#[cfg(any(feature = "encryption", feature = "compression"))]
15
use bonsaidb_core::document::KeyId;
16
use bonsaidb_core::{
17
    admin::{
18
        self,
19
        database::{self, ByName, Database as DatabaseRecord},
20
        Admin, ADMIN_DATABASE_NAME,
21
    },
22
    connection::{self, Connection, StorageConnection},
23
    schema::{Schema, SchemaName, Schematic},
24
};
25
#[cfg(feature = "multiuser")]
26
use bonsaidb_core::{
27
    admin::{user::User, PermissionGroup, Role},
28
    document::CollectionDocument,
29
    schema::{Nameable, NamedCollection},
30
};
31
use bonsaidb_utils::{fast_async_lock, fast_async_read, fast_async_write};
32
use futures::TryFutureExt;
33
use itertools::Itertools;
34
use nebari::{
35
    io::{
36
        any::{AnyFile, AnyFileManager},
37
        FileManager,
38
    },
39
    ChunkCache, ThreadPool,
40
};
41
use rand::{thread_rng, Rng};
42
use tokio::{
43
    fs::{self, File},
44
    io::{AsyncReadExt, AsyncWriteExt},
45
};
46

            
47
#[cfg(feature = "compression")]
48
use crate::config::Compression;
49
#[cfg(feature = "encryption")]
50
use crate::vault::{self, LocalVaultKeyStorage, Vault};
51
use crate::{
52
    config::{KeyValuePersistence, StorageConfiguration},
53
    database::Context,
54
    tasks::{manager::Manager, TaskManager},
55
    Database, Error,
56
};
57

            
58
#[cfg(feature = "password-hashing")]
59
mod argon;
60

            
61
mod backup;
62
pub use backup::BackupLocation;
63

            
64
/// A file-based, multi-database, multi-user database engine.
65
///
66
/// ## Converting from `Database::open` to `Storage::open`
67
///
68
/// [`Database::open`](Database::open) is a simple method that uses `Storage` to
69
/// create a database named `default` with the schema provided. These two ways
70
/// of opening the database are the same:
71
///
72
/// ```rust
73
/// // `bonsaidb_core` is re-exported to `bonsaidb::core` or `bonsaidb_local::core`.
74
/// use bonsaidb_core::{connection::StorageConnection, schema::Schema};
75
/// // `bonsaidb_local` is re-exported to `bonsaidb::local` if using the omnibus crate.
76
/// use bonsaidb_local::{
77
///     config::{Builder, StorageConfiguration},
78
///     Database, Storage,
79
/// };
80
/// # async fn open<MySchema: Schema>() -> anyhow::Result<()> {
81
/// // This creates a Storage instance, creates a database, and returns it.
82
/// let db = Database::open::<MySchema>(StorageConfiguration::new("my-db.bonsaidb")).await?;
83
///
84
/// // This is the equivalent code being executed:
85
/// let storage =
86
///     Storage::open(StorageConfiguration::new("my-db.bonsaidb").with_schema::<MySchema>()?)
87
///         .await?;
88
/// storage.create_database::<MySchema>("default", true).await?;
89
/// let db = storage.database::<MySchema>("default").await?;
90
/// #     Ok(())
91
/// # }
92
/// ```
93
///
94
/// ## Using multiple databases
95
///
96
/// This example shows how to use `Storage` to create and use multiple databases
97
/// with multiple schemas:
98
///
99
/// ```rust
100
/// use bonsaidb_core::{
101
///     connection::StorageConnection,
102
///     schema::{Collection, Schema},
103
/// };
104
/// use bonsaidb_local::{
105
///     config::{Builder, StorageConfiguration},
106
///     Storage,
107
/// };
108
/// use serde::{Deserialize, Serialize};
109
///
110
/// #[derive(Debug, Schema)]
111
/// #[schema(name = "my-schema", collections = [BlogPost, Author])]
112
/// # #[schema(core = bonsaidb_core)]
113
/// struct MySchema;
114
///
115
/// #[derive(Debug, Serialize, Deserialize, Collection)]
116
/// #[collection(name = "blog-posts")]
117
/// # #[collection(core = bonsaidb_core)]
118
/// struct BlogPost {
119
///     pub title: String,
120
///     pub contents: String,
121
///     pub author_id: u64,
122
/// }
123
///
124
/// #[derive(Debug, Serialize, Deserialize, Collection)]
125
/// #[collection(name = "blog-posts")]
126
/// # #[collection(core = bonsaidb_core)]
127
/// struct Author {
128
///     pub name: String,
129
/// }
130
///
131
/// # async fn test_fn() -> Result<(), bonsaidb_core::Error> {
132
/// let storage = Storage::open(
133
///     StorageConfiguration::new("my-db.bonsaidb")
134
///         .with_schema::<BlogPost>()?
135
///         .with_schema::<MySchema>()?,
136
/// )
137
/// .await?;
138
///
139
/// storage
140
///     .create_database::<BlogPost>("ectons-blog", true)
141
///     .await?;
142
/// let ectons_blog = storage.database::<BlogPost>("ectons-blog").await?;
143
/// storage
144
///     .create_database::<MySchema>("another-db", true)
145
///     .await?;
146
/// let another_db = storage.database::<MySchema>("another-db").await?;
147
///
148
/// #     Ok(())
149
/// # }
150
/// ```
151
204148
#[derive(Debug, Clone)]
152
pub struct Storage {
153
    data: Arc<Data>,
154
}
155

            
156
#[derive(Debug)]
157
struct Data {
158
    id: StorageId,
159
    path: PathBuf,
160
    parallelization: usize,
161
    threadpool: ThreadPool<AnyFile>,
162
    file_manager: AnyFileManager,
163
    pub(crate) tasks: TaskManager,
164
    schemas: RwLock<HashMap<SchemaName, Arc<dyn DatabaseOpener>>>,
165
    available_databases: RwLock<HashMap<String, SchemaName>>,
166
    open_roots: Mutex<HashMap<String, Context>>,
167
    #[cfg(feature = "password-hashing")]
168
    argon: argon::Hasher,
169
    #[cfg(feature = "encryption")]
170
    pub(crate) vault: Arc<Vault>,
171
    #[cfg(feature = "encryption")]
172
    default_encryption_key: Option<KeyId>,
173
    #[cfg(any(feature = "compression", feature = "encryption"))]
174
    tree_vault: Option<TreeVault>,
175
    pub(crate) key_value_persistence: KeyValuePersistence,
176
    chunk_cache: ChunkCache,
177
    pub(crate) check_view_integrity_on_database_open: bool,
178
    relay: Relay,
179
}
180

            
181
impl Storage {
182
    /// Creates or opens a multi-database [`Storage`] with its data stored in `directory`.
183
2699
    pub async fn open(configuration: StorageConfiguration) -> Result<Self, Error> {
184
175
        let owned_path = configuration
185
175
            .path
186
175
            .clone()
187
175
            .unwrap_or_else(|| PathBuf::from("db.bonsaidb"));
188
175
        let file_manager = if configuration.memory_only {
189
30
            AnyFileManager::memory()
190
        } else {
191
145
            AnyFileManager::std()
192
        };
193

            
194
175
        let manager = Manager::default();
195
700
        for _ in 0..configuration.workers.worker_count {
196
700
            manager.spawn_worker();
197
700
        }
198
175
        let tasks = TaskManager::new(manager);
199
175

            
200
175
        fs::create_dir_all(&owned_path).await?;
201

            
202
361
        let id = Self::lookup_or_create_id(&configuration, &owned_path).await?;
203

            
204
        #[cfg(feature = "encryption")]
205
174
        let vault = {
206
175
            let vault_key_storage = match configuration.vault_key_storage {
207
3
                Some(storage) => storage,
208
                None => Arc::new(
209
172
                    LocalVaultKeyStorage::new(owned_path.join("vault-keys"))
210
153
                        .await
211
172
                        .map_err(|err| Error::Vault(vault::Error::Initializing(err.to_string())))?,
212
                ),
213
            };
214

            
215
1320
            Arc::new(Vault::initialize(id, &owned_path, vault_key_storage).await?)
216
        };
217

            
218
174
        let parallelization = configuration.workers.parallelization;
219
174
        let check_view_integrity_on_database_open = configuration.views.check_integrity_on_open;
220
174
        let key_value_persistence = configuration.key_value_persistence;
221
174
        #[cfg(feature = "password-hashing")]
222
174
        let argon = argon::Hasher::new(configuration.argon);
223
174
        #[cfg(feature = "encryption")]
224
174
        let default_encryption_key = configuration.default_encryption_key;
225
174
        #[cfg(all(feature = "compression", feature = "encryption"))]
226
174
        let tree_vault = TreeVault::new_if_needed(
227
174
            default_encryption_key.clone(),
228
174
            &vault,
229
174
            configuration.default_compression,
230
174
        );
231
        #[cfg(all(not(feature = "compression"), feature = "encryption"))]
232
        let tree_vault = TreeVault::new_if_needed(default_encryption_key.clone(), &vault);
233
        #[cfg(all(feature = "compression", not(feature = "encryption")))]
234
        let tree_vault = TreeVault::new_if_needed(configuration.default_compression);
235

            
236
174
        let storage = tokio::task::spawn_blocking::<_, Result<Self, Error>>(move || {
237
174
            Ok(Self {
238
174
                data: Arc::new(Data {
239
174
                    id,
240
174
                    tasks,
241
174
                    parallelization,
242
174
                    #[cfg(feature = "password-hashing")]
243
174
                    argon,
244
174
                    #[cfg(feature = "encryption")]
245
174
                    vault,
246
174
                    #[cfg(feature = "encryption")]
247
174
                    default_encryption_key,
248
174
                    #[cfg(any(feature = "compression", feature = "encryption"))]
249
174
                    tree_vault,
250
174
                    path: owned_path,
251
174
                    file_manager,
252
174
                    chunk_cache: ChunkCache::new(2000, 160_384),
253
174
                    threadpool: ThreadPool::new(parallelization),
254
174
                    schemas: RwLock::new(configuration.initial_schemas),
255
174
                    available_databases: RwLock::default(),
256
174
                    open_roots: Mutex::default(),
257
174
                    key_value_persistence,
258
174
                    check_view_integrity_on_database_open,
259
174
                    relay: Relay::default(),
260
174
                }),
261
174
            })
262
174
        })
263
119
        .await??;
264

            
265
520
        storage.cache_available_databases().await?;
266

            
267
174
        storage.create_admin_database_if_needed().await?;
268

            
269
174
        Ok(storage)
270
175
    }
271

            
272
    /// Returns the path of the database storage.
273
    #[must_use]
274
17008
    pub fn path(&self) -> &Path {
275
17008
        &self.data.path
276
17008
    }
277

            
278
2725
    async fn lookup_or_create_id(
279
2725
        configuration: &StorageConfiguration,
280
2725
        path: &Path,
281
2725
    ) -> Result<StorageId, Error> {
282
175
        Ok(StorageId(if let Some(id) = configuration.unique_id {
283
            // The configuraiton id override is not persisted to disk. This is
284
            // mostly to prevent someone from accidentally adding this
285
            // configuration, realizing it breaks things, and then wanting to
286
            // revert. This makes reverting to the old value easier.
287
            id
288
        } else {
289
            // Load/Store a randomly generated id into a file. While the value
290
            // is numerical, the file contents are the ascii decimal, making it
291
            // easier for a human to view, and if needed, edit.
292
175
            let id_path = path.join("server-id");
293
175

            
294
175
            if id_path.exists() {
295
                // This value is important enought to not allow launching the
296
                // server if the file can't be read or contains unexpected data.
297
20
                let existing_id = String::from_utf8(
298
20
                    File::open(id_path)
299
20
                        .and_then(|mut f| async move {
300
20
                            let mut bytes = Vec::new();
301
38
                            f.read_to_end(&mut bytes).await.map(|_| bytes)
302
57
                        })
303
57
                        .await
304
20
                        .expect("error reading server-id file"),
305
20
                )
306
20
                .expect("server-id contains invalid data");
307
20

            
308
20
                existing_id.parse().expect("server-id isn't numeric")
309
            } else {
310
155
                let id = { thread_rng().gen::<u64>() };
311
155
                File::create(id_path)
312
155
                    .and_then(|mut file| async move {
313
155
                        let id = id.to_string();
314
155
                        file.write_all(id.as_bytes()).await?;
315
155
                        file.shutdown().await
316
304
                    })
317
304
                    .await
318
155
                    .map_err(|err| {
319
                        Error::Core(bonsaidb_core::Error::Configuration(format!(
320
                            "Error writing server-id file: {}",
321
                            err
322
                        )))
323
155
                    })?;
324
155
                id
325
            }
326
        }))
327
175
    }
328

            
329
2699
    async fn cache_available_databases(&self) -> Result<(), Error> {
330
174
        let available_databases = self
331
174
            .admin()
332
174
            .await
333
174
            .view::<ByName>()
334
346
            .query()
335
346
            .await?
336
174
            .into_iter()
337
182
            .map(|map| (map.key, map.value))
338
174
            .collect();
339
174
        let mut storage_databases = fast_async_write!(self.data.available_databases);
340
174
        *storage_databases = available_databases;
341
174
        Ok(())
342
174
    }
343

            
344
2699
    async fn create_admin_database_if_needed(&self) -> Result<(), Error> {
345
174
        self.register_schema::<Admin>().await?;
346
174
        match self.database::<Admin>(ADMIN_DATABASE_NAME).await {
347
19
            Ok(_) => {}
348
            Err(bonsaidb_core::Error::DatabaseNotFound(_)) => {
349
155
                self.create_database::<Admin>(ADMIN_DATABASE_NAME, true)
350
155
                    .await?;
351
            }
352
            Err(err) => return Err(Error::Core(err)),
353
        }
354
174
        Ok(())
355
174
    }
356

            
357
    /// Returns the unique id of the server.
358
    ///
359
    /// This value is set from the [`StorageConfiguration`] or randomly
360
    /// generated when creating a server. It shouldn't be changed after a server
361
    /// is in use, as doing can cause issues. For example, the vault that
362
    /// manages encrypted storage uses the server ID to store the vault key. If
363
    /// the server ID changes, the vault key storage will need to be updated
364
    /// with the new server ID.
365
    #[must_use]
366
    pub fn unique_id(&self) -> StorageId {
367
        self.data.id
368
    }
369

            
370
    #[must_use]
371
95620
    pub(crate) fn parallelization(&self) -> usize {
372
95620
        self.data.parallelization
373
95620
    }
374

            
375
    #[must_use]
376
    #[cfg(feature = "encryption")]
377
1392320
    pub(crate) fn vault(&self) -> &Arc<Vault> {
378
1392320
        &self.data.vault
379
1392320
    }
380

            
381
    #[must_use]
382
    #[cfg(any(feature = "encryption", feature = "compression"))]
383
2471760
    pub(crate) fn tree_vault(&self) -> Option<&TreeVault> {
384
2471760
        self.data.tree_vault.as_ref()
385
2471760
    }
386

            
387
    #[must_use]
388
    #[cfg(feature = "encryption")]
389
2416769
    pub(crate) fn default_encryption_key(&self) -> Option<&KeyId> {
390
2416769
        self.data.default_encryption_key.as_ref()
391
2416769
    }
392

            
393
    #[must_use]
394
    #[cfg(all(feature = "compression", not(feature = "encryption")))]
395
    #[allow(clippy::unused_self)]
396
    pub(crate) fn default_encryption_key(&self) -> Option<&KeyId> {
397
        None
398
    }
399

            
400
    /// Registers a schema for use within the server.
401
174
    pub async fn register_schema<DB: Schema>(&self) -> Result<(), Error> {
402
174
        let mut schemas = fast_async_write!(self.data.schemas);
403
        if schemas
404
            .insert(
405
174
                DB::schema_name(),
406
174
                Arc::new(StorageSchemaOpener::<DB>::new()?),
407
            )
408
174
            .is_none()
409
        {
410
174
            Ok(())
411
        } else {
412
            Err(Error::Core(bonsaidb_core::Error::SchemaAlreadyRegistered(
413
                DB::schema_name(),
414
            )))
415
        }
416
174
    }
417

            
418
    #[cfg_attr(
419
        not(any(feature = "encryption", feature = "compression")),
420
        allow(unused_mut)
421
    )]
422
1386229
    pub(crate) async fn open_roots(&self, name: &str) -> Result<Context, Error> {
423
1386229
        let mut open_roots = fast_async_lock!(self.data.open_roots);
424
1386229
        if let Some(roots) = open_roots.get(name) {
425
1364543
            Ok(roots.clone())
426
        } else {
427
21686
            let task_self = self.clone();
428
21686
            let task_name = name.to_string();
429
21686
            let roots = tokio::task::spawn_blocking(move || {
430
21686
                let mut config = nebari::Config::new(task_self.data.path.join(task_name))
431
21686
                    .file_manager(task_self.data.file_manager.clone())
432
21686
                    .cache(task_self.data.chunk_cache.clone())
433
21686
                    .shared_thread_pool(&task_self.data.threadpool);
434

            
435
                #[cfg(any(feature = "encryption", feature = "compression"))]
436
21686
                if let Some(vault) = task_self.data.tree_vault.clone() {
437
4676
                    config = config.vault(vault);
438
17104
                }
439

            
440
21686
                config.open().map_err(Error::from)
441
21686
            })
442
21634
            .await
443
21686
            .unwrap()?;
444
21686
            let context = Context::new(roots, self.data.key_value_persistence.clone());
445
21686

            
446
21686
            open_roots.insert(name.to_owned(), context.clone());
447
21686

            
448
21686
            Ok(context)
449
        }
450
1386229
    }
451

            
452
1850467
    pub(crate) fn tasks(&self) -> &'_ TaskManager {
453
1850467
        &self.data.tasks
454
1850467
    }
455

            
456
1386229
    pub(crate) fn check_view_integrity_on_database_open(&self) -> bool {
457
1386229
        self.data.check_view_integrity_on_database_open
458
1386229
    }
459

            
460
344
    pub(crate) fn relay(&self) -> &'_ Relay {
461
344
        &self.data.relay
462
344
    }
463

            
464
22099
    fn validate_name(name: &str) -> Result<(), Error> {
465
22099
        if name.chars().enumerate().all(|(index, c)| {
466
153160
            c.is_ascii_alphanumeric()
467
7216
                || (index == 0 && c == '_')
468
2702
                || (index > 0 && (c == '.' || c == '-'))
469
153160
        }) {
470
22017
            Ok(())
471
        } else {
472
82
            Err(Error::Core(bonsaidb_core::Error::InvalidDatabaseName(
473
82
                name.to_owned(),
474
82
            )))
475
        }
476
22099
    }
477

            
478
    /// Returns the administration database.
479
    #[allow(clippy::missing_panics_doc)]
480
39322
    pub async fn admin(&self) -> Database {
481
        Database::new::<Admin, _>(
482
            ADMIN_DATABASE_NAME,
483
39322
            self.open_roots(ADMIN_DATABASE_NAME).await.unwrap(),
484
39322
            self.clone(),
485
        )
486
        .await
487
39322
        .unwrap()
488
39322
    }
489

            
490
    /// Opens a database through a generic-free trait.
491
1349287
    pub(crate) async fn database_without_schema(&self, name: &str) -> Result<Database, Error> {
492
51907
        let schema = {
493
52062
            let available_databases = fast_async_read!(self.data.available_databases);
494
52062
            available_databases
495
52062
                .get(name)
496
52062
                .ok_or_else(|| {
497
155
                    Error::Core(bonsaidb_core::Error::DatabaseNotFound(name.to_string()))
498
52062
                })?
499
51907
                .clone()
500
        };
501

            
502
51907
        let mut schemas = fast_async_write!(self.data.schemas);
503
51907
        if let Some(schema) = schemas.get_mut(&schema) {
504
51907
            let db = schema.open(name.to_string(), self.clone()).await?;
505
51907
            Ok(db)
506
        } else {
507
            Err(Error::Core(bonsaidb_core::Error::SchemaNotRegistered(
508
                schema,
509
            )))
510
        }
511
52062
    }
512

            
513
    #[cfg(feature = "internal-apis")]
514
    #[doc(hidden)]
515
    /// Opens a database through a generic-free trait.
516
1336634
    pub async fn database_without_schema_internal(&self, name: &str) -> Result<Database, Error> {
517
577964
        self.database_without_schema(name).await
518
51409
    }
519

            
520
    #[cfg(feature = "multiuser")]
521
34
    async fn update_user_with_named_id<
522
34
        'user,
523
34
        'other,
524
34
        Col: NamedCollection<PrimaryKey = u64>,
525
34
        U: Nameable<'user, u64> + Send + Sync,
526
34
        O: Nameable<'other, u64> + Send + Sync,
527
34
        F: FnOnce(&mut CollectionDocument<User>, u64) -> bool,
528
34
    >(
529
34
        &self,
530
34
        user: U,
531
34
        other: O,
532
34
        callback: F,
533
34
    ) -> Result<(), bonsaidb_core::Error> {
534
34
        let admin = self.admin().await;
535
34
        let other = other.name()?;
536
34
        let (user, other) =
537
34
            futures::try_join!(User::load(user.name()?, &admin), other.id::<Col, _>(&admin),)?;
538
34
        match (user, other) {
539
34
            (Some(mut user), Some(other)) => {
540
34
                if callback(&mut user, other) {
541
17
                    user.update(&admin).await?;
542
17
                }
543
34
                Ok(())
544
            }
545
            // TODO make this a generic not found with a name parameter.
546
            _ => Err(bonsaidb_core::Error::UserNotFound),
547
        }
548
34
    }
549
}
550

            
551
#[async_trait]
552
pub trait DatabaseOpener: Send + Sync + Debug {
553
    fn schematic(&self) -> &'_ Schematic;
554
    async fn open(&self, name: String, storage: Storage) -> Result<Database, Error>;
555
}
556

            
557
#[derive(Debug)]
558
pub struct StorageSchemaOpener<DB: Schema> {
559
    schematic: Schematic,
560
    _phantom: PhantomData<DB>,
561
}
562

            
563
impl<DB> StorageSchemaOpener<DB>
564
where
565
    DB: Schema,
566
{
567
430
    pub fn new() -> Result<Self, Error> {
568
430
        let schematic = DB::schematic()?;
569
430
        Ok(Self {
570
430
            schematic,
571
430
            _phantom: PhantomData::default(),
572
430
        })
573
430
    }
574
}
575

            
576
#[async_trait]
577
impl<DB> DatabaseOpener for StorageSchemaOpener<DB>
578
where
579
    DB: Schema,
580
{
581
    fn schematic(&self) -> &'_ Schematic {
582
        &self.schematic
583
    }
584

            
585
51907
    async fn open(&self, name: String, storage: Storage) -> Result<Database, Error> {
586
51907
        let roots = storage.open_roots(&name).await?;
587
51907
        let db = Database::new::<DB, _>(name, roots, storage).await?;
588
51907
        Ok(db)
589
103814
    }
590
}
591

            
592
#[async_trait]
593
impl StorageConnection for Storage {
594
    type Database = Database;
595

            
596
    #[cfg_attr(
597
        feature = "tracing",
598
66285
        tracing::instrument(skip(name, schema, only_if_needed))
599
    )]
600
    async fn create_database_with_schema(
601
        &self,
602
        name: &str,
603
        schema: SchemaName,
604
        only_if_needed: bool,
605
22095
    ) -> Result<(), bonsaidb_core::Error> {
606
22095
        Self::validate_name(name)?;
607

            
608
        {
609
22015
            let schemas = fast_async_read!(self.data.schemas);
610
22015
            if !schemas.contains_key(&schema) {
611
80
                return Err(bonsaidb_core::Error::SchemaNotRegistered(schema));
612
21935
            }
613
        }
614

            
615
21935
        let mut available_databases = fast_async_write!(self.data.available_databases);
616
21935
        let admin = self.admin().await;
617
21935
        if !admin
618
21935
            .view::<database::ByName>()
619
21935
            .with_key(name.to_ascii_lowercase())
620
21935
            .query()
621
18820
            .await?
622
21935
            .is_empty()
623
        {
624
634
            if only_if_needed {
625
554
                return Ok(());
626
80
            }
627
80

            
628
80
            return Err(bonsaidb_core::Error::DatabaseNameAlreadyTaken(
629
80
                name.to_string(),
630
80
            ));
631
21301
        }
632
21301

            
633
21301
        admin
634
21301
            .collection::<DatabaseRecord>()
635
21301
            .push(&admin::Database {
636
21301
                name: name.to_string(),
637
21301
                schema: schema.clone(),
638
21301
            })
639
21301
            .await?;
640
21301
        available_databases.insert(name.to_string(), schema);
641
21301

            
642
21301
        Ok(())
643
44190
    }
644

            
645
642
    async fn database<DB: Schema>(
646
642
        &self,
647
642
        name: &str,
648
642
    ) -> Result<Self::Database, bonsaidb_core::Error> {
649
642
        let db = self.database_without_schema(name).await?;
650
487
        if db.data.schema.name == DB::schema_name() {
651
487
            Ok(db)
652
        } else {
653
            Err(bonsaidb_core::Error::SchemaMismatch {
654
                database_name: name.to_owned(),
655
                schema: DB::schema_name(),
656
                stored_schema: db.data.schema.name.clone(),
657
            })
658
        }
659
1284
    }
660

            
661
39480
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(name)))]
662
13160
    async fn delete_database(&self, name: &str) -> Result<(), bonsaidb_core::Error> {
663
13160
        let admin = self.admin().await;
664
13160
        let mut available_databases = fast_async_write!(self.data.available_databases);
665
13160
        available_databases.remove(name);
666

            
667
13160
        let mut open_roots = fast_async_lock!(self.data.open_roots);
668
13160
        open_roots.remove(name);
669
13160

            
670
13160
        let database_folder = self.path().join(name);
671
13160
        if database_folder.exists() {
672
13000
            let file_manager = self.data.file_manager.clone();
673
13000
            tokio::task::spawn_blocking(move || file_manager.delete_directory(&database_folder))
674
10322
                .await
675
13000
                .unwrap()
676
13000
                .map_err(Error::Nebari)?;
677
160
        }
678

            
679
13160
        if let Some(entry) = admin
680
13160
            .view::<database::ByName>()
681
13160
            .with_key(name.to_ascii_lowercase())
682
13160
            .query()
683
13160
            .await?
684
13160
            .first()
685
        {
686
13080
            admin.delete::<DatabaseRecord, _>(&entry.source).await?;
687

            
688
13080
            Ok(())
689
        } else {
690
80
            return Err(bonsaidb_core::Error::DatabaseNotFound(name.to_string()));
691
        }
692
26320
    }
693

            
694
240
    #[cfg_attr(feature = "tracing", tracing::instrument)]
695
80
    async fn list_databases(&self) -> Result<Vec<connection::Database>, bonsaidb_core::Error> {
696
80
        let available_databases = fast_async_read!(self.data.available_databases);
697
80
        Ok(available_databases
698
80
            .iter()
699
2006
            .map(|(name, schema)| connection::Database {
700
2006
                name: name.to_string(),
701
2006
                schema: schema.clone(),
702
2006
            })
703
80
            .collect())
704
160
    }
705

            
706
240
    #[cfg_attr(feature = "tracing", tracing::instrument)]
707
80
    async fn list_available_schemas(&self) -> Result<Vec<SchemaName>, bonsaidb_core::Error> {
708
80
        let available_databases = fast_async_read!(self.data.available_databases);
709
80
        Ok(available_databases.values().unique().cloned().collect())
710
160
    }
711

            
712
552
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(username)))]
713
    #[cfg(feature = "multiuser")]
714
184
    async fn create_user(&self, username: &str) -> Result<u64, bonsaidb_core::Error> {
715
184
        let result = self
716
184
            .admin()
717
            .await
718
184
            .collection::<User>()
719
396
            .push(&User::default_with_username(username))
720
396
            .await?;
721
158
        Ok(result.id)
722
368
    }
723

            
724
12
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user)))]
725
    #[cfg(feature = "multiuser")]
726
    async fn delete_user<'user, U: Nameable<'user, u64> + Send + Sync>(
727
        &self,
728
        user: U,
729
4
    ) -> Result<(), bonsaidb_core::Error> {
730
4
        let admin = self.admin().await;
731
4
        let doc = User::load(user, &admin)
732
4
            .await?
733
4
            .ok_or(bonsaidb_core::Error::UserNotFound)?;
734
4
        doc.delete(&admin).await?;
735

            
736
4
        Ok(())
737
8
    }
738

            
739
    #[cfg(feature = "password-hashing")]
740
9
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user, password)))]
741
    async fn set_user_password<'user, U: Nameable<'user, u64> + Send + Sync>(
742
        &self,
743
        user: U,
744
        password: bonsaidb_core::connection::SensitiveString,
745
3
    ) -> Result<(), bonsaidb_core::Error> {
746
3
        let admin = self.admin().await;
747
6
        let mut user = User::load(user, &admin)
748
6
            .await?
749
3
            .ok_or(bonsaidb_core::Error::UserNotFound)?;
750
3
        user.contents.argon_hash = Some(self.data.argon.hash(user.header.id, password).await?);
751
3
        user.update(&admin).await
752
6
    }
753

            
754
    #[cfg(all(feature = "multiuser", feature = "password-hashing"))]
755
15
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user)))]
756
    async fn authenticate<'user, U: Nameable<'user, u64> + Send + Sync>(
757
        &self,
758
        user: U,
759
        authentication: Authentication,
760
5
    ) -> Result<Authenticated, bonsaidb_core::Error> {
761
5
        let admin = self.admin().await;
762
5
        let user = User::load(user, &admin)
763
3
            .await?
764
5
            .ok_or(bonsaidb_core::Error::InvalidCredentials)?;
765
5
        match authentication {
766
5
            Authentication::Password(password) => {
767
5
                let saved_hash = user
768
5
                    .contents
769
5
                    .argon_hash
770
5
                    .clone()
771
5
                    .ok_or(bonsaidb_core::Error::InvalidCredentials)?;
772

            
773
5
                self.data
774
5
                    .argon
775
5
                    .verify(user.header.id, password, saved_hash)
776
5
                    .await?;
777
5
                let permissions = user.contents.effective_permissions(&admin).await?;
778
5
                Ok(Authenticated {
779
5
                    user_id: user.header.id,
780
5
                    permissions,
781
5
                })
782
            }
783
        }
784
10
    }
785

            
786
30
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user, permission_group)))]
787
    #[cfg(feature = "multiuser")]
788
    async fn add_permission_group_to_user<
789
        'user,
790
        'group,
791
        U: Nameable<'user, u64> + Send + Sync,
792
        G: Nameable<'group, u64> + Send + Sync,
793
    >(
794
        &self,
795
        user: U,
796
        permission_group: G,
797
10
    ) -> Result<(), bonsaidb_core::Error> {
798
10
        self.update_user_with_named_id::<PermissionGroup, _, _, _>(
799
10
            user,
800
10
            permission_group,
801
10
            |user, permission_group_id| {
802
10
                if user.contents.groups.contains(&permission_group_id) {
803
5
                    false
804
                } else {
805
5
                    user.contents.groups.push(permission_group_id);
806
5
                    true
807
                }
808
10
            },
809
16
        )
810
16
        .await
811
20
    }
812

            
813
24
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user, permission_group)))]
814
    #[cfg(feature = "multiuser")]
815
    async fn remove_permission_group_from_user<
816
        'user,
817
        'group,
818
        U: Nameable<'user, u64> + Send + Sync,
819
        G: Nameable<'group, u64> + Send + Sync,
820
    >(
821
        &self,
822
        user: U,
823
        permission_group: G,
824
8
    ) -> Result<(), bonsaidb_core::Error> {
825
8
        self.update_user_with_named_id::<PermissionGroup, _, _, _>(
826
8
            user,
827
8
            permission_group,
828
8
            |user, permission_group_id| {
829
8
                let old_len = user.contents.groups.len();
830
8
                user.contents.groups.retain(|id| id != &permission_group_id);
831
8
                old_len != user.contents.groups.len()
832
8
            },
833
11
        )
834
11
        .await
835
16
    }
836

            
837
24
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user, role)))]
838
    #[cfg(feature = "multiuser")]
839
    async fn add_role_to_user<
840
        'user,
841
        'group,
842
        U: Nameable<'user, u64> + Send + Sync,
843
        G: Nameable<'group, u64> + Send + Sync,
844
    >(
845
        &self,
846
        user: U,
847
        role: G,
848
8
    ) -> Result<(), bonsaidb_core::Error> {
849
8
        self.update_user_with_named_id::<PermissionGroup, _, _, _>(user, role, |user, role_id| {
850
8
            if user.contents.roles.contains(&role_id) {
851
4
                false
852
            } else {
853
4
                user.contents.roles.push(role_id);
854
4
                true
855
            }
856
10
        })
857
10
        .await
858
16
    }
859

            
860
24
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(user, role)))]
861
    #[cfg(feature = "multiuser")]
862
    async fn remove_role_from_user<
863
        'user,
864
        'group,
865
        U: Nameable<'user, u64> + Send + Sync,
866
        G: Nameable<'group, u64> + Send + Sync,
867
    >(
868
        &self,
869
        user: U,
870
        role: G,
871
8
    ) -> Result<(), bonsaidb_core::Error> {
872
8
        self.update_user_with_named_id::<Role, _, _, _>(user, role, |user, role_id| {
873
8
            let old_len = user.contents.roles.len();
874
8
            user.contents.roles.retain(|id| id != &role_id);
875
8
            old_len != user.contents.roles.len()
876
11
        })
877
11
        .await
878
16
    }
879
}
880

            
881
1
#[test]
882
1
fn name_validation_tests() {
883
1
    assert!(matches!(Storage::validate_name("azAZ09.-"), Ok(())));
884
1
    assert!(matches!(
885
1
        Storage::validate_name("_internal-names-work"),
886
        Ok(())
887
    ));
888
1
    assert!(matches!(
889
1
        Storage::validate_name("-alphaunmericfirstrequired"),
890
        Err(Error::Core(bonsaidb_core::Error::InvalidDatabaseName(_)))
891
    ));
892
1
    assert!(matches!(
893
1
        Storage::validate_name("\u{2661}"),
894
        Err(Error::Core(bonsaidb_core::Error::InvalidDatabaseName(_)))
895
    ));
896
1
}
897

            
898
/// The unique id of a [`Storage`] instance.
899
#[derive(Clone, Copy, Eq, PartialEq, Hash)]
900
pub struct StorageId(u64);
901

            
902
impl StorageId {
903
    /// Returns the id as a u64.
904
    #[must_use]
905
    pub const fn as_u64(self) -> u64 {
906
        self.0
907
    }
908
}
909

            
910
impl Debug for StorageId {
911
5080
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
912
5080
        // let formatted_length = format!();
913
5080
        write!(f, "{:016x}", self.0)
914
5080
    }
915
}
916

            
917
impl Display for StorageId {
918
5106
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
919
5106
        Debug::fmt(self, f)
920
5106
    }
921
}
922

            
923
1224195
#[derive(Debug, Clone)]
924
#[cfg(any(feature = "compression", feature = "encryption"))]
925
pub(crate) struct TreeVault {
926
    #[cfg(feature = "compression")]
927
    compression: Option<Compression>,
928
    #[cfg(feature = "encryption")]
929
    pub key: Option<KeyId>,
930
    #[cfg(feature = "encryption")]
931
    pub vault: Arc<Vault>,
932
}
933

            
934
#[cfg(all(feature = "compression", feature = "encryption"))]
935
impl TreeVault {
936
1395019
    pub(crate) fn new_if_needed(
937
1395019
        key: Option<KeyId>,
938
1395019
        vault: &Arc<Vault>,
939
1395019
        compression: Option<Compression>,
940
1395019
    ) -> Option<Self> {
941
1395019
        if key.is_none() && compression.is_none() {
942
1363882
            None
943
        } else {
944
31111
            Some(Self {
945
31111
                key,
946
31111
                compression,
947
31111
                vault: vault.clone(),
948
31111
            })
949
        }
950
1394993
    }
951

            
952
3491167
    fn header(&self, compressed: bool) -> u8 {
953
3491167
        let mut bits = if self.key.is_some() { 0b1000_0000 } else { 0 };
954

            
955
3491167
        if compressed {
956
1607741
            if let Some(compression) = self.compression {
957
1607741
                bits |= compression as u8;
958
1607741
            }
959
1883452
        }
960

            
961
3491193
        bits
962
3491193
    }
963
}
964

            
965
#[cfg(all(feature = "compression", feature = "encryption"))]
966
impl nebari::Vault for TreeVault {
967
    type Error = Error;
968

            
969
3491203
    fn encrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
970
3491203
        use std::borrow::Cow;
971
3491203

            
972
3491203
        // TODO this allocates too much. The vault should be able to do an
973
3491203
        // in-place encryption operation so that we can use a single buffer.
974
3491203
        let mut includes_compression = false;
975
3491203
        let compressed = match (payload.len(), self.compression) {
976
1610241
            (128..=usize::MAX, Some(Compression::Lz4)) => {
977
1607692
                includes_compression = true;
978
1607692
                Cow::Owned(lz4_flex::block::compress_prepend_size(payload))
979
            }
980
1883511
            _ => Cow::Borrowed(payload),
981
        };
982

            
983
3491203
        let mut complete = if let Some(key) = &self.key {
984
41146
            self.vault.encrypt_payload(key, &compressed, None)?
985
        } else {
986
3450057
            compressed.into_owned()
987
        };
988

            
989
3491203
        let header = self.header(includes_compression);
990
3491203
        if header != 0 {
991
1646160
            let header = [b't', b'r', b'v', header];
992
1646160
            complete.splice(0..0, header);
993
1848882
        }
994

            
995
3491416
        Ok(complete)
996
3491416
    }
997

            
998
393837
    fn decrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
999
393837
        use std::borrow::Cow;
393837

            
393837
        if payload.len() >= 4 && &payload[0..3] == b"trv" {
246814
            let header = payload[3];
246814
            let payload = &payload[4..];
246814
            let encrypted = (header & 0b1000_0000) != 0;
246814
            let compression = header & 0b0111_1111;
246814
            let decrypted = if encrypted {
8326
                Cow::Owned(self.vault.decrypt_payload(payload, None)?)
            } else {
238488
                Cow::Borrowed(payload)
            };
            #[allow(clippy::single_match)] // Make it an error when we add a new algorithm
246813
            return Ok(match Compression::from_u8(compression) {
                Some(Compression::Lz4) => {
240817
                    lz4_flex::block::decompress_size_prepended(&decrypted).map_err(Error::from)?
                }
5996
                None => decrypted.into_owned(),
            });
147023
        }
147023
        self.vault.decrypt_payload(payload, None)
393837
    }
}

            
#[cfg(all(feature = "compression", not(feature = "encryption")))]
impl TreeVault {
    pub(crate) fn new_if_needed(compression: Option<Compression>) -> Option<Self> {
        compression.map(|compression| Self {
            compression: Some(compression),
        })
    }
}

            
#[cfg(all(feature = "compression", not(feature = "encryption")))]
impl nebari::Vault for TreeVault {
    type Error = Error;

            
    fn encrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
        Ok(match (payload.len(), self.compression) {
            (128..=usize::MAX, Some(Compression::Lz4)) => {
                let mut destination =
                    vec![0; lz4_flex::block::get_maximum_output_size(payload.len()) + 8];
                let compressed_length =
                    lz4_flex::block::compress_into(payload, &mut destination[8..])
                        .expect("lz4-flex documents this shouldn't fail");
                destination.truncate(compressed_length + 8);
                destination[0..4].copy_from_slice(&[b't', b'r', b'v', Compression::Lz4 as u8]);
                // to_le_bytes() makes it compatible with lz4-flex decompress_size_prepended.
                let uncompressed_length =
                    u32::try_from(payload.len()).expect("nebari doesn't support >32 bit blocks");
                destination[4..8].copy_from_slice(&uncompressed_length.to_le_bytes());
                destination
            }
            // TODO this shouldn't copy
            _ => payload.to_vec(),
        })
    }

            
    fn decrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
        if payload.len() >= 4 && &payload[0..3] == b"trv" {
            let header = payload[3];
            let payload = &payload[4..];
            let encrypted = (header & 0b1000_0000) != 0;
            let compression = header & 0b0111_1111;
            if encrypted {
                return Err(Error::EncryptionDisabled);
            }

            
            #[allow(clippy::single_match)] // Make it an error when we add a new algorithm
            return Ok(match Compression::from_u8(compression) {
                Some(Compression::Lz4) => {
                    lz4_flex::block::decompress_size_prepended(payload).map_err(Error::from)?
                }
                None => payload.to_vec(),
            });
        }
        Ok(payload.to_vec())
    }
}

            
#[cfg(all(not(feature = "compression"), feature = "encryption"))]
impl TreeVault {
    pub(crate) fn new_if_needed(key: Option<KeyId>, vault: &Arc<Vault>) -> Option<Self> {
        key.map(|key| Self {
            key: Some(key),
            vault: vault.clone(),
        })
    }

            
    #[allow(dead_code)] // This implementation is sort of documentation for what it would be. But our Vault payload already can detect if a parsing error occurs, so we don't need a header if only encryption is enabled.
    fn header(&self) -> u8 {
        if self.key.is_some() {
            0b1000_0000
        } else {
            0
        }
    }
}

            
#[cfg(all(not(feature = "compression"), feature = "encryption"))]
impl nebari::Vault for TreeVault {
    type Error = Error;

            
    fn encrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
        if let Some(key) = &self.key {
            self.vault.encrypt_payload(key, payload, None)
        } else {
            // TODO does this need to copy?
            Ok(payload.to_vec())
        }
    }

            
    fn decrypt(&self, payload: &[u8]) -> Result<Vec<u8>, Error> {
        self.vault.decrypt_payload(payload, None)
    }
}