1
#![allow(clippy::missing_panics_doc)]
2

            
3
use std::{
4
    fmt::{Debug, Display},
5
    io::ErrorKind,
6
    ops::Deref,
7
    path::{Path, PathBuf},
8
    time::{Duration, Instant},
9
};
10

            
11
use itertools::Itertools;
12
use serde::{Deserialize, Serialize};
13
use transmog_pot::Pot;
14

            
15
#[cfg(feature = "multiuser")]
16
use crate::admin::{PermissionGroup, Role, User};
17
use crate::{
18
    connection::{AccessPolicy, Connection, StorageConnection},
19
    document::{
20
        AnyDocumentId, BorrowedDocument, CollectionDocument, CollectionHeader, DocumentId, Emit,
21
        Header, KeyId,
22
    },
23
    keyvalue::KeyValue,
24
    limits::{LIST_TRANSACTIONS_DEFAULT_RESULT_COUNT, LIST_TRANSACTIONS_MAX_RESULTS},
25
    schema::{
26
        view::{
27
            map::{Mappings, ViewMappedValue},
28
            ReduceResult, ViewSchema,
29
        },
30
        Collection, CollectionName, MappedValue, NamedCollection, Schema, SchemaName, Schematic,
31
        SerializedCollection, View, ViewMapResult,
32
    },
33
    Error,
34
};
35

            
36
817265
#[derive(Serialize, Deserialize, Debug, PartialEq, Default, Clone, Collection)]
37
// This collection purposely uses names with characters that need
38
// escaping, since it's used in backup/restore.
39
#[collection(name = "_basic", authority = "khonsulabs_", views = [BasicCount, BasicByParentId, BasicByTag, BasicByCategory], core = crate)]
40
#[must_use]
41
pub struct Basic {
42
    pub value: String,
43
    pub category: Option<String>,
44
    pub parent_id: Option<u64>,
45
    pub tags: Vec<String>,
46
}
47

            
48
impl Basic {
49
2622
    pub fn new(value: impl Into<String>) -> Self {
50
2622
        Self {
51
2622
            value: value.into(),
52
2622
            tags: Vec::default(),
53
2622
            category: None,
54
2622
            parent_id: None,
55
2622
        }
56
2622
    }
57

            
58
25
    pub fn with_category(mut self, category: impl Into<String>) -> Self {
59
25
        self.category = Some(category.into());
60
25
        self
61
25
    }
62

            
63
20
    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
64
20
        self.tags.push(tag.into());
65
20
        self
66
20
    }
67

            
68
26
    pub fn with_parent_id(mut self, parent_id: impl Into<AnyDocumentId<u64>>) -> Self {
69
26
        self.parent_id = Some(parent_id.into().to_primary_key().unwrap());
70
26
        self
71
26
    }
72
}
73

            
74
1327944
#[derive(Debug, Clone, View)]
75
#[view(collection = Basic, key = (), value = usize, name = "count", core = crate)]
76
pub struct BasicCount;
77

            
78
impl ViewSchema for BasicCount {
79
    type View = Self;
80

            
81
27
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
82
27
        document.header.emit_key_and_value((), 1)
83
27
    }
84

            
85
27
    fn reduce(
86
27
        &self,
87
27
        mappings: &[ViewMappedValue<Self::View>],
88
27
        _rereduce: bool,
89
27
    ) -> ReduceResult<Self::View> {
90
27
        Ok(mappings.iter().map(|map| map.value).sum())
91
27
    }
92
}
93

            
94
1340929
#[derive(Debug, Clone, View)]
95
#[view(collection = Basic, key = Option<u64>, value = usize, name = "by-parent-id", core = crate)]
96
pub struct BasicByParentId;
97

            
98
impl ViewSchema for BasicByParentId {
99
    type View = Self;
100

            
101
1404
    fn version(&self) -> u64 {
102
1404
        1
103
1404
    }
104

            
105
1377
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
106
1377
        let contents = Basic::document_contents(document)?;
107
1377
        document.header.emit_key_and_value(contents.parent_id, 1)
108
1377
    }
109

            
110
1512
    fn reduce(
111
1512
        &self,
112
1512
        mappings: &[ViewMappedValue<Self::View>],
113
1512
        _rereduce: bool,
114
1512
    ) -> ReduceResult<Self::View> {
115
1647
        Ok(mappings.iter().map(|map| map.value).sum())
116
1512
    }
117
}
118

            
119
1328509
#[derive(Debug, Clone, View)]
120
#[view(collection = Basic, key = String, value = usize, name = "by-category", core = crate)]
121
pub struct BasicByCategory;
122

            
123
impl ViewSchema for BasicByCategory {
124
    type View = Self;
125

            
126
702
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
127
702
        let contents = Basic::document_contents(document)?;
128
702
        if let Some(category) = &contents.category {
129
405
            document
130
405
                .header
131
405
                .emit_key_and_value(category.to_lowercase(), 1)
132
        } else {
133
297
            Ok(Mappings::none())
134
        }
135
702
    }
136

            
137
270
    fn reduce(
138
270
        &self,
139
270
        mappings: &[ViewMappedValue<Self::View>],
140
270
        _rereduce: bool,
141
270
    ) -> ReduceResult<Self::View> {
142
405
        Ok(mappings.iter().map(|map| map.value).sum())
143
270
    }
144
}
145

            
146
1333909
#[derive(Debug, Clone, View)]
147
#[view(collection = Basic, key = String, value = usize, name = "by-tag", core = crate)]
148
pub struct BasicByTag;
149

            
150
impl ViewSchema for BasicByTag {
151
    type View = Self;
152

            
153
567
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
154
567
        let contents = Basic::document_contents(document)?;
155
567
        contents
156
567
            .tags
157
567
            .iter()
158
810
            .map(|tag| document.header.emit_key_and_value(tag.clone(), 1))
159
567
            .collect()
160
567
    }
161

            
162
945
    fn reduce(
163
945
        &self,
164
945
        mappings: &[ViewMappedValue<Self::View>],
165
945
        _rereduce: bool,
166
945
    ) -> ReduceResult<Self::View> {
167
1215
        Ok(mappings.iter().map(|map| map.value).sum())
168
945
    }
169
}
170

            
171
189
#[derive(Debug, Clone, View)]
172
#[view(collection = Basic, key = (), value = (), name = "by-parent-id", core = crate)]
173
pub struct BasicByBrokenParentId;
174

            
175
impl ViewSchema for BasicByBrokenParentId {
176
    type View = Self;
177

            
178
27
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
179
27
        document.header.emit()
180
27
    }
181
}
182

            
183
1633986
#[derive(Serialize, Deserialize, Debug, PartialEq, Default, Clone, Collection)]
184
#[collection(name = "encrypted-basic", authority = "khonsulabs", views = [EncryptedBasicCount, EncryptedBasicByParentId, EncryptedBasicByCategory])]
185
#[collection(encryption_key = Some(KeyId::Master), encryption_optional, core = crate)]
186
#[must_use]
187
pub struct EncryptedBasic {
188
    pub value: String,
189
    pub category: Option<String>,
190
    pub parent_id: Option<u64>,
191
}
192

            
193
impl EncryptedBasic {
194
1
    pub fn new(value: impl Into<String>) -> Self {
195
1
        Self {
196
1
            value: value.into(),
197
1
            category: None,
198
1
            parent_id: None,
199
1
        }
200
1
    }
201

            
202
    pub fn with_category(mut self, category: impl Into<String>) -> Self {
203
        self.category = Some(category.into());
204
        self
205
    }
206

            
207
    pub const fn with_parent_id(mut self, parent_id: u64) -> Self {
208
        self.parent_id = Some(parent_id);
209
        self
210
    }
211
}
212

            
213
817182
#[derive(Debug, Clone, View)]
214
#[view(collection = EncryptedBasic, key = (), value = usize, name = "count", core = crate)]
215
pub struct EncryptedBasicCount;
216

            
217
impl ViewSchema for EncryptedBasicCount {
218
    type View = Self;
219

            
220
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
221
        document.header.emit_key_and_value((), 1)
222
    }
223

            
224
    fn reduce(
225
        &self,
226
        mappings: &[ViewMappedValue<Self::View>],
227
        _rereduce: bool,
228
    ) -> ReduceResult<Self::View> {
229
        Ok(mappings.iter().map(|map| map.value).sum())
230
    }
231
}
232

            
233
817182
#[derive(Debug, Clone, View)]
234
#[view(collection = EncryptedBasic, key = Option<u64>, value = usize, name = "by-parent-id", core = crate)]
235
pub struct EncryptedBasicByParentId;
236

            
237
impl ViewSchema for EncryptedBasicByParentId {
238
    type View = Self;
239

            
240
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
241
        let contents = EncryptedBasic::document_contents(document)?;
242
        document.header.emit_key_and_value(contents.parent_id, 1)
243
    }
