1
#![allow(clippy::missing_panics_doc)]
2

            
3
use std::{
4
    fmt::{Debug, Display},
5
    io::ErrorKind,
6
    ops::Deref,
7
    path::{Path, PathBuf},
8
    time::{Duration, Instant},
9
};
10

            
11
use itertools::Itertools;
12
use serde::{Deserialize, Serialize};
13
use transmog_pot::Pot;
14

            
15
#[cfg(feature = "multiuser")]
16
use crate::admin::{PermissionGroup, Role, User};
17
use crate::{
18
    connection::{AccessPolicy, Connection, StorageConnection},
19
    document::{
20
        AnyDocumentId, BorrowedDocument, CollectionDocument, CollectionHeader, DocumentId, Emit,
21
        Header, KeyId,
22
    },
23
    keyvalue::KeyValue,
24
    limits::{LIST_TRANSACTIONS_DEFAULT_RESULT_COUNT, LIST_TRANSACTIONS_MAX_RESULTS},
25
    schema::{
26
        view::{
27
            map::{Mappings, ViewMappedValue},
28
            ReduceResult, ViewSchema,
29
        },
30
        Collection, CollectionName, MappedValue, NamedCollection, Schema, SchemaName, Schematic,
31
        SerializedCollection, View, ViewMapResult,
32
    },
33
    Error,
34
};
35

            
36
786866
#[derive(Serialize, Deserialize, Debug, PartialEq, Default, Clone, Collection)]
37
// This collection purposely uses names with characters that need
38
// escaping, since it's used in backup/restore.
39
#[collection(name = "_basic", authority = "khonsulabs_", views = [BasicCount, BasicByParentId, BasicByTag, BasicByCategory], core = crate)]
40
pub struct Basic {
41
    pub value: String,
42
    pub category: Option<String>,
43
    pub parent_id: Option<u64>,
44
    pub tags: Vec<String>,
45
}
46

            
47
impl Basic {
48
2622
    pub fn new(value: impl Into<String>) -> Self {
49
2622
        Self {
50
2622
            value: value.into(),
51
2622
            tags: Vec::default(),
52
2622
            category: None,
53
2622
            parent_id: None,
54
2622
        }
55
2622
    }
56

            
57
25
    pub fn with_category(mut self, category: impl Into<String>) -> Self {
58
25
        self.category = Some(category.into());
59
25
        self
60
25
    }
61

            
62
20
    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
63
20
        self.tags.push(tag.into());
64
20
        self
65
20
    }
66

            
67
    #[must_use]
68
26
    pub fn with_parent_id(mut self, parent_id: impl Into<AnyDocumentId<u64>>) -> Self {
69
26
        self.parent_id = Some(parent_id.into().to_primary_key().unwrap());
70
26
        self
71
26
    }
72
}
73

            
74
1278657
#[derive(Debug, Clone, View)]
75
#[view(collection = Basic, key = (), value = usize, name = "count", core = crate)]
76
pub struct BasicCount;
77

            
78
impl ViewSchema for BasicCount {
79
    type View = Self;
80

            
81
26
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
82
26
        document.header.emit_key_and_value((), 1)
83
26
    }
84

            
85
26
    fn reduce(
86
26
        &self,
87
26
        mappings: &[ViewMappedValue<Self::View>],
88
26
        _rereduce: bool,
89
26
    ) -> ReduceResult<Self::View> {
90
26
        Ok(mappings.iter().map(|map| map.value).sum())
91
26
    }
92
}
93

            
94
1291083
#[derive(Debug, Clone, View)]
95
#[view(collection = Basic, key = Option<u64>, value = usize, name = "by-parent-id", core = crate)]
96
pub struct BasicByParentId;
97

            
98
impl ViewSchema for BasicByParentId {
99
    type View = Self;
100

            
101
1352
    fn version(&self) -> u64 {
102
1352
        1
103
1352
    }
104

            
105
1326
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
106
1326
        let contents = Basic::document_contents(document)?;
107
1326
        document.header.emit_key_and_value(contents.parent_id, 1)
108
1326
    }
109

            
110
1846
    fn reduce(
111
1846
        &self,
112
1846
        mappings: &[ViewMappedValue<Self::View>],
113
1846
        _rereduce: bool,
114
1846
    ) -> ReduceResult<Self::View> {
115
1976
        Ok(mappings.iter().map(|map| map.value).sum())
116
1846
    }
117
}
118

            
119
1279201
#[derive(Debug, Clone, View)]
120
#[view(collection = Basic, key = String, value = usize, name = "by-category", core = crate)]
121
pub struct BasicByCategory;
122

            
123
impl ViewSchema for BasicByCategory {
124
    type View = Self;
125

            
126
676
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
127
676
        let contents = Basic::document_contents(document)?;
128
676
        if let Some(category) = &contents.category {
129
390
            document
130
390
                .header
131
390
                .emit_key_and_value(category.to_lowercase(), 1)
132
        } else {
133
286
            Ok(Mappings::none())
134
        }
135
676
    }
136

            
137
390
    fn reduce(
138
390
        &self,
139
390
        mappings: &[ViewMappedValue<Self::View>],
140
390
        _rereduce: bool,
141
390
    ) -> ReduceResult<Self::View> {
142
520
        Ok(mappings.iter().map(|map| map.value).sum())
143
390
    }
144
}
145

            
146
1284375
#[derive(Debug, Clone, View)]
147
#[view(collection = Basic, key = String, value = usize, name = "by-tag", core = crate)]
148
pub struct BasicByTag;
149

            
150
impl ViewSchema for BasicByTag {
151
    type View = Self;
152

            
153
546
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
154
546
        let contents = Basic::document_contents(document)?;
155
546
        contents
156
546
            .tags
157
546
            .iter()
158
780
            .map(|tag| document.header.emit_key_and_value(tag.clone(), 1))
159
546
            .collect()
160
546
    }
161

            
162
1040
    fn reduce(
163
1040
        &self,
164
1040
        mappings: &[ViewMappedValue<Self::View>],
165
1040
        _rereduce: bool,
166
1040
    ) -> ReduceResult<Self::View> {
167
1300
        Ok(mappings.iter().map(|map| map.value).sum())
168
1040
    }
169
}
170

            
171
182
#[derive(Debug, Clone, View)]
172
#[view(collection = Basic, key = (), value = (), name = "by-parent-id", core = crate)]
173
pub struct BasicByBrokenParentId;
174

            
175
impl ViewSchema for BasicByBrokenParentId {
176
    type View = Self;
177

            
178
26
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
179
26
        document.header.emit()
180
26
    }
181
}
182

            
183
1573208
#[derive(Serialize, Deserialize, Debug, PartialEq, Default, Clone, Collection)]
184
#[collection(name = "encrypted-basic", authority = "khonsulabs", views = [EncryptedBasicCount, EncryptedBasicByParentId, EncryptedBasicByCategory])]
185
#[collection(encryption_key = Some(KeyId::Master), encryption_optional, core = crate)]
186
pub struct EncryptedBasic {
187
    pub value: String,
188
    pub category: Option<String>,
189
    pub parent_id: Option<u64>,
190
}
191

            
192
impl EncryptedBasic {
193
1
    pub fn new(value: impl Into<String>) -> Self {
194
1
        Self {
195
1
            value: value.into(),
196
1
            category: None,
197
1
            parent_id: None,
198
1
        }
199
1
    }
200

            
201
    pub fn with_category(mut self, category: impl Into<String>) -> Self {
202
        self.category = Some(category.into());
203
        self
204
    }
205

            
206
    #[must_use]
207
    pub const fn with_parent_id(mut self, parent_id: u64) -> Self {
208
        self.parent_id = Some(parent_id);
209
        self
210
    }
211
}
212

            
213
786786
#[derive(Debug, Clone, View)]
214
#[view(collection = EncryptedBasic, key = (), value = usize, name = "count", core = crate)]
215
pub struct EncryptedBasicCount;
216

            
217
impl ViewSchema for EncryptedBasicCount {
218
    type View = Self;
219

            
220
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
221
        document.header.emit_key_and_value((), 1)
222
    }
223

            
224
    fn reduce(
225
        &self,
226
        mappings: &[ViewMappedValue<Self::View>],
227
        _rereduce: bool,
228
    ) -> ReduceResult<Self::View> {
229
        Ok(mappings.iter().map(|map| map.value).sum())
230
    }
231
}
232

            
233
786786
#[derive(Debug, Clone, View)]
234
#[view(collection = EncryptedBasic, key = Option<u64>, value = usize, name = "by-parent-id", core = crate)]
235
pub struct EncryptedBasicByParentId;
236

            
237
impl ViewSchema for EncryptedBasicByParentId {
238
    type View = Self;
239

            
240
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
241
        let contents = EncryptedBasic::document_contents(document)?;
242
        document.header.emit_key_and_value(contents.parent_id, 1)
243
    }