244

            
245
    fn reduce(
246
        &self,
247
        mappings: &[ViewMappedValue<Self::View>],
248
        _rereduce: bool,
249
    ) -> ReduceResult<Self::View> {
250
        Ok(mappings.iter().map(|map| map.value).sum())
251
    }
252
}
253

            
254
817182
#[derive(Debug, Clone, View)]
255
#[view(collection = EncryptedBasic, key = String, value = usize, name = "by-category", core = crate)]
256
pub struct EncryptedBasicByCategory;
257

            
258
impl ViewSchema for EncryptedBasicByCategory {
259
    type View = Self;
260

            
261
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
262
        let contents = EncryptedBasic::document_contents(document)?;
263
        if let Some(category) = &contents.category {
264
            document
265
                .header
266
                .emit_key_and_value(category.to_lowercase(), 1)
267
        } else {
268
            Ok(Mappings::none())
269
        }
270
    }
271

            
272
    fn reduce(
273
        &self,
274
        mappings: &[ViewMappedValue<Self::View>],
275
        _rereduce: bool,
276
    ) -> ReduceResult<Self::View> {
277
        Ok(mappings.iter().map(|map| map.value).sum())
278
    }
279
}
280

            
281
816993
#[derive(Debug, Schema)]
282
#[schema(name = "basic", collections = [Basic, EncryptedBasic, Unique], core = crate)]
283
pub struct BasicSchema;
284

            
285
816993
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Default, Collection)]
286
#[collection(name = "unique", authority = "khonsulabs", views = [UniqueValue], core = crate)]
287
pub struct Unique {
288
    pub value: String,
289
}
290

            
291
impl Unique {
292
25
    pub fn new(value: impl Display) -> Self {
293
25
        Self {
294
25
            value: value.to_string(),
295
25
        }
296
25
    }
297
}
298

            
299
823689
#[derive(Debug, Clone, View)]
300
#[view(collection = Unique, key = String, value = (), name = "unique-value", core = crate)]
301
pub struct UniqueValue;
302

            
303
impl ViewSchema for UniqueValue {
304
    type View = Self;
305

            
306
820908
    fn unique(&self) -> bool {
307
820908
        true
308
820908
    }
309

            
310
1080
    fn map(&self, document: &BorrowedDocument<'_>) -> ViewMapResult<Self::View> {
311
1080
        let entry = Unique::document_contents(document)?;
312
1080
        document.header.emit_key(entry.value)
313
1080
    }
314
}
315

            
316
impl NamedCollection for Unique {
317
    type ByNameView = UniqueValue;
318
}
319

            
320
pub struct TestDirectory(pub PathBuf);
321

            
322
impl TestDirectory {
323
    pub fn absolute<S: AsRef<Path>>(path: S) -> Self {
324
        let path = path.as_ref().to_owned();
325
        if path.exists() {
326
            std::fs::remove_dir_all(&path).expect("error clearing temporary directory");
327
        }
328
        Self(path)
329
    }
330
116
    pub fn new<S: AsRef<Path>>(name: S) -> Self {
331
116
        let path = std::env::temp_dir().join(name);
332
116
        if path.exists() {
333
            std::fs::remove_dir_all(&path).expect("error clearing temporary directory");
334
116
        }
335
116
        Self(path)
336
116
    }
337
}
338

            
339
impl Drop for TestDirectory {
340
    fn drop(&mut self) {
341
3132
        if let Err(err) = std::fs::remove_dir_all(&self.0) {
342
            if err.kind() != ErrorKind::NotFound {
343
                eprintln!("Failed to clean up temporary folder: {:?}", err);
344
            }
345
3132
        }
346
3132
    }
347
}
348

            
349
impl AsRef<Path> for TestDirectory {
350
3456
    fn as_ref(&self) -> &Path {
351
3456
        &self.0
352
3456
    }
353
}
354

            
355
impl Deref for TestDirectory {
356
    type Target = PathBuf;
357

            
358
27
    fn deref(&self) -> &Self::Target {
359
27
        &self.0
360
27
    }
361
}
362

            
363
#[derive(Debug)]
364
pub struct BasicCollectionWithNoViews;
365

            
366
impl Collection for BasicCollectionWithNoViews {
367
    type PrimaryKey = u64;
368

            
369
216
    fn collection_name() -> CollectionName {
370
216
        Basic::collection_name()
371
216
    }
372

            
373
54
    fn define_views(_schema: &mut Schematic) -> Result<(), Error> {
374
54
        Ok(())
375
54
    }
376
}
377

            
378
impl SerializedCollection for BasicCollectionWithNoViews {
379
    type Contents = Basic;
380
    type Format = Pot;
381

            
382
27
    fn format() -> Self::Format {
383
27
        Pot::default()
384
27
    }
385
}
386

            
387
#[derive(Debug)]
388
pub struct BasicCollectionWithOnlyBrokenParentId;
389

            
390
impl Collection for BasicCollectionWithOnlyBrokenParentId {
391
    type PrimaryKey = u64;
392

            
393
189
    fn collection_name() -> CollectionName {
394
189
        Basic::collection_name()
395
189
    }
396

            
397
54
    fn define_views(schema: &mut Schematic) -> Result<(), Error> {
398
54
        schema.define_view(BasicByBrokenParentId)
399
54
    }
400
}
401

            
402
270
#[derive(Serialize, Deserialize, Clone, Debug, Collection)]
403
#[collection(name = "unassociated", authority = "khonsulabs", core = crate)]
404
pub struct UnassociatedCollection;
405

            
406
4023
#[derive(Copy, Clone, Debug)]
407
pub enum HarnessTest {
408
    ServerConnectionTests = 1,
409
    StoreRetrieveUpdate,
410
    NotFound,
411
    Conflict,
412
    BadUpdate,
413
    NoUpdate,
414
    GetMultiple,
415
    List,
416
    ListTransactions,
417
    ViewQuery,
418
    UnassociatedCollection,
419
    Compact,
420
    ViewUpdate,
421
    ViewMultiEmit,
422
    ViewUnimplementedReduce,
423
    ViewAccessPolicies,
424
    Encryption,
425
    UniqueViews,
426
    NamedCollection,
427
    PubSubSimple,
428
    UserManagement,
429
    PubSubMultipleSubscribers,
430
    PubSubDropAndSend,
431
    PubSubUnsubscribe,
432
    PubSubDropCleanup,
433
    PubSubPublishAll,
434
    KvBasic,
435
    KvConcurrency,
436
    KvSet,
437
    KvIncrementDecrement,
438
    KvExpiration,
439
    KvDeleteExpire,
440
    KvTransactions,
441
}
442

            
443
impl HarnessTest {
444
    #[must_use]
445
    pub const fn port(self, base: u16) -> u16 {
446
        base + self as u16
447
    }
448
}
449

            
450
impl Display for HarnessTest {
451
4023
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452
4023
        Debug::fmt(&self, f)
453
4023
    }