244

            
245
    fn reduce(
246
        &self,
247
        mappings: &[ViewMappedValue<Self::View>],
248
        _rereduce: bool,
249
    ) -> ReduceResult<Self::View> {
250
        Ok(mappings.iter().map(|map| map.value).sum())
251
    }
252
}
253

            
254
786786
#[derive(Debug, Clone, View)]
255
#[view(collection = EncryptedBasic, key = String, value = usize, name = "by-category", core = crate)]
256
pub struct EncryptedBasicByCategory;
257

            
258
impl ViewSchema for EncryptedBasicByCategory {
259
    type View = Self;
260

            
261
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
262
        let contents = EncryptedBasic::document_contents(document)?;
263
        if let Some(category) = &contents.category {
264
            document
265
                .header
266
                .emit_key_and_value(category.to_lowercase(), 1)
267
        } else {
268
            Ok(Mappings::none())
269
        }
270
    }
271

            
272
    fn reduce(
273
        &self,
274
        mappings: &[ViewMappedValue<Self::View>],
275
        _rereduce: bool,
276
    ) -> ReduceResult<Self::View> {
277
        Ok(mappings.iter().map(|map| map.value).sum())
278
    }
279
}
280

            
281
786604
#[derive(Debug, Schema)]
282
#[schema(name = "basic", collections = [Basic, EncryptedBasic, Unique], core = crate)]
283
pub struct BasicSchema;
284

            
285
786604
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Default, Collection)]
286
#[collection(name = "unique", authority = "khonsulabs", views = [UniqueValue], core = crate)]
287
pub struct Unique {
288
    pub value: String,
289
}
290

            
291
impl Unique {
292
25
    pub fn new(value: impl Display) -> Self {
293
25
        Self {
294
25
            value: value.to_string(),
295
25
        }
296
25
    }
297
}
298

            
299
792922
#[derive(Debug, Clone, View)]
300
#[view(collection = Unique, key = String, value = (), name = "unique-value", core = crate)]
301
pub struct UniqueValue;
302

            
303
impl ViewSchema for UniqueValue {
304
    type View = Self;
305

            
306
790374
    fn unique(&self) -> bool {
307
790374
        true
308
790374
    }
309

            
310
1040
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
311
1040
        let entry = Unique::document_contents(document)?;
312
1040
        document.header.emit_key(entry.value)
313
1040
    }
314
}
315

            
316
impl NamedCollection for Unique {
317
    type ByNameView = UniqueValue;
318
}
319

            
320
pub struct TestDirectory(pub PathBuf);
321

            
322
impl TestDirectory {
323
114
    pub fn new<S: AsRef<Path>>(name: S) -> Self {
324
114
        let path = std::env::temp_dir().join(name);
325
114
        if path.exists() {
326
            std::fs::remove_dir_all(&path).expect("error clearing temporary directory");
327
114
        }
328
114
        Self(path)
329
114
    }
330
}
331

            
332
impl Drop for TestDirectory {
333
    fn drop(&mut self) {
334
2938
        if let Err(err) = std::fs::remove_dir_all(&self.0) {
335
            if err.kind() != ErrorKind::NotFound {
336
                eprintln!("Failed to clean up temporary folder: {:?}", err);
337
            }
338
2938
        }
339
2938
    }
340
}
341

            
342
impl AsRef<Path> for TestDirectory {
343
3224
    fn as_ref(&self) -> &Path {
344
3224
        &self.0
345
3224
    }
346
}
347

            
348
impl Deref for TestDirectory {
349
    type Target = PathBuf;
350

            
351
26
    fn deref(&self) -> &Self::Target {
352
26
        &self.0
353
26
    }
354
}
355

            
356
#[derive(Debug)]
357
pub struct BasicCollectionWithNoViews;
358

            
359
impl Collection for BasicCollectionWithNoViews {
360
    type PrimaryKey = u64;
361

            
362
208
    fn collection_name() -> CollectionName {
363
208
        Basic::collection_name()
364
208
    }
365

            
366
52
    fn define_views(_schema: &mut Schematic) -> Result<(), Error> {
367
52
        Ok(())
368
52
    }
369
}
370

            
371
impl SerializedCollection for BasicCollectionWithNoViews {
372
    type Contents = Basic;
373
    type Format = Pot;
374

            
375
26
    fn format() -> Self::Format {
376
26
        Pot::default()
377
26
    }
378
}
379

            
380
#[derive(Debug)]
381
pub struct BasicCollectionWithOnlyBrokenParentId;
382

            
383
impl Collection for BasicCollectionWithOnlyBrokenParentId {
384
    type PrimaryKey = u64;
385

            
386
182
    fn collection_name() -> CollectionName {
387
182
        Basic::collection_name()
388
182
    }
389

            
390
52
    fn define_views(schema: &mut Schematic) -> Result<(), Error> {
391
52
        schema.define_view(BasicByBrokenParentId)
392
52
    }
393
}
394

            
395
260
#[derive(Serialize, Deserialize, Clone, Debug, Collection)]
396
#[collection(name = "unassociated", authority = "khonsulabs", core = crate)]
397
pub struct UnassociatedCollection;
398

            
399
3874
#[derive(Copy, Clone, Debug)]
400
pub enum HarnessTest {
401
    ServerConnectionTests = 1,
402
    StoreRetrieveUpdate,
403
    NotFound,
404
    Conflict,
405
    BadUpdate,
406
    NoUpdate,
407
    GetMultiple,
408
    List,
409
    ListTransactions,
410
    ViewQuery,
411
    UnassociatedCollection,
412
    Compact,
413
    ViewUpdate,
414
    ViewMultiEmit,
415
    ViewUnimplementedReduce,
416
    ViewAccessPolicies,
417
    Encryption,
418
    UniqueViews,
419
    NamedCollection,
420
    PubSubSimple,
421
    UserManagement,
422
    PubSubMultipleSubscribers,
423
    PubSubDropAndSend,
424
    PubSubUnsubscribe,
425
    PubSubDropCleanup,
426
    PubSubPublishAll,
427
    KvBasic,
428
    KvConcurrency,
429
    KvSet,
430
    KvIncrementDecrement,
431
    KvExpiration,
432
    KvDeleteExpire,
433
    KvTransactions,
434
}
435

            
436
impl HarnessTest {
437
    #[must_use]
438
    pub const fn port(self, base: u16) -> u16 {
439
        base + self as u16
440
    }
441
}
442

            
443
impl Display for HarnessTest {
444
3874
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445
3874
        Debug::fmt(&self, f)
446
3874
    }