454
}
455

            
456
/// Compares two f64's accounting for the epsilon.
457
#[macro_export]
458
macro_rules! assert_f64_eq {
459
    ($a:expr, $b:expr) => {{
460
        let a: f64 = $a;
461
        let b: f64 = $b;
462
        assert!((a - b).abs() <= f64::EPSILON, "{:?} <> {:?}", a, b);
463
    }};
464
}
465

            
466
/// Creates a test suite that tests methods available on [`Connection`]
467
#[macro_export]
468
macro_rules! define_connection_test_suite {
469
    ($harness:ident) => {
470
        #[tokio::test]
471
5
        async fn server_connection_tests() -> anyhow::Result<()> {
472
            let harness =
473
                $harness::new($crate::test_util::HarnessTest::ServerConnectionTests).await?;
474
            let db = harness.server();
475
            $crate::test_util::basic_server_connection_tests(
476
                db.clone(),
477
                &format!("server-connection-tests-{}", $harness::server_name()),
478
            )
479
            .await?;
480
            harness.shutdown().await
481
        }
482

            
483
        #[tokio::test]
484
5
        async fn store_retrieve_update_delete() -> anyhow::Result<()> {
485
            let harness =
486
                $harness::new($crate::test_util::HarnessTest::StoreRetrieveUpdate).await?;
487
            let db = harness.connect().await?;
488
            $crate::test_util::store_retrieve_update_delete_tests(&db).await?;
489
            harness.shutdown().await
490
        }
491

            
492
        #[tokio::test]
493
5
        async fn not_found() -> anyhow::Result<()> {
494
            let harness = $harness::new($crate::test_util::HarnessTest::NotFound).await?;
495
            let db = harness.connect().await?;
496

            
497
            $crate::test_util::not_found_tests(&db).await?;
498
            harness.shutdown().await
499
        }
500

            
501
        #[tokio::test]
502
5
        async fn conflict() -> anyhow::Result<()> {
503
            let harness = $harness::new($crate::test_util::HarnessTest::Conflict).await?;
504
            let db = harness.connect().await?;
505

            
506
            $crate::test_util::conflict_tests(&db).await?;
507
            harness.shutdown().await
508
        }
509

            
510
        #[tokio::test]
511
5
        async fn bad_update() -> anyhow::Result<()> {
512
            let harness = $harness::new($crate::test_util::HarnessTest::BadUpdate).await?;
513
            let db = harness.connect().await?;
514

            
515
            $crate::test_util::bad_update_tests(&db).await?;
516
            harness.shutdown().await
517
        }
518

            
519
        #[tokio::test]
520
5
        async fn no_update() -> anyhow::Result<()> {
521
            let harness = $harness::new($crate::test_util::HarnessTest::NoUpdate).await?;
522
            let db = harness.connect().await?;
523

            
524
            $crate::test_util::no_update_tests(&db).await?;
525
            harness.shutdown().await
526
        }
527

            
528
        #[tokio::test]
529
5
        async fn get_multiple() -> anyhow::Result<()> {
530
            let harness = $harness::new($crate::test_util::HarnessTest::GetMultiple).await?;
531
            let db = harness.connect().await?;
532

            
533
            $crate::test_util::get_multiple_tests(&db).await?;
534
            harness.shutdown().await
535
        }
536

            
537
        #[tokio::test]
538
5
        async fn list() -> anyhow::Result<()> {
539
            let harness = $harness::new($crate::test_util::HarnessTest::List).await?;
540
            let db = harness.connect().await?;
541

            
542
            $crate::test_util::list_tests(&db).await?;
543
            harness.shutdown().await
544
        }
545

            
546
        #[tokio::test]
547
5
        async fn list_transactions() -> anyhow::Result<()> {
548
            let harness = $harness::new($crate::test_util::HarnessTest::ListTransactions).await?;
549
            let db = harness.connect().await?;
550

            
551
            $crate::test_util::list_transactions_tests(&db).await?;
552
            harness.shutdown().await
553
        }
554

            
555
        #[tokio::test]
556
5
        async fn view_query() -> anyhow::Result<()> {
557
            let harness = $harness::new($crate::test_util::HarnessTest::ViewQuery).await?;
558
            let db = harness.connect().await?;
559

            
560
            $crate::test_util::view_query_tests(&db).await?;
561
            harness.shutdown().await
562
        }
563

            
564
        #[tokio::test]
565
5
        async fn unassociated_collection() -> anyhow::Result<()> {
566
            let harness =
567
                $harness::new($crate::test_util::HarnessTest::UnassociatedCollection).await?;
568
            let db = harness.connect().await?;
569

            
570
            $crate::test_util::unassociated_collection_tests(&db).await?;
571
            harness.shutdown().await
572
        }
573

            
574
        #[tokio::test]
575
5
        async fn unimplemented_reduce() -> anyhow::Result<()> {
576
            let harness =
577
                $harness::new($crate::test_util::HarnessTest::ViewUnimplementedReduce).await?;
578
            let db = harness.connect().await?;
579

            
580
            $crate::test_util::unimplemented_reduce(&db).await?;
581
            harness.shutdown().await
582
        }
583

            
584
        #[tokio::test]
585
5
        async fn view_update() -> anyhow::Result<()> {
586
            let harness = $harness::new($crate::test_util::HarnessTest::ViewUpdate).await?;
587
            let db = harness.connect().await?;
588

            
589
            $crate::test_util::view_update_tests(&db).await?;
590
            harness.shutdown().await
591
        }
592

            
593
        #[tokio::test]
594
5
        async fn view_multi_emit() -> anyhow::Result<()> {
595
            let harness = $harness::new($crate::test_util::HarnessTest::ViewMultiEmit).await?;
596
            let db = harness.connect().await?;
597

            
598
            $crate::test_util::view_multi_emit_tests(&db).await?;
599
            harness.shutdown().await
600
        }
601

            
602
        #[tokio::test]
603
5
        async fn view_access_policies() -> anyhow::Result<()> {
604
            let harness = $harness::new($crate::test_util::HarnessTest::ViewAccessPolicies).await?;
605
            let db = harness.connect().await?;
606

            
607
            $crate::test_util::view_access_policy_tests(&db).await?;
608
            harness.shutdown().await
609
        }
610

            
611
        #[tokio::test]
612
5
        async fn unique_views() -> anyhow::Result<()> {
613
            let harness = $harness::new($crate::test_util::HarnessTest::UniqueViews).await?;
614
            let db = harness.connect().await?;
615

            
616
            $crate::test_util::unique_view_tests(&db).await?;
617
            harness.shutdown().await
618
        }
619

            
620
        #[tokio::test]
621
5
        async fn named_collection() -> anyhow::Result<()> {
622
            let harness = $harness::new($crate::test_util::HarnessTest::NamedCollection).await?;
623
            let db = harness.connect().await?;
624

            
625
            $crate::test_util::named_collection_tests(&db).await?;
626
            harness.shutdown().await
627
        }
628

            
629
        #[tokio::test]
630
        #[cfg(any(feature = "multiuser", feature = "local-multiuser", feature = "server"))]
631
4
        async fn user_management() -> anyhow::Result<()> {
632
            let harness = $harness::new($crate::test_util::HarnessTest::UserManagement).await?;
633
            let _db = harness.connect().await?;
634
            let server = harness.server();
635
            let admin = server
636
                .database::<$crate::admin::Admin>($crate::admin::ADMIN_DATABASE_NAME)
637
                .await?;
638

            
639
            $crate::test_util::user_management_tests(
640
                &admin,
641
                server.clone(),
642
                $harness::server_name(),
643
            )
644
            .await?;
645
            harness.shutdown().await
646
        }
647

            
648
        #[tokio::test]
649
5
        async fn compaction() -> anyhow::Result<()> {
650
            let harness = $harness::new($crate::test_util::HarnessTest::Compact).await?;
651
            let db = harness.connect().await?;
652

            
653
            $crate::test_util::compaction_tests(&db).await?;
654
            harness.shutdown().await
655
        }
656
    };
657
}
658

            
659
506
pub async fn store_retrieve_update_delete_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
660
506
    let original_value = Basic::new("initial_value");
661
506
    let collection = db.collection::<Basic>();
662
2052
    let header = collection.push(&original_value).await?;
663

            
664
506
    let mut doc = collection
665
1498
        .get(header.id)
666
1498
        .await?
667
506
        .expect("couldn't retrieve stored item");
668
506
    let mut value = Basic::document_contents(&doc)?;
669
506
    assert_eq!(original_value, value);
670
506
    let old_revision = doc.header.revision;
671
506

            
672
506
    // Update the value
673
506
    value.value = String::from("updated_value");
674
506
    Basic::set_document_contents(&mut doc, value.clone())?;
675
1743
    db.update::<Basic, _>(&mut doc).await?;
676

            
677
    // update should cause the revision to be changed
678
506
    assert_ne!(doc.header.revision, old_revision);
679

            
680
    // Check the value in the database to ensure it has the new document
681
506
    let doc = collection
682
1546
        .get(header.id)
683
1546
        .await?
684
506
        .expect("couldn't retrieve stored item");
685
506
    assert_eq!(Basic::document_contents(&doc)?, value);
686

            
687
    // These operations should have created two transactions with one change each
688
1515
    let transactions = db.list_executed_transactions(None, None).await?;
689
506
    assert_eq!(transactions.len(), 2);
690
506
    assert!(transactions[0].id < transactions[1].id);
691
1518
    for transaction in &transactions {
692
1012
        let changes = transaction
693
1012
            .changes
694
1012
            .documents()
695
1012
            .expect("incorrect transaction type");
696
1012
        assert_eq!(changes.documents.len(), 1);
697
1012
        assert_eq!(changes.collections.len(), 1);
698
1012
        assert_eq!(changes.collections[0], Basic::collection_name());
699
1012
        assert_eq!(changes.documents[0].collection, 0);
700
1012
        assert_eq!(header.id, changes.documents[0].id.deserialize()?);
701
1012
        assert!(!changes.documents[0].deleted);
702
    }
703

            
704
1653
    db.collection::<Basic>().delete(&doc).await?;
705
1583
    assert!(collection.get(header.id).await?.is_none());
706
506
    let transactions = db
707
1524
        .list_executed_transactions(Some(transactions.last().as_ref().unwrap().id + 1), None)
708
1524
        .await?;
709
506
    assert_eq!(transactions.len(), 1);
710
506
    let transaction = transactions.first().unwrap();
711
506
    let changes = transaction
712
506
        .changes
713
506
        .documents()
714
506
        .expect("incorrect transaction type");
715
506
    assert_eq!(changes.documents.len(), 1);
716
506
    assert_eq!(changes.collections[0], Basic::collection_name());
717
506
    assert_eq!(header.id, changes.documents[0].id.deserialize()?);
718
506
    assert!(changes.documents[0].deleted);
719

            
720
    // Use the Collection interface
721
1702
    let mut doc = original_value.clone().push_into(db).await?;
722
506
    doc.contents.category = Some(String::from("updated"));
723
1740
    doc.update(db).await?;
724
1504
    let reloaded = Basic::get(doc.header.id, db).await?.unwrap();