447
}
448

            
449
/// Compares two f64's accounting for the epsilon.
450
#[macro_export]
451
macro_rules! assert_f64_eq {
452
    ($a:expr, $b:expr) => {{
453
        let a: f64 = $a;
454
        let b: f64 = $b;
455
        assert!((a - b).abs() <= f64::EPSILON, "{:?} <> {:?}", a, b);
456
    }};
457
}
458

            
459
/// Creates a test suite that tests methods available on [`Connection`]
460
#[macro_export]
461
macro_rules! define_connection_test_suite {
462
    ($harness:ident) => {
463
        #[tokio::test]
464
5
        async fn server_connection_tests() -> anyhow::Result<()> {
465
            let harness =
466
                $harness::new($crate::test_util::HarnessTest::ServerConnectionTests).await?;
467
            let db = harness.server();
468
            $crate::test_util::basic_server_connection_tests(
469
                db.clone(),
470
                &format!("server-connection-tests-{}", $harness::server_name()),
471
            )
472
            .await?;
473
            harness.shutdown().await
474
        }
475

            
476
        #[tokio::test]
477
5
        async fn store_retrieve_update_delete() -> anyhow::Result<()> {
478
            let harness =
479
                $harness::new($crate::test_util::HarnessTest::StoreRetrieveUpdate).await?;
480
            let db = harness.connect().await?;
481
            $crate::test_util::store_retrieve_update_delete_tests(&db).await?;
482
            harness.shutdown().await
483
        }
484

            
485
        #[tokio::test]
486
5
        async fn not_found() -> anyhow::Result<()> {
487
            let harness = $harness::new($crate::test_util::HarnessTest::NotFound).await?;
488
            let db = harness.connect().await?;
489

            
490
            $crate::test_util::not_found_tests(&db).await?;
491
            harness.shutdown().await
492
        }
493

            
494
        #[tokio::test]
495
5
        async fn conflict() -> anyhow::Result<()> {
496
            let harness = $harness::new($crate::test_util::HarnessTest::Conflict).await?;
497
            let db = harness.connect().await?;
498

            
499
            $crate::test_util::conflict_tests(&db).await?;
500
            harness.shutdown().await
501
        }
502

            
503
        #[tokio::test]
504
5
        async fn bad_update() -> anyhow::Result<()> {
505
            let harness = $harness::new($crate::test_util::HarnessTest::BadUpdate).await?;
506
            let db = harness.connect().await?;
507

            
508
            $crate::test_util::bad_update_tests(&db).await?;
509
            harness.shutdown().await
510
        }
511

            
512
        #[tokio::test]
513
5
        async fn no_update() -> anyhow::Result<()> {
514
            let harness = $harness::new($crate::test_util::HarnessTest::NoUpdate).await?;
515
            let db = harness.connect().await?;
516

            
517
            $crate::test_util::no_update_tests(&db).await?;
518
            harness.shutdown().await
519
        }
520

            
521
        #[tokio::test]
522
5
        async fn get_multiple() -> anyhow::Result<()> {
523
            let harness = $harness::new($crate::test_util::HarnessTest::GetMultiple).await?;
524
            let db = harness.connect().await?;
525

            
526
            $crate::test_util::get_multiple_tests(&db).await?;
527
            harness.shutdown().await
528
        }
529

            
530
        #[tokio::test]
531
5
        async fn list() -> anyhow::Result<()> {
532
            let harness = $harness::new($crate::test_util::HarnessTest::List).await?;
533
            let db = harness.connect().await?;
534

            
535
            $crate::test_util::list_tests(&db).await?;
536
            harness.shutdown().await
537
        }
538

            
539
        #[tokio::test]
540
5
        async fn list_transactions() -> anyhow::Result<()> {
541
            let harness = $harness::new($crate::test_util::HarnessTest::ListTransactions).await?;
542
            let db = harness.connect().await?;
543

            
544
            $crate::test_util::list_transactions_tests(&db).await?;
545
            harness.shutdown().await
546
        }
547

            
548
        #[tokio::test]
549
5
        async fn view_query() -> anyhow::Result<()> {
550
            let harness = $harness::new($crate::test_util::HarnessTest::ViewQuery).await?;
551
            let db = harness.connect().await?;
552

            
553
            $crate::test_util::view_query_tests(&db).await?;
554
            harness.shutdown().await
555
        }
556

            
557
        #[tokio::test]
558
5
        async fn unassociated_collection() -> anyhow::Result<()> {
559
            let harness =
560
                $harness::new($crate::test_util::HarnessTest::UnassociatedCollection).await?;
561
            let db = harness.connect().await?;
562

            
563
            $crate::test_util::unassociated_collection_tests(&db).await?;
564
            harness.shutdown().await
565
        }
566

            
567
        #[tokio::test]
568
5
        async fn unimplemented_reduce() -> anyhow::Result<()> {
569
            let harness =
570
                $harness::new($crate::test_util::HarnessTest::ViewUnimplementedReduce).await?;
571
            let db = harness.connect().await?;
572

            
573
            $crate::test_util::unimplemented_reduce(&db).await?;
574
            harness.shutdown().await
575
        }
576

            
577
        #[tokio::test]
578
5
        async fn view_update() -> anyhow::Result<()> {
579
            let harness = $harness::new($crate::test_util::HarnessTest::ViewUpdate).await?;
580
            let db = harness.connect().await?;
581

            
582
            $crate::test_util::view_update_tests(&db).await?;
583
            harness.shutdown().await
584
        }
585

            
586
        #[tokio::test]
587
5
        async fn view_multi_emit() -> anyhow::Result<()> {
588
            let harness = $harness::new($crate::test_util::HarnessTest::ViewMultiEmit).await?;
589
            let db = harness.connect().await?;
590

            
591
            $crate::test_util::view_multi_emit_tests(&db).await?;
592
            harness.shutdown().await
593
        }
594

            
595
        #[tokio::test]
596
5
        async fn view_access_policies() -> anyhow::Result<()> {
597
            let harness = $harness::new($crate::test_util::HarnessTest::ViewAccessPolicies).await?;
598
            let db = harness.connect().await?;
599

            
600
            $crate::test_util::view_access_policy_tests(&db).await?;
601
            harness.shutdown().await
602
        }
603

            
604
        #[tokio::test]
605
5
        async fn unique_views() -> anyhow::Result<()> {
606
            let harness = $harness::new($crate::test_util::HarnessTest::UniqueViews).await?;
607
            let db = harness.connect().await?;
608

            
609
            $crate::test_util::unique_view_tests(&db).await?;
610
            harness.shutdown().await
611
        }
612

            
613
        #[tokio::test]
614
5
        async fn named_collection() -> anyhow::Result<()> {
615
            let harness = $harness::new($crate::test_util::HarnessTest::NamedCollection).await?;
616
            let db = harness.connect().await?;
617

            
618
            $crate::test_util::named_collection_tests(&db).await?;
619
            harness.shutdown().await
620
        }
621

            
622
        #[tokio::test]
623
        #[cfg(any(feature = "multiuser", feature = "local-multiuser", feature = "server"))]
624
4
        async fn user_management() -> anyhow::Result<()> {
625
            let harness = $harness::new($crate::test_util::HarnessTest::UserManagement).await?;
626
            let _db = harness.connect().await?;
627
            let server = harness.server();
628
            let admin = server
629
                .database::<$crate::admin::Admin>($crate::admin::ADMIN_DATABASE_NAME)
630
                .await?;
631

            
632
            $crate::test_util::user_management_tests(
633
                &admin,
634
                server.clone(),
635
                $harness::server_name(),
636
            )
637
            .await?;
638
            harness.shutdown().await
639
        }
640

            
641
        #[tokio::test]
642
5
        async fn compaction() -> anyhow::Result<()> {
643
            let harness = $harness::new($crate::test_util::HarnessTest::Compact).await?;
644
            let db = harness.connect().await?;
645

            
646
            $crate::test_util::compaction_tests(&db).await?;
647
            harness.shutdown().await
648
        }
649
    };
650
}
651

            
652
506
pub async fn store_retrieve_update_delete_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
653
506
    let original_value = Basic::new("initial_value");
654
506
    let collection = db.collection::<Basic>();
655
2060
    let header = collection.push(&original_value).await?;
656

            
657
506
    let mut doc = collection
658
1398
        .get(header.id)
659
1398
        .await?
660
506
        .expect("couldn't retrieve stored item");
661
506
    let mut value = Basic::document_contents(&doc)?;
662
506
    assert_eq!(original_value, value);
663
506
    let old_revision = doc.header.revision;
664
506

            
665
506
    // Update the value
666
506
    value.value = String::from("updated_value");
667
506
    Basic::set_document_contents(&mut doc, value.clone())?;
668
1607
    db.update::<Basic, _>(&mut doc).await?;
669

            
670
    // update should cause the revision to be changed
671
506
    assert_ne!(doc.header.revision, old_revision);
672

            
673
    // Check the value in the database to ensure it has the new document
674
506
    let doc = collection
675
1367
        .get(header.id)
676
1367
        .await?
677
506
        .expect("couldn't retrieve stored item");
678
506
    assert_eq!(Basic::document_contents(&doc)?, value);
679

            
680
    // These operations should have created two transactions with one change each
681
1444
    let transactions = db.list_executed_transactions(None, None).await?;
682
506
    assert_eq!(transactions.len(), 2);
683
506
    assert!(transactions[0].id < transactions[1].id);
684
1518
    for transaction in &transactions {
685
1012
        let changed_documents = transaction
686
1012
            .changes
687
1012
            .documents()
688
1012
            .expect("incorrect transaction type");
689
1012
        assert_eq!(changed_documents.len(), 1);
690
1012
        assert_eq!(changed_documents[0].collection, Basic::collection_name());
691
1012
        assert_eq!(header.id, changed_documents[0].id.deserialize()?);
692
1012
        assert!(!changed_documents[0].deleted);
693
    }
694

            
695
1559
    db.collection::<Basic>().delete(&doc).await?;
696
1467
    assert!(collection.get(header.id).await?.is_none());
697
506
    let transactions = db
698
1427
        .list_executed_transactions(Some(transactions.last().as_ref().unwrap().id + 1), None)
699
1427
        .await?;
700
506
    assert_eq!(transactions.len(), 1);
701
506
    let transaction = transactions.first().unwrap();
702
506
    let changed_documents = transaction
703
506
        .changes
704
506
        .documents()
705
506
        .expect("incorrect transaction type");
706
506
    assert_eq!(changed_documents.len(), 1);
707
506
    assert_eq!(changed_documents[0].collection, Basic::collection_name());
708
506
    assert_eq!(header.id, changed_documents[0].id.deserialize()?);
709
506
    assert!(changed_documents[0].deleted);
710

            
711
    // Use the Collection interface
712
1600
    let mut doc = original_value.clone().push_into(db).await?;
713
506
    doc.contents.category = Some(String::from("updated"));
714
1656
    doc.update(db).await?;
715
1520
    let reloaded = Basic::get(doc.header.id, db).await?.unwrap();
716
506
    assert_eq!(doc.contents, reloaded.contents);
717

            
718
    // Test Connection::insert with a specified id
719
506
    let doc = BorrowedDocument::with_contents::<Basic>(42, &Basic::new("42"))?;
720
506
    let document_42 = db
721
1559
        .insert::<Basic, _, _>(Some(doc.header.id), doc.contents.into_vec())
722
1559
        .await?;
723
506
    assert_eq!(document_42.id, 42);
724
1531
    let document_43 = Basic::new("43").insert_into(43, db).await?;
725
506
    assert_eq!(document_43.header.id, 43);
726

            
727
    // Test that inserting a document with the same ID results in a conflict:
728
506
    let conflict_err = Basic::new("43")
729
1474
        .insert_into(doc.header.id, db)
730
1474
        .await
731
506
        .unwrap_err();
732
506
    assert!(matches!(conflict_err.error, Error::DocumentConflict(..)));
733

            
734
    // Test that overwriting works
735
506
    let overwritten = Basic::new("43")
736
1550
        .overwrite_into(doc.header.id, db)
737
1550
        .await
738
506
        .unwrap();
739
506
    assert!(overwritten.header.revision.id > doc.header.revision.id);
740

            
741
506
    Ok(())
742
506
}
743

            
744
5
pub async fn not_found_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
745
5
    assert!(db.collection::<Basic>().get(1).await?.is_none());
746

            
747
5
    assert!(db.last_transaction_id().await?.is_none());
748

            
749
5
    Ok(())
750
5
}
751

            
752
5
pub async fn conflict_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
753
5
    let original_value = Basic::new("initial_value");
754
5
    let collection = db.collection::<Basic>();
755
5
    let header = collection.push(&original_value).await?;
756

            
757
5
    let mut doc = collection
758
5
        .get(header.id)
759
5
        .await?
760
5
        .expect("couldn't retrieve stored item");
761
5
    let mut value = Basic::document_contents(&doc)?;
762
5
    value.value = String::from("updated_value");
763
5
    Basic::set_document_contents(&mut doc, value.clone())?;
764
5
    db.update::<Basic, _>(&mut doc).await?;
765

            
766
    // To generate a conflict, let's try to do the same update again by
767
    // reverting the header
768
5
    doc.header = Header::try_from(header).unwrap();
769
5
    match db
770
5
        .update::<Basic, _>(&mut doc)
771
5
        .await
772
5
        .expect_err("conflict should have generated an error")
773
    {
774
5
        Error::DocumentConflict(collection, header) => {
775
5
            assert_eq!(collection, Basic::collection_name());
776
5
            assert_eq!(header.id, doc.header.id);
777
        }
778
        other => return Err(anyhow::Error::from(other)),
779
    }
780

            
781
    // Let's force an update through overwrite. After this succeeds, the header
782
    // is updated to the new revision.
783
5
    db.collection::<Basic>().overwrite(&mut doc).await.unwrap();
784

            
785
    // Now, let's use the CollectionDocument API to modify the document through a refetch.
786
5
    let mut doc = CollectionDocument::<Basic>::try_from(&doc)?;
787
5
    doc.modify(db, |doc| {
788
5
        doc.contents.value = String::from("modify worked");
789
5
    })
790
5
    .await?;
791
5
    assert_eq!(doc.contents.value, "modify worked");
792
5
    let doc = Basic::get(doc.header.id, db).await?.unwrap();
793
5
    assert_eq!(doc.contents.value, "modify worked");
794

            
795
5
    Ok(())
796
5
}
797

            
798
5
pub async fn bad_update_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
799
5
    let mut doc = BorrowedDocument::with_contents::<Basic>(1, &Basic::default())?;
800
5
    match db.update::<Basic, _>(&mut doc).await {
801
5
        Err(Error::DocumentNotFound(collection, id)) => {
802
5
            assert_eq!(collection, Basic::collection_name());
803
5
            assert_eq!(id.as_ref(), &DocumentId::from_u64(1));
804
5
            Ok(())
805
        }
806
        other => panic!("expected DocumentNotFound from update but got: {:?}", other),
807
    }
808
5
}
809

            
810
5
pub async fn no_update_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
811
5
    let original_value = Basic::new("initial_value");
812
5
    let collection = db.collection::<Basic>();
813
5
    let header = collection.push(&original_value).await?;
814

            
815
5
    let mut doc = collection
816
5
        .get(header.id)
817
4
        .await?
818
5
        .expect("couldn't retrieve stored item");
819
5
    db.update::<Basic, _>(&mut doc).await?;
820

            
821
5
    assert_eq!(CollectionHeader::try_from(doc.header)?, header);
822

            
823
5
    Ok(())
824
5
}
825

            
826
5
pub async fn get_multiple_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
827
5
    let collection = db.collection::<Basic>();
828
5
    let doc1_value = Basic::new("initial_value");
829
5
    let doc1 = collection.push(&doc1_value).await?;
830

            
831
5
    let doc2_value = Basic::new("second_value");
832
5
    let doc2 = collection.push(&doc2_value).await?;
833

            
834
5
    let both_docs = Basic::get_multiple([doc1.id, doc2.id], db).await?;
835
5
    assert_eq!(both_docs.len(), 2);
836

            
837
5
    let out_of_order = Basic::get_multiple([doc2.id, doc1.id], db).await?;
838
5
    assert_eq!(out_of_order.len(), 2);
839

            
840
    // The order of get_multiple isn't guaranteed, so these two checks are done
841
    // with iterators instead of direct indexing
842
5
    let doc1 = both_docs
843
5
        .iter()
844
5
        .find(|doc| doc.header.id == doc1.id)
845
5
        .expect("Couldn't find doc1");
846
5
    assert_eq!(doc1.contents.value, doc1_value.value);
847
5
    let doc2 = both_docs
848
5
        .iter()
849
10
        .find(|doc| doc.header.id == doc2.id)
850
5
        .expect("Couldn't find doc2");
851
5
    assert_eq!(doc2.contents.value, doc2_value.value);
852

            
853
5
    Ok(())
854
5
}
855

            
856
5
pub async fn list_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
857
5
    let collection = db.collection::<Basic>();
858
5
    let doc1_value = Basic::new("initial_value");
859
5
    let doc1 = collection.push(&doc1_value).await?;
860

            
861
5
    let doc2_value = Basic::new("second_value");
862
5
    let doc2 = collection.push(&doc2_value).await?;
863

            
864
5
    let all_docs = Basic::all(db).await?;
865
5
    assert_eq!(all_docs.len(), 2);
866

            
867
5
    let both_docs = Basic::list(doc1.id..=doc2.id, db).await?;
868
5
    assert_eq!(both_docs.len(), 2);
869

            
870
5
    assert_eq!(both_docs[0].contents.value, doc1_value.value);
871
5
    assert_eq!(both_docs[1].contents.value, doc2_value.value);
872

            
873
5
    let one_doc = Basic::list(doc1.id..doc2.id, db).await?;
874
5
    assert_eq!(one_doc.len(), 1);
875

            
876
5
    let limited = Basic::list(doc1.id..=doc2.id, db)
877
5
        .limit(1)
878
5
        .descending()
879
5
        .await?;
880
5
    assert_eq!(limited.len(), 1);
881
5
    assert_eq!(limited[0].contents.value, doc2_value.value);
882

            
883
5
    Ok(())
884
5
}
885

            
886
5
pub async fn list_transactions_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
887
5
    let collection = db.collection::<Basic>();
888
5

            
889
5
    // create LIST_TRANSACTIONS_MAX_RESULTS + 1 items, giving us just enough
890
5
    // transactions to test the edge cases of `list_transactions`
891
5
    futures::future::join_all(
892
5
        (0..=(LIST_TRANSACTIONS_MAX_RESULTS))
893
5005
            .map(|_| async { collection.push(&Basic::default()).await.unwrap() }),
894
4496
    )
895
4496
    .await;
896

            
897
    // Test defaults
898
5
    let transactions = db.list_executed_transactions(None, None).await?;
899
5
    assert_eq!(transactions.len(), LIST_TRANSACTIONS_DEFAULT_RESULT_COUNT);
900

            
901
    // Test max results limit
902
5
    let transactions = db
903
5
        .list_executed_transactions(None, Some(LIST_TRANSACTIONS_MAX_RESULTS + 1))
904
5
        .await?;
905
5
    assert_eq!(transactions.len(), LIST_TRANSACTIONS_MAX_RESULTS);
906

            
907
    // Test requesting 0 items
908
5
    let transactions = db.list_executed_transactions(None, Some(0)).await?;
909
5
    assert!(transactions.is_empty());
910

            
911
    // Test doing a loop fetching until we get no more results
912
5
    let mut transactions = Vec::new();
913
5
    let mut starting_id = None;