725
506
    assert_eq!(doc.contents, reloaded.contents);
726

            
727
    // Test Connection::insert with a specified id
728
506
    let doc = BorrowedDocument::with_contents::<Basic>(42, &Basic::new("42"))?;
729
506
    let document_42 = db
730
1742
        .insert::<Basic, _, _>(Some(doc.header.id), doc.contents.into_vec())
731
1742
        .await?;
732
506
    assert_eq!(document_42.id, 42);
733
1719
    let document_43 = Basic::new("43").insert_into(43, db).await?;
734
506
    assert_eq!(document_43.header.id, 43);
735

            
736
    // Test that inserting a document with the same ID results in a conflict:
737
506
    let conflict_err = Basic::new("43")
738
1647
        .insert_into(doc.header.id, db)
739
1647
        .await
740
506
        .unwrap_err();
741
506
    assert!(matches!(conflict_err.error, Error::DocumentConflict(..)));
742

            
743
    // Test that overwriting works
744
506
    let overwritten = Basic::new("43")
745
1809
        .overwrite_into(doc.header.id, db)
746
1809
        .await
747
506
        .unwrap();
748
506
    assert!(overwritten.header.revision.id > doc.header.revision.id);
749

            
750
506
    Ok(())
751
506
}
752

            
753
5
pub async fn not_found_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
754
5
    assert!(db.collection::<Basic>().get(1).await?.is_none());
755

            
756
5
    assert!(db.last_transaction_id().await?.is_none());
757

            
758
5
    Ok(())
759
5
}
760

            
761
5
pub async fn conflict_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
762
5
    let original_value = Basic::new("initial_value");
763
5
    let collection = db.collection::<Basic>();
764
5
    let header = collection.push(&original_value).await?;
765

            
766
5
    let mut doc = collection
767
5
        .get(header.id)
768
5
        .await?
769
5
        .expect("couldn't retrieve stored item");
770
5
    let mut value = Basic::document_contents(&doc)?;
771
5
    value.value = String::from("updated_value");
772
5
    Basic::set_document_contents(&mut doc, value.clone())?;
773
5
    db.update::<Basic, _>(&mut doc).await?;
774

            
775
    // To generate a conflict, let's try to do the same update again by
776
    // reverting the header
777
5
    doc.header = Header::try_from(header).unwrap();
778
5
    match db
779
5
        .update::<Basic, _>(&mut doc)
780
5
        .await
781
5
        .expect_err("conflict should have generated an error")
782
    {
783
5
        Error::DocumentConflict(collection, header) => {
784
5
            assert_eq!(collection, Basic::collection_name());
785
5
            assert_eq!(header.id, doc.header.id);
786
        }
787
        other => return Err(anyhow::Error::from(other)),
788
    }
789

            
790
    // Let's force an update through overwrite. After this succeeds, the header
791
    // is updated to the new revision.
792
5
    db.collection::<Basic>().overwrite(&mut doc).await.unwrap();
793

            
794
    // Now, let's use the CollectionDocument API to modify the document through a refetch.
795
5
    let mut doc = CollectionDocument::<Basic>::try_from(&doc)?;
796
5
    doc.modify(db, |doc| {
797
5
        doc.contents.value = String::from("modify worked");
798
5
    })
799
5
    .await?;
800
5
    assert_eq!(doc.contents.value, "modify worked");
801
5
    let doc = Basic::get(doc.header.id, db).await?.unwrap();
802
5
    assert_eq!(doc.contents.value, "modify worked");
803

            
804
5
    Ok(())
805
5
}
806

            
807
5
pub async fn bad_update_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
808
5
    let mut doc = BorrowedDocument::with_contents::<Basic>(1, &Basic::default())?;
809
5
    match db.update::<Basic, _>(&mut doc).await {
810
5
        Err(Error::DocumentNotFound(collection, id)) => {
811
5
            assert_eq!(collection, Basic::collection_name());
812
5
            assert_eq!(id.as_ref(), &DocumentId::from_u64(1));
813
5
            Ok(())
814
        }
815
        other => panic!("expected DocumentNotFound from update but got: {:?}", other),
816
    }
817
5
}
818

            
819
5
pub async fn no_update_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
820
5
    let original_value = Basic::new("initial_value");
821
5
    let collection = db.collection::<Basic>();
822
5
    let header = collection.push(&original_value).await?;
823

            
824
5
    let mut doc = collection
825
5
        .get(header.id)
826
5
        .await?
827
5
        .expect("couldn't retrieve stored item");
828
5
    db.update::<Basic, _>(&mut doc).await?;
829

            
830
5
    assert_eq!(CollectionHeader::try_from(doc.header)?, header);
831

            
832
5
    Ok(())
833
5
}
834

            
835
5
pub async fn get_multiple_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
836
5
    let collection = db.collection::<Basic>();
837
5
    let doc1_value = Basic::new("initial_value");
838
5
    let doc1 = collection.push(&doc1_value).await?;
839

            
840
5
    let doc2_value = Basic::new("second_value");
841
5
    let doc2 = collection.push(&doc2_value).await?;
842

            
843
5
    let both_docs = Basic::get_multiple([doc1.id, doc2.id], db).await?;
844
5
    assert_eq!(both_docs.len(), 2);
845

            
846
5
    let out_of_order = Basic::get_multiple([doc2.id, doc1.id], db).await?;
847
5
    assert_eq!(out_of_order.len(), 2);
848

            
849
    // The order of get_multiple isn't guaranteed, so these two checks are done
850
    // with iterators instead of direct indexing
851
5
    let doc1 = both_docs
852
5
        .iter()
853
5
        .find(|doc| doc.header.id == doc1.id)
854
5
        .expect("Couldn't find doc1");
855
5
    assert_eq!(doc1.contents.value, doc1_value.value);
856
5
    let doc2 = both_docs
857
5
        .iter()
858
10
        .find(|doc| doc.header.id == doc2.id)
859
5
        .expect("Couldn't find doc2");
860
5
    assert_eq!(doc2.contents.value, doc2_value.value);
861

            
862
5
    Ok(())
863
5
}
864

            
865
5
pub async fn list_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
866
5
    let collection = db.collection::<Basic>();
867
5
    let doc1_value = Basic::new("initial_value");
868
5
    let doc1 = collection.push(&doc1_value).await?;
869

            
870
5
    let doc2_value = Basic::new("second_value");
871
5
    let doc2 = collection.push(&doc2_value).await?;
872

            
873
5
    let all_docs = Basic::all(db).await?;
874
5
    assert_eq!(all_docs.len(), 2);
875
5
    assert_eq!(Basic::all(db).count().await?, 2);
876

            
877
5
    let both_docs = Basic::list(doc1.id..=doc2.id, db).await?;
878
5
    assert_eq!(both_docs.len(), 2);
879
5
    assert_eq!(Basic::list(doc1.id..=doc2.id, db).count().await?, 2);
880

            
881
5
    assert_eq!(both_docs[0].contents.value, doc1_value.value);
882
5
    assert_eq!(both_docs[1].contents.value, doc2_value.value);
883

            
884
5
    let one_doc = Basic::list(doc1.id..doc2.id, db).await?;
885
5
    assert_eq!(one_doc.len(), 1);
886

            
887
5
    let limited = Basic::list(doc1.id..=doc2.id, db)
888
5
        .limit(1)
889
5
        .descending()
890
5
        .await?;
891
5
    assert_eq!(limited.len(), 1);
892
5
    assert_eq!(limited[0].contents.value, doc2_value.value);
893

            
894
5
    Ok(())
895
5
}
896

            
897
5
pub async fn list_transactions_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
898
5
    let collection = db.collection::<Basic>();
899
5

            
900
5
    // create LIST_TRANSACTIONS_MAX_RESULTS + 1 items, giving us just enough
901
5
    // transactions to test the edge cases of `list_transactions`
902
5
    futures::future::join_all(
903
5
        (0..=(LIST_TRANSACTIONS_MAX_RESULTS))
904
5005
            .map(|_| async { collection.push(&Basic::default()).await.unwrap() }),
905
4570
    )
906
4570
    .await;
907

            
908
    // Test defaults
909
5
    let transactions = db.list_executed_transactions(None, None).await?;
910
5
    assert_eq!(transactions.len(), LIST_TRANSACTIONS_DEFAULT_RESULT_COUNT);
911

            
912
    // Test max results limit
913
5
    let transactions = db
914
5
        .list_executed_transactions(None, Some(LIST_TRANSACTIONS_MAX_RESULTS + 1))
915
5
        .await?;
916
5
    assert_eq!(transactions.len(), LIST_TRANSACTIONS_MAX_RESULTS);
917

            
918
    // Test requesting 0 items
919
5
    let transactions = db.list_executed_transactions(None, Some(0)).await?;
920
5
    assert!(transactions.is_empty());
921

            
922
    // Test doing a loop fetching until we get no more results
923
5
    let mut transactions = Vec::new();
924
5
    let mut starting_id = None;
925
    loop {
926
60
        let chunk = db
927
60
            .list_executed_transactions(starting_id, Some(100))
928
60
            .await?;
929
60
        if chunk.is_empty() {
930
5
            break;
931
55
        }
932
55

            
933
55
        let max_id = chunk.last().map(|tx| tx.id).unwrap();
934
55
        starting_id = Some(max_id + 1);
935
55
        transactions.extend(chunk);
936
    }
937

            
938
5
    assert_eq!(transactions.len(), LIST_TRANSACTIONS_MAX_RESULTS + 1);
939

            
940
5
    Ok(())
941
5
}
942

            
943
5
pub async fn view_query_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
944
5
    let collection = db.collection::<Basic>();
945
5
    let a = collection.push(&Basic::new("A")).await?;
946
5
    let b = collection.push(&Basic::new("B")).await?;
947
5
    let a_child = collection
948
5
        .push(
949
5
            &Basic::new("A.1")
950
5
                .with_parent_id(a.id)
951
5
                .with_category("Alpha"),
952
5
        )
953
5
        .await?;
954
5
    collection
955
5
        .push(&Basic::new("B.1").with_parent_id(b.id).with_category("Beta"))
956
5
        .await?;
957
5
    collection
958
5
        .push(&Basic::new("B.2").with_parent_id(b.id).with_category("beta"))
959
5
        .await?;
960

            
961
5
    let a_children = db
962
5
        .view::<BasicByParentId>()
963
5
        .with_key(Some(a.id))
964
8
        .query()
965
8
        .await?;
966
5
    assert_eq!(a_children.len(), 1);
967

            
968
5
    let a_children = db
969
5
        .view::<BasicByParentId>()
970
5
        .with_key(Some(a.id))
971
8
        .query_with_collection_docs()
972
8
        .await?;
973
5
    assert_eq!(a_children.len(), 1);
974
5
    assert_eq!(a_children.get(0).unwrap().document.header, a_child);
975

            
976
5
    let b_children = db
977
5
        .view::<BasicByParentId>()
978
5
        .with_key(Some(b.id))
979
5
        .query()
980
2
        .await?;
981
5
    assert_eq!(b_children.len(), 2);
982

            
983
5
    let a_and_b_children = db
984
5
        .view::<BasicByParentId>()
985
5
        .with_keys([Some(a.id), Some(b.id)])
986
5
        .query()
987
2
        .await?;
988
5
    assert_eq!(a_and_b_children.len(), 3);
989

            
990
    // Test out of order keys
991
5
    let a_and_b_children = db
992
5
        .view::<BasicByParentId>()
993
5
        .with_keys([Some(b.id), Some(a.id)])
994
5
        .query()
995
2
        .await?;
996
5
    assert_eq!(a_and_b_children.len(), 3);
997

            
998
5
    let has_parent = db
999
5
        .view::<BasicByParentId>()
5
        .with_key_range(Some(0)..=Some(u64::MAX))
5
        .query()
2
        .await?;
5
    assert_eq!(has_parent.len(), 3);
    // Verify the result is sorted ascending
5
    assert!(has_parent
5
        .windows(2)
10
        .all(|window| window[0].key <= window[1].key));

            
    // Test limiting and descending order
5
    let last_with_parent = db
5
        .view::<BasicByParentId>()
5
        .with_key_range(Some(0)..=Some(u64::MAX))
5
        .descending()
5
        .limit(1)
5
        .query()
2
        .await?;
10
    assert_eq!(last_with_parent.iter().map(|m| m.key).unique().count(), 1);
5
    assert_eq!(last_with_parent[0].key, has_parent[2].key);

            
8
    let items_with_categories = db.view::<BasicByCategory>().query().await?;
5
    assert_eq!(items_with_categories.len(), 3);

            
    // Test deleting
5
    let deleted_count = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(b.id))
8
        .delete_docs()
8
        .await?;
5
    assert_eq!(b_children.len() as u64, deleted_count);
    assert_eq!(
5
        db.view::<BasicByParentId>()
5
            .with_key(Some(b.id))
5
            .query()
5
            .await?
5
            .len(),
        0
    );

            
5
    Ok(())
5
}

            
5
pub async fn unassociated_collection_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
5
    let result = db
5
        .collection::<UnassociatedCollection>()
5
        .push(&UnassociatedCollection)
5
        .await;
5
    match result {
5
        Err(Error::CollectionNotFound) => {}
        other => unreachable!("unexpected result: {:?}", other),
    }

            
5
    Ok(())
5
}

            
5
pub async fn unimplemented_reduce<C: Connection>(db: &C) -> anyhow::Result<()> {
5
    assert!(matches!(
8
        db.view::<UniqueValue>().reduce().await,
        Err(Error::ReduceUnimplemented)
    ));
5
    Ok(())
5
}

            
5
pub async fn view_update_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
5
    let collection = db.collection::<Basic>();
5
    let a = collection.push(&Basic::new("A")).await?;

            
5
    let a_children = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(a.id))
8
        .query()
8
        .await?;
5
    assert_eq!(a_children.len(), 0);
    // The reduce function of `BasicByParentId` acts as a "count" of records.
    assert_eq!(
5
        db.view::<BasicByParentId>()
5
            .with_key(Some(a.id))
5
            .reduce()
5
            .await?,
        0
    );

            
    // Test inserting a new record and the view being made available
5
    let a_child = collection
5
        .push(
5
            &Basic::new("A.1")
5
                .with_parent_id(a.id)
5
                .with_category("Alpha"),
5
        )
5
        .await?;

            
5
    let a_children = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(a.id))
5
        .query()
5
        .await?;
5
    assert_eq!(a_children.len(), 1);
    assert_eq!(
5
        db.view::<BasicByParentId>()
5
            .with_key(Some(a.id))
5
            .reduce()
5
            .await?,
        1
    );

            
    // Verify reduce_grouped matches our expectations.
    assert_eq!(
5
        db.view::<BasicByParentId>().reduce_grouped().await?,
5
        vec![MappedValue::new(None, 1,), MappedValue::new(Some(a.id), 1,),]
    );

            
    // Test updating the record and the view being updated appropriately
5
    let mut doc = db.collection::<Basic>().get(a_child.id).await?.unwrap();
5
    let mut basic = Basic::document_contents(&doc)?;
5
    basic.parent_id = None;
5
    Basic::set_document_contents(&mut doc, basic)?;
5
    db.update::<Basic, _>(&mut doc).await?;

            
5
    let a_children = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(a.id))
5
        .query()
5
        .await?;
5
    assert_eq!(a_children.len(), 0);
    assert_eq!(
5
        db.view::<BasicByParentId>()
5
            .with_key(Some(a.id))
5
            .reduce()
5
            .await?,
        0
    );
5
    assert_eq!(db.view::<BasicByParentId>().reduce().await?, 2);

            
    // Test deleting a record and ensuring it goes away
5
    db.collection::<Basic>().delete(&doc).await?;

            
5
    let all_entries = db.view::<BasicByParentId>().query().await?;
5
    assert_eq!(all_entries.len(), 1);

            
    // Verify reduce_grouped matches our expectations.
    assert_eq!(
5
        db.view::<BasicByParentId>().reduce_grouped().await?,
5
        vec![MappedValue::new(None, 1,),]
    );

            
5
    Ok(())
5
}

            
5
pub async fn view_multi_emit_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
5
    let mut a = Basic::new("A")
5
        .with_tag("red")
5
        .with_tag("green")
5
        .push_into(db)
5
        .await?;
5
    let mut b = Basic::new("B")
5
        .with_tag("blue")
5
        .with_tag("green")
5
        .push_into(db)
5
        .await?;

            
8
    assert_eq!(db.view::<BasicByTag>().query().await?.len(), 4);

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("green"))
5
            .query()
5
            .await?
5
            .len(),
        2
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("red"))
5
            .query()
2
            .await?
5
            .len(),
        1
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("blue"))
5
            .query()
2
            .await?
5
            .len(),
        1
    );

            
    // Change tags
5
    a.contents.tags = vec![String::from("red"), String::from("blue")];
5
    a.update(db).await?;

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("green"))
5
            .query()
5
            .await?
5
            .len(),
        1
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("red"))
5
            .query()
5
            .await?
5
            .len(),
        1
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("blue"))
5
            .query()
2
            .await?
5
            .len(),
        2
    );
5
    b.contents.tags.clear();
5
    b.update(db).await?;

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("green"))
5
            .query()
5
            .await?
5
            .len(),
        0
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("red"))
5
            .query()
5
            .await?
5
            .len(),
        1
    );

            
    assert_eq!(
5
        db.view::<BasicByTag>()
5
            .with_key(String::from("blue"))
5
            .query()
2
            .await?
5
            .len(),
        1
    );

            
5
    Ok(())
5
}

            
5
pub async fn view_access_policy_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
5
    let collection = db.collection::<Basic>();
5
    let a = collection.push(&Basic::new("A")).await?;

            
    // Test inserting a record that should match the view, but ask for it to be
    // NoUpdate. Verify we get no matches.