914
    loop {
915
60
        let chunk = db
916
60
            .list_executed_transactions(starting_id, Some(100))
917
60
            .await?;
918
60
        if chunk.is_empty() {
919
5
            break;
920
55
        }
921
55

            
922
55
        let max_id = chunk.last().map(|tx| tx.id).unwrap();
923
55
        starting_id = Some(max_id + 1);
924
55
        transactions.extend(chunk);
925
    }
926

            
927
5
    assert_eq!(transactions.len(), LIST_TRANSACTIONS_MAX_RESULTS + 1);
928

            
929
5
    Ok(())
930
5
}
931

            
932
5
pub async fn view_query_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
933
5
    let collection = db.collection::<Basic>();
934
5
    let a = collection.push(&Basic::new("A")).await?;
935
5
    let b = collection.push(&Basic::new("B")).await?;
936
5
    let a_child = collection
937
5
        .push(
938
5
            &Basic::new("A.1")
939
5
                .with_parent_id(a.id)
940
5
                .with_category("Alpha"),
941
5
        )
942
5
        .await?;
943
5
    collection
944
5
        .push(&Basic::new("B.1").with_parent_id(b.id).with_category("Beta"))
945
5
        .await?;
946
5
    collection
947
5
        .push(&Basic::new("B.2").with_parent_id(b.id).with_category("beta"))
948
5
        .await?;
949

            
950
5
    let a_children = db
951
5
        .view::<BasicByParentId>()
952
5
        .with_key(Some(a.id))
953
9
        .query()
954
9
        .await?;
955
5
    assert_eq!(a_children.len(), 1);
956

            
957
5
    let a_children = db
958
5
        .view::<BasicByParentId>()
959
5
        .with_key(Some(a.id))
960
7
        .query_with_collection_docs()
961
6
        .await?;
962
5
    assert_eq!(a_children.len(), 1);
963
5
    assert_eq!(a_children.get(0).unwrap().document.header, a_child);
964

            
965
5
    let b_children = db
966
5
        .view::<BasicByParentId>()
967
5
        .with_key(Some(b.id))
968
5
        .query()
969
2
        .await?;
970
5
    assert_eq!(b_children.len(), 2);
971

            
972
5
    let a_and_b_children = db
973
5
        .view::<BasicByParentId>()
974
5
        .with_keys([Some(a.id), Some(b.id)])
975
5
        .query()
976
2
        .await?;
977
5
    assert_eq!(a_and_b_children.len(), 3);
978

            
979
    // Test out of order keys
980
5
    let a_and_b_children = db
981
5
        .view::<BasicByParentId>()
982
5
        .with_keys([Some(b.id), Some(a.id)])
983
5
        .query()
984
2
        .await?;
985
5
    assert_eq!(a_and_b_children.len(), 3);
986

            
987
5
    let has_parent = db
988
5
        .view::<BasicByParentId>()
989
5
        .with_key_range(Some(0)..=Some(u64::MAX))
990
5
        .query()
991
2
        .await?;
992
5
    assert_eq!(has_parent.len(), 3);
993
    // Verify the result is sorted ascending
994
5
    assert!(has_parent
995
5
        .windows(2)
996
10
        .all(|window| window[0].key <= window[1].key));
997

            
998
    // Test limiting and descending order
999
5
    let last_with_parent = db
5
        .view::<BasicByParentId>()
5
        .with_key_range(Some(0)..=Some(u64::MAX))
5
        .descending()
5
        .limit(1)
5
        .query()
2
        .await?;
10
    assert_eq!(last_with_parent.iter().map(|m| m.key).unique().count(), 1);
5
    assert_eq!(last_with_parent[0].key, has_parent[2].key);

            
9
    let items_with_categories = db.view::<BasicByCategory>().query().await?;
5
    assert_eq!(items_with_categories.len(), 3);

            
    // Test deleting
5
    let deleted_count = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(b.id))
8
        .delete_docs()
8
        .await?;
5
    assert_eq!(b_children.len() as u64, deleted_count);
    assert_eq!(
5
        db.view::<BasicByParentId>()
5
            .with_key(Some(b.id))
5
            .query()
5
            .await?
5
            .len(),
        0
    );

            
5
    Ok(())
5
}

            
5
pub async fn unassociated_collection_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
5
    let result = db
5
        .collection::<UnassociatedCollection>()
5
        .push(&UnassociatedCollection)
5
        .await;
5
    match result {
5
        Err(Error::CollectionNotFound) => {}
        other => unreachable!("unexpected result: {:?}", other),
    }

            
5
    Ok(())
5
}

            
5
pub async fn unimplemented_reduce<C: Connection>(db: &C) -> anyhow::Result<()> {
5
    assert!(matches!(
5
        db.view::<UniqueValue>().reduce().await,
        Err(Error::ReduceUnimplemented)
    ));
5
    Ok(())
5
}

            
5
pub async fn view_update_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
5
    let collection = db.collection::<Basic>();
5
    let a = collection.push(&Basic::new("A")).await?;

            
5
    let a_children = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(a.id))
8
        .query()
8
        .await?;
5
    assert_eq!(a_children.len(), 0);
    // The reduce function of `BasicByParentId` acts as a "count" of records.
    assert_eq!(
5
        db.view::<BasicByParentId>()
5
            .with_key(Some(a.id))
5
            .reduce()
5
            .await?,
        0
    );

            
    // Test inserting a new record and the view being made available
5
    let a_child = collection
5
        .push(
5
            &Basic::new("A.1")
5
                .with_parent_id(a.id)
5
                .with_category("Alpha"),
5
        )
5
        .await?;

            
5
    let a_children = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(a.id))
5
        .query()
5
        .await?;
5
    assert_eq!(a_children.len(), 1);
    assert_eq!(
5
        db.view::<BasicByParentId>()
5
            .with_key(Some(a.id))
5
            .reduce()
5
            .await?,
        1
    );

            
    // Verify reduce_grouped matches our expectations.
    assert_eq!(
5
        db.view::<BasicByParentId>().reduce_grouped().await?,
5
        vec![MappedValue::new(None, 1,), MappedValue::new(Some(a.id), 1,),]
    );

            
    // Test updating the record and the view being updated appropriately
5
    let mut doc = db.collection::<Basic>().get(a_child.id).await?.unwrap();
5
    let mut basic = Basic::document_contents(&doc)?;
5
    basic.parent_id = None;
5
    Basic::set_document_contents(&mut doc, basic)?;
5
    db.update::<Basic, _>(&mut doc).await?;

            
5
    let a_children = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(a.id))
5
        .query()
5
        .await?;
5
    assert_eq!(a_children.len(), 0);
    assert_eq!(
5
        db.view::<BasicByParentId>()
5
            .with_key(Some(a.id))
5
            .reduce()
5
            .await?,
        0
    );
5
    assert_eq!(db.view::<BasicByParentId>().reduce().await?, 2);

            
    // Test deleting a record and ensuring it goes away
5
    db.collection::<Basic>().delete(&doc).await?;

            
5
    let all_entries = db.view::<BasicByParentId>().query().await?;
5
    assert_eq!(all_entries.len(), 1);

            
    // Verify reduce_grouped matches our expectations.
    assert_eq!(
5
        db.view::<BasicByParentId>().reduce_grouped().await?,
5
        vec![MappedValue::new(None, 1,),]
    );

            
5
    Ok(())
5
}

            
5
pub async fn view_multi_emit_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
5
    let mut a = Basic::new("A")
5
        .with_tag("red")
5
        .with_tag("green")
5
        .push_into(db)
5
        .await?;
5
    let mut b = Basic::new("B")
5
        .with_tag("blue")
5
        .with_tag("green")
5
        .push_into(db)
5
        .await?;

            
8
    assert_eq!(db.view::<BasicByTag>().query().await?.len(), 4);

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("green"))
5
            .query()
5
            .await?
5
            .len(),
        2
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("red"))
5
            .query()
2
            .await?
5
            .len(),
        1
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("blue"))
5
            .query()
2
            .await?
5
            .len(),
        1
    );

            
    // Change tags
5
    a.contents.tags = vec![String::from("red"), String::from("blue")];
5
    a.update(db).await?;

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("green"))
5
            .query()
5
            .await?
5
            .len(),
        1
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("red"))
5
            .query()
5
            .await?
5
            .len(),
        1
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("blue"))
5
            .query()
2
            .await?
5
            .len(),
        2
    );
5
    b.contents.tags.clear();
5
    b.update(db).await?;

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("green"))
5
            .query()
5
            .await?
5
            .len(),
        0
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("red"))
5
            .query()
5
            .await?
5
            .len(),
        1
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("blue"))
5
            .query()
2
            .await?
5
            .len(),
        1
    );

            
5
    Ok(())
5
}

            
5
pub async fn view_access_policy_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
5
    let collection = db.collection::<Basic>();