5
    collection
5
        .push(
5
            &Basic::new("A.1")
5
                .with_parent_id(a.id)
5
                .with_category("Alpha"),
5
        )
5
        .await?;

            
5
    let a_children = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(a.id))
5
        .with_access_policy(AccessPolicy::NoUpdate)
5
        .query()
2
        .await?;
5
    assert_eq!(a_children.len(), 0);

            
5
    tokio::time::sleep(Duration::from_millis(20)).await;

            
    // Verify the view still have no value, but this time ask for it to be
    // updated after returning
5
    let a_children = db
5
        .view::<BasicByParentId>()
5
        .with_key(Some(a.id))
5
        .with_access_policy(AccessPolicy::UpdateAfter)
5
        .query()
2
        .await?;
5
    assert_eq!(a_children.len(), 0);

            
    // Waiting on background jobs can be unreliable in a CI environment
6
    for _ in 0..10_u8 {
6
        tokio::time::sleep(Duration::from_millis(20)).await;

            
        // Now, the view should contain the entry.
6
        let a_children = db
6
            .view::<BasicByParentId>()
6
            .with_key(Some(a.id))
6
            .with_access_policy(AccessPolicy::NoUpdate)
6
            .query()
3
            .await?;
6
        if a_children.len() == 1 {
5
            return Ok(());
1
        }
    }
    panic!("view never updated")
5
}

            
5
pub async fn unique_view_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
11
    let first_doc = db.collection::<Unique>().push(&Unique::new("1")).await?;

            
    if let Err(Error::UniqueKeyViolation {
5
        view,
5
        existing_document,
5
        conflicting_document,
5
    }) = db.collection::<Unique>().push(&Unique::new("1")).await
    {
5
        assert_eq!(view, UniqueValue.view_name());
5
        assert_eq!(first_doc.id, existing_document.id.deserialize()?);
        // We can't predict the conflicting document id since it's generated
        // inside of the transaction, but we can assert that it's different than
        // the document that was previously stored.
5
        assert_ne!(conflicting_document, existing_document);
    } else {
        unreachable!("unique key violation not triggered");
    }

            
5
    let second_doc = db.collection::<Unique>().push(&Unique::new("2")).await?;
5
    let mut second_doc = db.collection::<Unique>().get(second_doc.id).await?.unwrap();
5
    let mut contents = Unique::document_contents(&second_doc)?;
5
    contents.value = String::from("1");
5
    Unique::set_document_contents(&mut second_doc, contents)?;
    if let Err(Error::UniqueKeyViolation {
5
        view,
5
        existing_document,
5
        conflicting_document,
5
    }) = db.update::<Unique, _>(&mut second_doc).await
    {
5
        assert_eq!(view, UniqueValue.view_name());
5
        assert_eq!(first_doc.id, existing_document.id.deserialize()?);
5
        assert_eq!(conflicting_document.id, second_doc.header.id);
    } else {
        unreachable!("unique key violation not triggered");
    }

            
5
    Ok(())