5
    let a = collection.push(&Basic::new("A")).await?;

            
    // Test inserting a record that should match the view, but ask for it to be
    // NoUpdate. Verify we get no matches.
5
    collection
5
        .push(
5
            &Basic::new("A.1")
5
                .with_parent_id(a.id)
5
                .with_category("Alpha"),
5
        )
5
        .await?;

            
5
    let a_children = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(a.id))
5
        .with_access_policy(AccessPolicy::NoUpdate)
5
        .query()
2
        .await?;
5
    assert_eq!(a_children.len(), 0);

            
5
    tokio::time::sleep(Duration::from_millis(20)).await;

            
    // Verify the view still have no value, but this time ask for it to be
    // updated after returning
5
    let a_children = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(a.id))
5
        .with_access_policy(AccessPolicy::UpdateAfter)
5
        .query()
2
        .await?;
5
    assert_eq!(a_children.len(), 0);

            
    // Waiting on background jobs can be unreliable in a CI environment
5
    for _ in 0..10_u8 {
5
        tokio::time::sleep(Duration::from_millis(20)).await;

            
        // Now, the view should contain the entry.
5
        let a_children = db
5
            .view::<BasicByParentId>()
5
            .with_key(Some(a.id))
5
            .with_access_policy(AccessPolicy::NoUpdate)
5
            .query()
2
            .await?;
5
        if a_children.len() == 1 {
5
            return Ok(());
        }
    }
    panic!("view never updated")
5
}

            
5
pub async fn unique_view_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
8
    let first_doc = db.collection::<Unique>().push(&Unique::new("1")).await?;

            
    if let Err(Error::UniqueKeyViolation {
5
        view,
5
        existing_document,
5
        conflicting_document,
5
    }) = db.collection::<Unique>().push(&Unique::new("1")).await
    {
5
        assert_eq!(view, UniqueValue.view_name());
5
        assert_eq!(first_doc.id, existing_document.id.deserialize()?);
        // We can't predict the conflicting document id since it's generated
        // inside of the transaction, but we can assert that it's different than
        // the document that was previously stored.
5
        assert_ne!(conflicting_document, existing_document);
    } else {
        unreachable!("unique key violation not triggered");
    }

            
5
    let second_doc = db.collection::<Unique>().push(&Unique::new("2")).await?;
5
    let mut second_doc = db.collection::<Unique>().get(second_doc.id).await?.unwrap();
5
    let mut contents = Unique::document_contents(&second_doc)?;
5
    contents.value = String::from("1");
5
    Unique::set_document_contents(&mut second_doc, contents)?;
    if let Err(Error::UniqueKeyViolation {
5
        view,
5
        existing_document,
5
        conflicting_document,
5
    }) = db.update::<Unique, _>(&mut second_doc).await
    {
5
        assert_eq!(view, UniqueValue.view_name());
5
        assert_eq!(first_doc.id, existing_document.id.deserialize()?);
5
        assert_eq!(conflicting_document.id, second_doc.header.id);
    } else {
        unreachable!("unique key violation not triggered");
    }

            
5
    Ok(())
5
}

            
5
pub async fn named_collection_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
8
    Unique::new("0").push_into(db).await?;
5
    let original_entry = Unique::entry("1", db)
5
        .update_with(|_existing: &mut Unique| unreachable!())
13
        .or_insert_with(|| Unique::new("1"))
13
        .await?
5
        .expect("Document not inserted");

            
5
    let updated = Unique::entry("1", db)
5
        .update_with(|existing: &mut Unique| {
5
            existing.value = String::from("2");
5
        })
12
        .or_insert_with(|| unreachable!())
12
        .await?
5
        .unwrap();
5
    assert_eq!(original_entry.header.id, updated.header.id);
5
    assert_ne!(original_entry.contents.value, updated.contents.value);

            
7
    let retrieved = Unique::entry("2", db).await?.unwrap();
5
    assert_eq!(retrieved.contents.value, updated.contents.value);

            
5
    let conflict = Unique::entry("2", db)
5
        .update_with(|existing: &mut Unique| {
5
            existing.value = String::from("0");
7
        })
6
        .await;
5
    assert!(matches!(conflict, Err(Error::UniqueKeyViolation { .. })));

            
5
    Ok(())