5
}

            
5
pub async fn named_collection_tests<C: Connection>(db: &C) -> anyhow::Result<()> {
11
    Unique::new("0").push_into(db).await?;
5
    let original_entry = Unique::entry("1", db)
5
        .update_with(|_existing: &mut Unique| unreachable!())
13
        .or_insert_with(|| Unique::new("1"))
13
        .await?
5
        .expect("Document not inserted");

            
5
    let updated = Unique::entry("1", db)
5
        .update_with(|existing: &mut Unique| {
5
            existing.value = String::from("2");
5
        })
13
        .or_insert_with(|| unreachable!())
13
        .await?
5
        .unwrap();
5
    assert_eq!(original_entry.header.id, updated.header.id);
5
    assert_ne!(original_entry.contents.value, updated.contents.value);

            
8
    let retrieved = Unique::entry("2", db).await?.unwrap();
5
    assert_eq!(retrieved.contents.value, updated.contents.value);

            
5
    let conflict = Unique::entry("2", db)
5
        .update_with(|existing: &mut Unique| {
5
            existing.value = String::from("0");
10
        })
10
        .await;
5
    assert!(matches!(conflict, Err(Error::UniqueKeyViolation { .. })));

            
5
    Ok(())
5
}

            
5
pub async fn compaction_tests<C: Connection + KeyValue>(db: &C) -> anyhow::Result<()> {
5
    let original_value = Basic::new("initial_value");
5
    let collection = db.collection::<Basic>();
5
    collection.push(&original_value).await?;

            
    // Test a collection compaction
5
    db.compact_collection::<Basic>().await?;

            
    // Test the key value store compaction
5
    db.set_key("foo", &1_u32).await?;
5
    db.compact_key_value_store().await?;

            
    // Compact everything... again...
5
    db.compact().await?;

            
5
    Ok(())
5
}

            
#[cfg(feature = "multiuser")]
4
pub async fn user_management_tests<C: Connection, S: StorageConnection>(
4
    admin: &C,
4
    server: S,
4
    server_name: &str,
4
) -> anyhow::Result<()> {
4
    let username = format!("user-management-tests-{}", server_name);
8
    let user_id = server.create_user(&username).await?;
    // Test the default created user state.
    {
4
        let user = User::get(user_id, admin)
4
            .await
4
            .unwrap()
4
            .expect("user not found");
4
        assert_eq!(user.contents.username, username);
4
        assert!(user.contents.groups.is_empty());
4
        assert!(user.contents.roles.is_empty());
    }

            
4
    let role = Role::named(format!("role-{}", server_name))
8
        .push_into(admin)
8
        .await
4
        .unwrap();
4
    let group = PermissionGroup::named(format!("group-{}", server_name))
8
        .push_into(admin)
8
        .await
4
        .unwrap();
4

            
4
    // Add the role and group.
4
    server
6
        .add_permission_group_to_user(user_id, &group)
6
        .await
4
        .unwrap();
6
    server.add_role_to_user(user_id, &role).await.unwrap();

            
    // Test the results
    {
4
        let user = User::get(user_id, admin)
4
            .await
4
            .unwrap()
4
            .expect("user not found");
4
        assert_eq!(user.contents.groups, vec![group.header.id]);
4
        assert_eq!(user.contents.roles, vec![role.header.id]);
    }

            
    // Add the same things again (should not do anything). With names this time.
4
    server
6
        .add_permission_group_to_user(&username, &group)
6
        .await
4
        .unwrap();
4
    server.add_role_to_user(&username, &role).await.unwrap();
    {
4
        let user = User::load(&username, admin)
4
            .await
4
            .unwrap()
4
            .expect("user not found");
4
        assert_eq!(user.contents.groups, vec![group.header.id]);
4
        assert_eq!(user.contents.roles, vec![role.header.id]);
    }

            
    // Remove the group.
4
    server
6
        .remove_permission_group_from_user(user_id, &group)
6
        .await
4
        .unwrap();
6
    server.remove_role_from_user(user_id, &role).await.unwrap();
    {
4
        let user = User::get(user_id, admin)
4
            .await
4
            .unwrap()
4
            .expect("user not found");
4
        assert!(user.contents.groups.is_empty());
4
        assert!(user.contents.roles.is_empty());
    }

            
    // Removing again shouldn't cause an error.
4
    server
4
        .remove_permission_group_from_user(user_id, &group)
4
        .await?;
4
    server.remove_role_from_user(user_id, &role).await?;

            
    // Remove the user
6
    server.delete_user(user_id).await?;
    // Test if user is removed.
4
    assert!(User::get(user_id, admin).await.unwrap().is_none());

            
4
    Ok(())
4
}

            
/// Defines the `KeyValue` test suite
#[macro_export]
macro_rules! define_kv_test_suite {
    ($harness:ident) => {
        #[tokio::test]
5
        async fn basic_kv_test() -> anyhow::Result<()> {
            use $crate::keyvalue::{KeyStatus, KeyValue};
            let harness = $harness::new($crate::test_util::HarnessTest::KvBasic).await?;
            let db = harness.connect().await?;
            assert_eq!(
                db.set_key("akey", &String::from("avalue")).await?,
                KeyStatus::Inserted
            );
            assert_eq!(
                db.get_key("akey").into().await?,
                Some(String::from("avalue"))
            );
            assert_eq!(
                db.set_key("akey", &String::from("new_value"))
                    .returning_previous_as()
                    .await?,
                Some(String::from("avalue"))
            );
            assert_eq!(
                db.get_key("akey").into().await?,
                Some(String::from("new_value"))
            );
            assert_eq!(
                db.get_key("akey").and_delete().into().await?,
                Some(String::from("new_value"))
            );
            assert_eq!(db.get_key("akey").await?, None);
            assert_eq!(
                db.set_key("akey", &String::from("new_value"))
                    .returning_previous()
                    .await?,
                None
            );
            assert_eq!(db.delete_key("akey").await?, KeyStatus::Deleted);
            assert_eq!(db.delete_key("akey").await?, KeyStatus::NotChanged);

            
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn kv_concurrency() -> anyhow::Result<()> {
            use $crate::keyvalue::{KeyStatus, KeyValue};
            const WRITERS: usize = 100;
            const INCREMENTS: usize = 100;
            let harness = $harness::new($crate::test_util::HarnessTest::KvConcurrency).await?;
            let db = harness.connect().await?;

            
            let handles = (0..WRITERS).map(|_| {
                let db = db.clone();
                tokio::task::spawn(async move {
                    for _ in 0..INCREMENTS {
                        db.increment_key_by("concurrency", 1_u64).await.unwrap();
                    }
                })
            });
            futures::future::join_all(handles).await;

            
            assert_eq!(
                db.get_key("concurrency").into_u64().await.unwrap().unwrap(),
                (WRITERS * INCREMENTS) as u64
            );

            
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn kv_set_tests() -> anyhow::Result<()> {
            use $crate::keyvalue::{KeyStatus, KeyValue};
            let harness = $harness::new($crate::test_util::HarnessTest::KvSet).await?;
            let db = harness.connect().await?;
            let kv = db.with_key_namespace("set");

            
            assert_eq!(
                kv.set_key("a", &0_u32).only_if_exists().await?,
                KeyStatus::NotChanged
            );
            assert_eq!(
                kv.set_key("a", &0_u32).only_if_vacant().await?,
                KeyStatus::Inserted
            );
            assert_eq!(
                kv.set_key("a", &1_u32).only_if_vacant().await?,
                KeyStatus::NotChanged
            );
            assert_eq!(
                kv.set_key("a", &2_u32).only_if_exists().await?,
                KeyStatus::Updated,
            );
            assert_eq!(
                kv.set_key("a", &3_u32).returning_previous_as().await?,
                Some(2_u32),
            );

            
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn kv_increment_decrement_tests() -> anyhow::Result<()> {
            use $crate::keyvalue::{KeyStatus, KeyValue};
            let harness =
                $harness::new($crate::test_util::HarnessTest::KvIncrementDecrement).await?;
            let db = harness.connect().await?;
            let kv = db.with_key_namespace("increment_decrement");

            
            // Empty keys should be equal to 0
            assert_eq!(kv.increment_key_by("i64", 1_i64).await?, 1_i64);
            assert_eq!(kv.get_key("i64").into_i64().await?, Some(1_i64));
            assert_eq!(kv.increment_key_by("u64", 1_u64).await?, 1_u64);
            $crate::assert_f64_eq!(kv.increment_key_by("f64", 1_f64).await?, 1_f64);

            
            // Test float incrementing/decrementing an existing value
            $crate::assert_f64_eq!(kv.increment_key_by("f64", 1_f64).await?, 2_f64);
            $crate::assert_f64_eq!(kv.decrement_key_by("f64", 2_f64).await?, 0_f64);

            
            // Empty keys should be equal to 0
            assert_eq!(kv.decrement_key_by("i64_2", 1_i64).await?, -1_i64);
            assert_eq!(
                kv.decrement_key_by("u64_2", 42_u64)
                    .allow_overflow()
                    .await?,
                u64::MAX - 41
            );
            assert_eq!(kv.decrement_key_by("u64_3", 42_u64).await?, u64::MIN);
            $crate::assert_f64_eq!(kv.decrement_key_by("f64_2", 1_f64).await?, -1_f64);

            
            // Test decrement wrapping with overflow
            kv.set_numeric_key("i64", i64::MIN).await?;
            assert_eq!(
                kv.decrement_key_by("i64", 1_i64).allow_overflow().await?,
                i64::MAX
            );
            assert_eq!(
                kv.decrement_key_by("u64", 2_u64).allow_overflow().await?,
                u64::MAX
            );

            
            // Test increment wrapping with overflow
            assert_eq!(
                kv.increment_key_by("i64", 1_i64).allow_overflow().await?,
                i64::MIN
            );
            assert_eq!(
                kv.increment_key_by("u64", 1_u64).allow_overflow().await?,
                u64::MIN
            );

            
            // Test saturating increments.
            kv.set_numeric_key("i64", i64::MAX - 1).await?;
            kv.set_numeric_key("u64", u64::MAX - 1).await?;
            assert_eq!(kv.increment_key_by("i64", 2_i64).await?, i64::MAX);
            assert_eq!(kv.increment_key_by("u64", 2_u64).await?, u64::MAX);

            
            // Test saturating decrements.
            kv.set_numeric_key("i64", i64::MIN + 1).await?;
            kv.set_numeric_key("u64", u64::MIN + 1).await?;
            assert_eq!(kv.decrement_key_by("i64", 2_i64).await?, i64::MIN);
            assert_eq!(kv.decrement_key_by("u64", 2_u64).await?, u64::MIN);

            
            // Test numerical conversion safety using get
            {
                // For i64 -> f64, the limit is 2^52 + 1 in either posive or
                // negative directions.
                kv.set_numeric_key("i64", (2_i64.pow(f64::MANTISSA_DIGITS)))
                    .await?;
                $crate::assert_f64_eq!(
                    kv.get_key("i64").into_f64().await?.unwrap(),
                    9_007_199_254_740_992_f64
                );
                kv.set_numeric_key("i64", -(2_i64.pow(f64::MANTISSA_DIGITS)))
                    .await?;
                $crate::assert_f64_eq!(
                    kv.get_key("i64").into_f64().await?.unwrap(),
                    -9_007_199_254_740_992_f64
                );

            
                kv.set_numeric_key("i64", (2_i64.pow(f64::MANTISSA_DIGITS) + 1))
                    .await?;
                assert!(matches!(kv.get_key("i64").into_f64().await, Err(_)));
                $crate::assert_f64_eq!(
                    kv.get_key("i64").into_f64_lossy().await?.unwrap(),
                    9_007_199_254_740_993_f64
                );
                kv.set_numeric_key("i64", -(2_i64.pow(f64::MANTISSA_DIGITS) + 1))
                    .await?;
                assert!(matches!(kv.get_key("i64").into_f64().await, Err(_)));
                $crate::assert_f64_eq!(
                    kv.get_key("i64").into_f64_lossy().await?.unwrap(),
                    -9_007_199_254_740_993_f64
                );

            
                // For i64 -> u64, the only limit is sign.
                kv.set_numeric_key("i64", -1_i64).await?;
                assert!(matches!(kv.get_key("i64").into_u64().await, Err(_)));
                assert_eq!(
                    kv.get_key("i64").into_u64_lossy(true).await?.unwrap(),
                    0_u64
                );
                assert_eq!(
                    kv.get_key("i64").into_u64_lossy(false).await?.unwrap(),
                    u64::MAX
                );

            
                // For f64 -> i64, the limit is fractional numbers. Saturating isn't tested in this conversion path.
                kv.set_numeric_key("f64", 1.1_f64).await?;
                assert!(matches!(kv.get_key("f64").into_i64().await, Err(_)));
                assert_eq!(
                    kv.get_key("f64").into_i64_lossy(false).await?.unwrap(),
                    1_i64
                );
                kv.set_numeric_key("f64", -1.1_f64).await?;
                assert!(matches!(kv.get_key("f64").into_i64().await, Err(_)));
                assert_eq!(
                    kv.get_key("f64").into_i64_lossy(false).await?.unwrap(),
                    -1_i64
                );

            
                // For f64 -> u64, the limit is fractional numbers or negative numbers. Saturating isn't tested in this conversion path.
                kv.set_numeric_key("f64", 1.1_f64).await?;
                assert!(matches!(kv.get_key("f64").into_u64().await, Err(_)));
                assert_eq!(
                    kv.get_key("f64").into_u64_lossy(false).await?.unwrap(),
                    1_u64
                );
                kv.set_numeric_key("f64", -1.1_f64).await?;
                assert!(matches!(kv.get_key("f64").into_u64().await, Err(_)));
                assert_eq!(
                    kv.get_key("f64").into_u64_lossy(false).await?.unwrap(),
                    0_u64
                );

            
                // For u64 -> i64, the limit is > i64::MAX
                kv.set_numeric_key("u64", i64::MAX as u64 + 1).await?;
                assert!(matches!(kv.get_key("u64").into_i64().await, Err(_)));
                assert_eq!(
                    kv.get_key("u64").into_i64_lossy(true).await?.unwrap(),
                    i64::MAX
                );
                assert_eq!(
                    kv.get_key("u64").into_i64_lossy(false).await?.unwrap(),
                    i64::MIN
                );
            }

            
            // Test that non-numeric keys won't be changed when attempting to incr/decr
            kv.set_key("non-numeric", &String::from("test")).await?;
            assert!(matches!(
                kv.increment_key_by("non-numeric", 1_i64).await,
                Err(_)
            ));
            assert!(matches!(
                kv.decrement_key_by("non-numeric", 1_i64).await,
                Err(_)
            ));
            assert_eq!(
                kv.get_key("non-numeric").into::<String>().await?.unwrap(),
                String::from("test")
            );

            
            // Test that NaN cannot be stored
            kv.set_numeric_key("f64", 0_f64).await?;
            assert!(matches!(
                kv.set_numeric_key("f64", f64::NAN).await,
                Err(bonsaidb_core::Error::NotANumber)
            ));
            // Verify the value was unchanged.
            $crate::assert_f64_eq!(kv.get_key("f64").into_f64().await?.unwrap(), 0.);
            // Try to increment by nan
            assert!(matches!(
                kv.increment_key_by("f64", f64::NAN).await,
                Err(bonsaidb_core::Error::NotANumber)
            ));
            $crate::assert_f64_eq!(kv.get_key("f64").into_f64().await?.unwrap(), 0.);

            
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn kv_expiration_tests() -> anyhow::Result<()> {
            use std::time::Duration;

            
            use $crate::keyvalue::{KeyStatus, KeyValue};

            
            let harness = $harness::new($crate::test_util::HarnessTest::KvExpiration).await?;
            let db = harness.connect().await?;

            
            loop {
                let kv = db.with_key_namespace("expiration");

            
                kv.delete_key("a").await?;
                kv.delete_key("b").await?;

            
                // Test that the expiration is updated for key a, but not for key b.
                let timing = $crate::test_util::TimingTest::new(Duration::from_millis(500));
                let (r1, r2) = tokio::join!(
                    kv.set_key("a", &0_u32).expire_in(Duration::from_secs(2)),
                    kv.set_key("b", &0_u32).expire_in(Duration::from_secs(2))
                );
                if timing.elapsed() > Duration::from_millis(500) {
                    println!(
                        "Restarting test {}. Took too long {:?}",
                        line!(),
                        timing.elapsed(),
                    );
                    continue;
                }
                assert_eq!(r1?, KeyStatus::Inserted);
                assert_eq!(r2?, KeyStatus::Inserted);
                let (r1, r2) = tokio::join!(
                    kv.set_key("a", &1_u32).expire_in(Duration::from_secs(4)),
                    kv.set_key("b", &1_u32)
                        .expire_in(Duration::from_secs(100))
                        .keep_existing_expiration()
                );
                if timing.elapsed() > Duration::from_secs(1) {
                    println!(
                        "Restarting test {}. Took too long {:?}",
                        line!(),
                        timing.elapsed(),
                    );
                    continue;
                }

            
                assert_eq!(r1?, KeyStatus::Updated, "a wasn't an update");
                assert_eq!(r2?, KeyStatus::Updated, "b wasn't an update");

            
                let a = kv.get_key("a").into().await?;
                assert_eq!(a, Some(1u32), "a shouldn't have expired yet");

            
                // Before checking the value, make sure we haven't elapsed too
                // much time. If so, just restart the test.
                if !timing.wait_until(Duration::from_secs_f32(3.)).await {
                    println!(
                        "Restarting test {}. Took too long {:?}",
                        line!(),
                        timing.elapsed()
                    );
                    continue;
                }

            
                assert_eq!(kv.get_key("b").await?, None, "b never expired");

            
                timing.wait_until(Duration::from_secs_f32(5.)).await;
                assert_eq!(kv.get_key("a").await?, None, "a never expired");
                break;
            }
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn delete_expire_tests() -> anyhow::Result<()> {
            use std::time::Duration;

            
            use $crate::keyvalue::{KeyStatus, KeyValue};

            
            let harness = $harness::new($crate::test_util::HarnessTest::KvDeleteExpire).await?;
            let db = harness.connect().await?;

            
            loop {
                let kv = db.with_key_namespace("delete_expire");

            
                kv.delete_key("a").await?;

            
                let timing = $crate::test_util::TimingTest::new(Duration::from_millis(100));

            
                // Create a key with an expiration. Delete the key. Set a new
                // value at that key with no expiration. Ensure it doesn't
                // expire.
                kv.set_key("a", &0_u32)
                    .expire_in(Duration::from_secs(2))
                    .await?;
                kv.delete_key("a").await?;
                kv.set_key("a", &1_u32).await?;
                if timing.elapsed() > Duration::from_secs(1) {
                    println!(
                        "Restarting test {}. Took too long {:?}",
                        line!(),
                        timing.elapsed(),
                    );
                    continue;
                }
                if !timing.wait_until(Duration::from_secs_f32(2.5)).await {
                    println!(
                        "Restarting test {}. Took too long {:?}",
                        line!(),
                        timing.elapsed()
                    );
                    continue;
                }

            
                assert_eq!(kv.get_key("a").into().await?, Some(1u32));

            
                break;
            }
            harness.shutdown().await?;

            
            Ok(())
        }

            
        #[tokio::test]
5
        async fn kv_transaction_tests() -> anyhow::Result<()> {
            use std::time::Duration;

            
            use $crate::{
                connection::Connection,
                keyvalue::{KeyStatus, KeyValue},
            };
            let harness = $harness::new($crate::test_util::HarnessTest::KvTransactions).await?;
            let db = harness.connect().await?;
            // Generate several transactions that we can validate. Persisting
            // happens in the background, so we delay between each step to give
            // it a moment.
            db.set_key("expires", &0_u32)
                .expire_in(Duration::from_secs(1))
                .await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            db.set_key("akey", &String::from("avalue")).await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            db.get_key("akey").and_delete().await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            db.set_numeric_key("nkey", 0_u64).await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            db.increment_key_by("nkey", 1_u64).await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            db.delete_key("nkey").await?;
            tokio::time::sleep(Duration::from_millis(100)).await;
            // Ensure this doesn't generate a transaction.
            db.delete_key("nkey").await?;

            
            tokio::time::sleep(Duration::from_secs(1)).await;

            
            let transactions = Connection::list_executed_transactions(&db, None, None).await?;
            let deleted_keys = transactions
                .iter()
                .filter_map(|tx| tx.changes.keys())
                .flatten()
                .filter(|changed_key| changed_key.deleted)
                .count();
            assert_eq!(deleted_keys, 3);
            let akey_changes = transactions
                .iter()
                .filter_map(|tx| tx.changes.keys())
                .flatten()
                .filter(|changed_key| changed_key.key == "akey")
                .count();
            assert_eq!(akey_changes, 2);
            let nkey_changes = transactions
                .iter()
                .filter_map(|tx| tx.changes.keys())
                .flatten()
                .filter(|changed_key| changed_key.key == "nkey")
                .count();
            assert_eq!(nkey_changes, 3);

            
            harness.shutdown().await?;

            
            Ok(())
        }
    };
}

            
pub struct TimingTest {
    tolerance: Duration,
    start: Instant,
}

            
impl TimingTest {
    #[must_use]
432
    pub fn new(tolerance: Duration) -> Self {
432
        Self {
432
            tolerance,
432
            start: Instant::now(),
432
        }
432
    }

            
648
    pub async fn wait_until(&self, absolute_duration: Duration) -> bool {
24
        let target = self.start + absolute_duration;
24
        let mut now = Instant::now();
24
        if now < target {
24
            tokio::time::sleep_until(target.into()).await;
24
            now = Instant::now();
        }
24
        let amount_past = now.checked_duration_since(target);
24

            
24
        // Return false if we're beyond the tolerance given
24
        amount_past.unwrap_or_default() < self.tolerance
24
    }

            
    #[must_use]
513
    pub fn elapsed(&self) -> Duration {
513
        Instant::now()
513
            .checked_duration_since(self.start)
513
            .unwrap_or_default()
513
    }
}

            
5
pub async fn basic_server_connection_tests<C: StorageConnection>(
5
    server: C,
5
    newdb_name: &str,
5
) -> anyhow::Result<()> {
5
    let mut schemas = server.list_available_schemas().await?;
5
    schemas.sort();
5
    assert!(schemas.contains(&BasicSchema::schema_name()));
5
    assert!(schemas.contains(&SchemaName::new("khonsulabs", "bonsaidb-admin")));

            
5
    let databases = server.list_databases().await?;
42
    assert!(databases.iter().any(|db| db.name == "tests"));

            
5
    server
8
        .create_database::<BasicSchema>(newdb_name, false)
8
        .await?;
8
    server.delete_database(newdb_name).await?;

            
    assert!(matches!(
5
        server.delete_database(newdb_name).await,
        Err(Error::DatabaseNotFound(_))
    ));

            
    assert!(matches!(
5
        server.create_database::<BasicSchema>("tests", false).await,
        Err(Error::DatabaseNameAlreadyTaken(_))
    ));

            
    assert!(matches!(
5
        server.create_database::<BasicSchema>("tests", true).await,
        Ok(_)
    ));

            
    assert!(matches!(
5
        server
5
            .create_database::<BasicSchema>("|invalidname", false)
2
            .await,
        Err(Error::InvalidDatabaseName(_))
    ));

            
    assert!(matches!(
5
        server
5
            .create_database::<UnassociatedCollection>(newdb_name, false)
2
            .await,
        Err(Error::SchemaNotRegistered(_))
    ));

            
5
    Ok(())
5
}