5
}

            
5
pub async fn compaction_tests<C: Connection + KeyValue>(db: &C) -> anyhow::Result<()> {
5
    let original_value = Basic::new("initial_value");
5
    let collection = db.collection::<Basic>();
5
    collection.push(&original_value).await?;

            
    // Test a collection compaction
5
    db.compact_collection::<Basic>().await?;

            
    // Test the key value store compaction
5
    db.set_key("foo", &1_u32).await?;
5
    db.compact_key_value_store().await?;

            
    // Compact everything... again...
5
    db.compact().await?;

            
5
    Ok(())
5
}

            
#[cfg(feature = "multiuser")]
4
pub async fn user_management_tests<C: Connection, S: StorageConnection>(
4
    admin: &C,
4
    server: S,
4
    server_name: &str,
4
) -> anyhow::Result<()> {
4
    let username = format!("user-management-tests-{}", server_name);
6
    let user_id = server.create_user(&username).await?;
    // Test the default created user state.
    {
4
        let user = User::get(user_id, admin)
4
            .await
4
            .unwrap()
4
            .expect("user not found");
4
        assert_eq!(user.contents.username, username);
4
        assert!(user.contents.groups.is_empty());
4
        assert!(user.contents.roles.is_empty());
    }

            
4
    let role = Role::named(format!("role-{}", server_name))
6
        .push_into(admin)
6
        .await
4
        .unwrap();
4
    let group = PermissionGroup::named(format!("group-{}", server_name))
6
        .push_into(admin)
6
        .await
4
        .unwrap();
4

            
4
    // Add the role and group.
4
    server
5
        .add_permission_group_to_user(user_id, &group)
5
        .await
4
        .unwrap();
6
    server.add_role_to_user(user_id, &role).await.unwrap();

            
    // Test the results
    {
4
        let user = User::get(user_id, admin)
3
            .await
4
            .unwrap()
4
            .expect("user not found");
4
        assert_eq!(user.contents.groups, vec![group.header.id]);
4
        assert_eq!(user.contents.roles, vec![role.header.id]);
    }

            
    // Add the same things again (should not do anything). With names this time.
4
    server
5
        .add_permission_group_to_user(&username, &group)
5
        .await
4
        .unwrap();
4
    server.add_role_to_user(&username, &role).await.unwrap();
    {
4
        let user = User::load(&username, admin)
2
            .await
4
            .unwrap()
4
            .expect("user not found");
4
        assert_eq!(user.contents.groups, vec![group.header.id]);
4
        assert_eq!(user.contents.roles, vec![role.header.id]);
    }

            
    // Remove the group.
4
    server
5
        .remove_permission_group_from_user(user_id, &group)
5
        .await
4
        .unwrap();
5
    server.remove_role_from_user(user_id, &role).await.unwrap();
    {
4
        let user = User::get(user_id, admin)
4
            .await
4
            .unwrap()
4
            .expect("user not found");
4
        assert!(user.contents.groups.is_empty());
4
        assert!(user.contents.roles.is_empty());
    }

            
    // Removing again shouldn't cause an error.
4
    server
4
        .remove_permission_group_from_user(user_id, &group)
3
        .await?;
4
    server.remove_role_from_user(user_id, &role).await?;

            
4
    Ok(())
4
}

            
/// Defines the `KeyValue` test suite
#[macro_export]
macro_rules! define_kv_test_suite {
    ($harness:ident) => {
        #[tokio::test]
5
        async fn basic_kv_test() -> anyhow::Result<()> {
            use $crate::keyvalue::{KeyStatus, KeyValue};
            let harness = $harness::new($crate::test_util::HarnessTest::KvBasic).await?;
            let db = harness.connect().await?;
            assert_eq!(
                db.set_key("akey", &String::from("avalue")).await?,
                KeyStatus::Inserted
            );
            assert_eq!(
                db.get_key("akey").into().await?,
                Some(String::from("avalue"))
            );
            assert_eq!(
                db.set_key("akey", &String::from("new_value"))
                    .returning_previous_as()
                    .await?,
                Some(String::from("avalue"))
            );
            assert_eq!(
                db.get_key("akey").into().await?,
                Some(String::from("new_value"))
            );
            assert_eq!(
                db.get_key("akey").and_delete().into().await?,
                Some(String::from("new_value"))
            );
            assert_eq!(db.get_key("akey").await?, None);
            assert_eq!(
                db.set_key("akey", &String::from("new_value"))
                    .returning_previous()
                    .await?,
                None
            );
            assert_eq!(db.delete_key("akey").await?, KeyStatus::Deleted);
            assert_eq!(db.delete_key("akey").await?, KeyStatus::NotChanged);

            
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn kv_concurrency() -> anyhow::Result<()> {
            use $crate::keyvalue::{KeyStatus, KeyValue};
            const WRITERS: usize = 100;
            const INCREMENTS: usize = 100;
            let harness = $harness::new($crate::test_util::HarnessTest::KvConcurrency).await?;
            let db = harness.connect().await?;

            
            let handles = (0..WRITERS).map(|_| {
                let db = db.clone();
                tokio::task::spawn(async move {
                    for _ in 0..INCREMENTS {
                        db.increment_key_by("concurrency", 1_u64).await.unwrap();
                    }
                })
            });
            futures::future::join_all(handles).await;

            
            assert_eq!(
                db.get_key("concurrency").into_u64().await.unwrap().unwrap(),
                (WRITERS * INCREMENTS) as u64
            );

            
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn kv_set_tests() -> anyhow::Result<()> {
            use $crate::keyvalue::{KeyStatus, KeyValue};
            let harness = $harness::new($crate::test_util::HarnessTest::KvSet).await?;
            let db = harness.connect().await?;
            let kv = db.with_key_namespace("set");

            
            assert_eq!(
                kv.set_key("a", &0_u32).only_if_exists().await?,
                KeyStatus::NotChanged
            );
            assert_eq!(
                kv.set_key("a", &0_u32).only_if_vacant().await?,
                KeyStatus::Inserted
            );
            assert_eq!(
                kv.set_key("a", &1_u32).only_if_vacant().await?,
                KeyStatus::NotChanged
            );
            assert_eq!(
                kv.set_key("a", &2_u32).only_if_exists().await?,
                KeyStatus::Updated,
            );
            assert_eq!(
                kv.set_key("a", &3_u32).returning_previous_as().await?,
                Some(2_u32),
            );

            
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn kv_increment_decrement_tests() -> anyhow::Result<()> {
            use $crate::keyvalue::{KeyStatus, KeyValue};
            let harness =
                $harness::new($crate::test_util::HarnessTest::KvIncrementDecrement).await?;
            let db = harness.connect().await?;
            let kv = db.with_key_namespace("increment_decrement");

            
            // Empty keys should be equal to 0
            assert_eq!(kv.increment_key_by("i64", 1_i64).await?, 1_i64);
            assert_eq!(kv.get_key("i64").into_i64().await?, Some(1_i64));
            assert_eq!(kv.increment_key_by("u64", 1_u64).await?, 1_u64);
            $crate::assert_f64_eq!(kv.increment_key_by("f64", 1_f64).await?, 1_f64);

            
            // Test float incrementing/decrementing an existing value
            $crate::assert_f64_eq!(kv.increment_key_by("f64", 1_f64).await?, 2_f64);
            $crate::assert_f64_eq!(kv.decrement_key_by("f64", 2_f64).await?, 0_f64);

            
            // Empty keys should be equal to 0
            assert_eq!(kv.decrement_key_by("i64_2", 1_i64).await?, -1_i64);
            assert_eq!(
                kv.decrement_key_by("u64_2", 42_u64)
                    .allow_overflow()
                    .await?,
                u64::MAX - 41
            );
            assert_eq!(kv.decrement_key_by("u64_3", 42_u64).await?, u64::MIN);
            $crate::assert_f64_eq!(kv.decrement_key_by("f64_2", 1_f64).await?, -1_f64);

            
            // Test decrement wrapping with overflow
            kv.set_numeric_key("i64", i64::MIN).await?;
            assert_eq!(
                kv.decrement_key_by("i64", 1_i64).allow_overflow().await?,
                i64::MAX
            );
            assert_eq!(
                kv.decrement_key_by("u64", 2_u64).allow_overflow().await?,
                u64::MAX
            );

            
            // Test increment wrapping with overflow
            assert_eq!(
                kv.increment_key_by("i64", 1_i64).allow_overflow().await?,
                i64::MIN
            );
            assert_eq!(
                kv.increment_key_by("u64", 1_u64).allow_overflow().await?,
                u64::MIN
            );

            
            // Test saturating increments.
            kv.set_numeric_key("i64", i64::MAX - 1).await?;
            kv.set_numeric_key("u64", u64::MAX - 1).await?;
            assert_eq!(kv.increment_key_by("i64", 2_i64).await?, i64::MAX);
            assert_eq!(kv.increment_key_by("u64", 2_u64).await?, u64::MAX);

            
            // Test saturating decrements.
            kv.set_numeric_key("i64", i64::MIN + 1).await?;
            kv.set_numeric_key("u64", u64::MIN + 1).await?;
            assert_eq!(kv.decrement_key_by("i64", 2_i64).await?, i64::MIN);
            assert_eq!(kv.decrement_key_by("u64", 2_u64).await?, u64::MIN);

            
            // Test numerical conversion safety using get
            {
                // For i64 -> f64, the limit is 2^52 + 1 in either posive or
                // negative directions.
                kv.set_numeric_key("i64", (2_i64.pow(f64::MANTISSA_DIGITS)))
                    .await?;
                $crate::assert_f64_eq!(
                    kv.get_key("i64").into_f64().await?.unwrap(),
                    9_007_199_254_740_992_f64
                );
                kv.set_numeric_key("i64", -(2_i64.pow(f64::MANTISSA_DIGITS)))
                    .await?;
                $crate::assert_f64_eq!(
                    kv.get_key("i64").into_f64().await?.unwrap(),
                    -9_007_199_254_740_992_f64
                );

            
                kv.set_numeric_key("i64", (2_i64.pow(f64::MANTISSA_DIGITS) + 1))
                    .await?;
                assert!(matches!(kv.get_key("i64").into_f64().await, Err(_)));
                $crate::assert_f64_eq!(
                    kv.get_key("i64").into_f64_lossy().await?.unwrap(),
                    9_007_199_254_740_993_f64
                );
                kv.set_numeric_key("i64", -(2_i64.pow(f64::MANTISSA_DIGITS) + 1))
                    .await?;
                assert!(matches!(kv.get_key("i64").into_f64().await, Err(_)));
                $crate::assert_f64_eq!(
                    kv.get_key("i64").into_f64_lossy().await?.unwrap(),
                    -9_007_199_254_740_993_f64
                );

            
                // For i64 -> u64, the only limit is sign.
                kv.set_numeric_key("i64", -1_i64).await?;
                assert!(matches!(kv.get_key("i64").into_u64().await, Err(_)));
                assert_eq!(
                    kv.get_key("i64").into_u64_lossy(true).await?.unwrap(),
                    0_u64
                );
                assert_eq!(
                    kv.get_key("i64").into_u64_lossy(false).await?.unwrap(),
                    u64::MAX
                );

            
                // For f64 -> i64, the limit is fractional numbers. Saturating isn't tested in this conversion path.
                kv.set_numeric_key("f64", 1.1_f64).await?;
                assert!(matches!(kv.get_key("f64").into_i64().await, Err(_)));
                assert_eq!(
                    kv.get_key("f64").into_i64_lossy(false).await?.unwrap(),
                    1_i64
                );
                kv.set_numeric_key("f64", -1.1_f64).await?;
                assert!(matches!(kv.get_key("f64").into_i64().await, Err(_)));
                assert_eq!(
                    kv.get_key("f64").into_i64_lossy(false).await?.unwrap(),
                    -1_i64
                );

            
                // For f64 -> u64, the limit is fractional numbers or negative numbers. Saturating isn't tested in this conversion path.
                kv.set_numeric_key("f64", 1.1_f64).await?;
                assert!(matches!(kv.get_key("f64").into_u64().await, Err(_)));
                assert_eq!(
                    kv.get_key("f64").into_u64_lossy(false).await?.unwrap(),
                    1_u64
                );
                kv.set_numeric_key("f64", -1.1_f64).await?;
                assert!(matches!(kv.get_key("f64").into_u64().await, Err(_)));
                assert_eq!(
                    kv.get_key("f64").into_u64_lossy(false).await?.unwrap(),
                    0_u64
                );

            
                // For u64 -> i64, the limit is > i64::MAX
                kv.set_numeric_key("u64", i64::MAX as u64 + 1).await?;
                assert!(matches!(kv.get_key("u64").into_i64().await, Err(_)));
                assert_eq!(
                    kv.get_key("u64").into_i64_lossy(true).await?.unwrap(),
                    i64::MAX
                );
                assert_eq!(
                    kv.get_key("u64").into_i64_lossy(false).await?.unwrap(),
                    i64::MIN
                );
            }

            
            // Test that non-numeric keys won't be changed when attempting to incr/decr
            kv.set_key("non-numeric", &String::from("test")).await?;
            assert!(matches!(
                kv.increment_key_by("non-numeric", 1_i64).await,
                Err(_)
            ));
            assert!(matches!(
                kv.decrement_key_by("non-numeric", 1_i64).await,
                Err(_)
            ));
            assert_eq!(
                kv.get_key("non-numeric").into::<String>().await?.unwrap(),
                String::from("test")
            );

            
            // Test that NaN cannot be stored
            kv.set_numeric_key("f64", 0_f64).await?;
            assert!(matches!(
                kv.set_numeric_key("f64", f64::NAN).await,
                Err(bonsaidb_core::Error::NotANumber)
            ));
            // Verify the value was unchanged.
            $crate::assert_f64_eq!(kv.get_key("f64").into_f64().await?.unwrap(), 0.);
            // Try to increment by nan
            assert!(matches!(
                kv.increment_key_by("f64", f64::NAN).await,
                Err(bonsaidb_core::Error::NotANumber)
            ));
            $crate::assert_f64_eq!(kv.get_key("f64").into_f64().await?.unwrap(), 0.);

            
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn kv_expiration_tests() -> anyhow::Result<()> {
            use std::time::Duration;

            
            use $crate::keyvalue::{KeyStatus, KeyValue};

            
            let harness = $harness::new($crate::test_util::HarnessTest::KvExpiration).await?;
            let db = harness.connect().await?;

            
            loop {
                let kv = db.with_key_namespace("expiration");

            
                kv.delete_key("a").await?;
                kv.delete_key("b").await?;

            
                // Test that the expiration is updated for key a, but not for key b.
                let timing = $crate::test_util::TimingTest::new(Duration::from_millis(500));
                let (r1, r2) = tokio::join!(
                    kv.set_key("a", &0_u32).expire_in(Duration::from_secs(2)),
                    kv.set_key("b", &0_u32).expire_in(Duration::from_secs(2))
                );
                if timing.elapsed() > Duration::from_millis(500) {
                    println!(
                        "Restarting test {}. Took too long {:?}",
                        line!(),
                        timing.elapsed(),
                    );
                    continue;
                }
                assert_eq!(r1?, KeyStatus::Inserted);
                assert_eq!(r2?, KeyStatus::Inserted);
                let (r1, r2) = tokio::join!(
                    kv.set_key("a", &1_u32).expire_in(Duration::from_secs(4)),
                    kv.set_key("b", &1_u32)
                        .expire_in(Duration::from_secs(100))
                        .keep_existing_expiration()
                );
                if timing.elapsed() > Duration::from_secs(1) {
                    println!(
                        "Restarting test {}. Took too long {:?}",
                        line!(),
                        timing.elapsed(),
                    );
                    continue;
                }

            
                assert_eq!(r1?, KeyStatus::Updated, "a wasn't an update");
                assert_eq!(r2?, KeyStatus::Updated, "b wasn't an update");

            
                let a = kv.get_key("a").into().await?;
                assert_eq!(a, Some(1u32), "a shouldn't have expired yet");

            
                // Before checking the value, make sure we haven't elapsed too
                // much time. If so, just restart the test.
                if !timing.wait_until(Duration::from_secs_f32(3.)).await {
                    println!(
                        "Restarting test {}. Took too long {:?}",
                        line!(),
                        timing.elapsed()
                    );
                    continue;
                }

            
                assert_eq!(kv.get_key("b").await?, None, "b never expired");

            
                timing.wait_until(Duration::from_secs_f32(5.)).await;
                assert_eq!(kv.get_key("a").await?, None, "a never expired");
                break;
            }
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn delete_expire_tests() -> anyhow::Result<()> {
            use std::time::Duration;

            
            use $crate::keyvalue::{KeyStatus, KeyValue};

            
            let harness = $harness::new($crate::test_util::HarnessTest::KvDeleteExpire).await?;
            let db = harness.connect().await?;

            
            loop {
                let kv = db.with_key_namespace("delete_expire");

            
                kv.delete_key("a").await?;

            
                let timing = $crate::test_util::TimingTest::new(Duration::from_millis(100));

            
                // Create a key with an expiration. Delete the key. Set a new
                // value at that key with no expiration. Ensure it doesn't
                // expire.
                kv.set_key("a", &0_u32)
                    .expire_in(Duration::from_secs(2))
                    .await?;
                kv.delete_key("a").await?;
                kv.set_key("a", &1_u32).await?;
                if timing.elapsed() > Duration::from_secs(1) {
                    println!(
                        "Restarting test {}. Took too long {:?}",
                        line!(),
                        timing.elapsed(),
                    );
                    continue;
                }
                if !timing.wait_until(Duration::from_secs_f32(2.5)).await {
                    println!(
                        "Restarting test {}. Took too long {:?}",
                        line!(),
                        timing.elapsed()
                    );
                    continue;
                }

            
                assert_eq!(kv.get_key("a").into().await?, Some(1u32));

            
                break;
            }
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn kv_transaction_tests() -> anyhow::Result<()> {
            use std::time::Duration;

            
            use $crate::{
                connection::Connection,
                keyvalue::{KeyStatus, KeyValue},
            };
            let harness = $harness::new($crate::test_util::HarnessTest::KvTransactions).await?;
            let db = harness.connect().await?;
            // Generate several transactions that we can validate. Persisting
            // happens in the background, so we delay between each step to give
            // it a moment.
            db.set_key("expires", &0_u32)
                .expire_in(Duration::from_secs(1))
                .await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            db.set_key("akey", &String::from("avalue")).await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            db.get_key("akey").and_delete().await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            db.set_numeric_key("nkey", 0_u64).await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            db.increment_key_by("nkey", 1_u64).await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            db.delete_key("nkey").await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            // Ensure this doesn't generate a transaction.
            db.delete_key("nkey").await?;

            
            tokio::time::sleep(Duration::from_secs(1)).await;

            
            let transactions = Connection::list_executed_transactions(&db, None, None).await?;
            let deleted_keys = transactions
                .iter()
                .filter_map(|tx| tx.changes.keys())
                .flatten()
                .filter(|changed_key| changed_key.deleted)
                .count();
            assert_eq!(deleted_keys, 3);
            let akey_changes = transactions
                .iter()
                .filter_map(|tx| tx.changes.keys())
                .flatten()
                .filter(|changed_key| changed_key.key == "akey")
                .count();
            assert_eq!(akey_changes, 2);
            let nkey_changes = transactions
                .iter()
                .filter_map(|tx| tx.changes.keys())
                .flatten()
                .filter(|changed_key| changed_key.key == "nkey")
                .count();
            assert_eq!(nkey_changes, 3);

            
            harness.shutdown().await?;

            
            Ok(())
        }
    };
}

            
pub struct TimingTest {
    tolerance: Duration,
    start: Instant,
}

            
impl TimingTest {
    #[must_use]
416
    pub fn new(tolerance: Duration) -> Self {
416
        Self {
416
            tolerance,
416
            start: Instant::now(),
416
        }
416
    }

            
624
    pub async fn wait_until(&self, absolute_duration: Duration) -> bool {
24
        let target = self.start + absolute_duration;
24
        let mut now = Instant::now();
24
        if now < target {
24
            tokio::time::sleep_until(target.into()).await;
24
            now = Instant::now();
        }
24
        let amount_past = now.checked_duration_since(target);
24

            
24
        // Return false if we're beyond the tolerance given
24
        amount_past.unwrap_or_default() < self.tolerance
24
    }

            
    #[must_use]
494
    pub fn elapsed(&self) -> Duration {
494
        Instant::now()
494
            .checked_duration_since(self.start)
494
            .unwrap_or_default()
494
    }
}

            
5
pub async fn basic_server_connection_tests<C: StorageConnection>(
5
    server: C,
5
    newdb_name: &str,
5
) -> anyhow::Result<()> {
5
    let mut schemas = server.list_available_schemas().await?;
5
    schemas.sort();
5
    assert!(schemas.contains(&BasicSchema::schema_name()));
5
    assert!(schemas.contains(&SchemaName::new("khonsulabs", "bonsaidb-admin")));

            
5
    let databases = server.list_databases().await?;
37
    assert!(databases.iter().any(|db| db.name == "tests"));

            
5
    server
8
        .create_database::<BasicSchema>(newdb_name, false)
8
        .await?;
8
    server.delete_database(newdb_name).await?;

            
    assert!(matches!(
5
        server.delete_database(newdb_name).await,
        Err(Error::DatabaseNotFound(_))
    ));

            
    assert!(matches!(
5
        server.create_database::<BasicSchema>("tests", false).await,
        Err(Error::DatabaseNameAlreadyTaken(_))
    ));

            
    assert!(matches!(
5
        server.create_database::<BasicSchema>("tests", true).await,
        Ok(_)
    ));

            
    assert!(matches!(
5
        server
5
            .create_database::<BasicSchema>("|invalidname", false)
2
            .await,
        Err(Error::InvalidDatabaseName(_))
    ));

            
    assert!(matches!(
5
        server
5
            .create_database::<UnassociatedCollection>(newdb_name, false)
2
            .await,
        Err(Error::SchemaNotRegistered(_))
    ));

            
5
    Ok(())
5
}