1
use std::borrow::Cow;
2
use std::fmt::Debug;
3

            
4
use arc_bytes::serde::{Bytes, CowBytes};
5
use serde::de::{self, Visitor};
6
use serde::ser::SerializeStruct;
7
use serde::{Deserialize, Serialize};
8

            
9
use crate::connection::{AsyncConnection, Connection};
10
use crate::document::{
11
    BorrowedDocument, CollectionHeader, DocumentId, HasHeader, Header, OwnedDocument,
12
};
13
use crate::schema::SerializedCollection;
14
use crate::transaction::{Operation, Transaction};
15
use crate::Error;
16

            
17
/// A document with serializable contents.
18
29
#[derive(Clone, Debug, Eq, PartialEq)]
19
pub struct CollectionDocument<C>
20
where
21
    C: SerializedCollection,
22
{
23
    /// The header of the document, which contains the id and `Revision`.
24
    pub header: CollectionHeader<C::PrimaryKey>,
25

            
26
    /// The document's contents.
27
    pub contents: C::Contents,
28
}
29

            
30
impl<'a, C> TryFrom<&'a BorrowedDocument<'a>> for CollectionDocument<C>
31
where
32
    C: SerializedCollection,
33
{
34
    type Error = Error;
35

            
36
59561
    fn try_from(value: &'a BorrowedDocument<'a>) -> Result<Self, Self::Error> {
37
59561
        Ok(Self {
38
59561
            contents: C::deserialize(&value.contents)?,
39
59561
            header: CollectionHeader::try_from(value.header.clone())?,
40
        })
41
59561
    }
42
}
43

            
44
impl<'a, C> TryFrom<&'a OwnedDocument> for CollectionDocument<C>
45
where
46
    C: SerializedCollection,
47
{
48
    type Error = Error;
49

            
50
27684
    fn try_from(value: &'a OwnedDocument) -> Result<Self, Self::Error> {
51
27684
        Ok(Self {
52
27684
            contents: C::deserialize(&value.contents)?,
53
27684
            header: CollectionHeader::try_from(value.header.clone())?,
54
        })
55
27684
    }
56
}
57

            
58
impl<'a, 'b, C> TryFrom<&'b CollectionDocument<C>> for BorrowedDocument<'a>
59
where
60
    C: SerializedCollection,
61
{
62
    type Error = crate::Error;
63

            
64
    fn try_from(value: &'b CollectionDocument<C>) -> Result<Self, Self::Error> {
65
        Ok(Self {
66
            contents: CowBytes::from(C::serialize(&value.contents)?),
67
            header: Header::try_from(value.header.clone())?,
68
        })
69
    }
70
}
71

            
72
impl<C> CollectionDocument<C>
73
where
74
    C: SerializedCollection,
75
{
76
    /// Updates the document stored in the database with the contents of this
77
    /// collection document.
78
    ///
79
    /// ```rust
80
    /// # bonsaidb_core::__doctest_prelude!();
81
    /// # use bonsaidb_core::connection::Connection;
82
    /// # fn test_fn<C: Connection>(db: C) -> Result<(), Error> {
83
    /// if let Some(mut document) = MyCollection::get(&42, &db)? {
84
    ///     // ... do something `document`
85
    ///     document.update(&db)?;
86
    ///     println!(
87
    ///         "The document has been updated: {:?}",
88
    ///         document.header.revision
89
    ///     );
90
    /// }
91
    /// # Ok(())
92
    /// # }
93
    /// ```
94
869
    pub fn update<Cn: Connection>(&mut self, connection: &Cn) -> Result<(), Error> {
95
869
        let mut doc = self.to_document()?;
96

            
97
869
        connection.update::<C, _>(&mut doc)?;
98

            
99
866
        self.header = CollectionHeader::try_from(doc.header)?;
100

            
101
866
        Ok(())
102
869
    }
103

            
104
    /// Pushes an update [`Operation`] to the transaction for this document.
105
    ///
106
    /// The changes will happen once the transaction is applied.
107
8
    pub fn update_in_transaction(&self, transaction: &mut Transaction) -> Result<(), Error> {
108
8
        transaction.push(Operation::update_serialized::<C>(
109
8
            self.header.clone(),
110
8
            &self.contents,
111
8
        )?);
112
8
        Ok(())
113
8
    }
114

            
115
    /// Stores the new value of `contents` in the document.
116
    ///
117
    /// ```rust
118
    /// # bonsaidb_core::__doctest_prelude!();
119
    /// # use bonsaidb_core::connection::AsyncConnection;
120
    /// # fn test_fn<C: AsyncConnection>(db: C) -> Result<(), Error> {
121
    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
122
    /// if let Some(mut document) = MyCollection::get_async(&42, &db).await? {
123
    ///     // modify the document
124
    ///     document.update_async(&db).await?;
125
    ///     println!("Updated revision: {:?}", document.header.revision);
126
    /// }
127
    /// # Ok(())
128
    /// # })
129
    /// # }
130
    /// ```
131
5133
    pub async fn update_async<Cn: AsyncConnection>(
132
5133
        &mut self,
133
5133
        connection: &Cn,
134
5133
    ) -> Result<(), Error> {
135
5133
        let mut doc = self.to_document()?;
136

            
137
7745
        connection.update::<C, _>(&mut doc).await?;
138

            
139
5128
        self.header = CollectionHeader::try_from(doc.header)?;
140

            
141
5128
        Ok(())
142
5133
    }
143

            
144
    /// Modifies `self`, automatically retrying the modification if the document
145
    /// has been updated on the server.
146
    ///
147
    /// The contents of `Ok` will be the last value returned from `modifier`.
148
    ///
149
    /// ## Data loss warning
150
    ///
151
    /// If you've modified `self` before calling this function and a conflict
152
    /// occurs, all changes to self will be lost when the current document is
153
    /// fetched before retrying the process again. When you use this function,
154
    /// you should limit the edits to the value to within the `modifier`
155
    /// callback.
156
3
    pub fn modify<R, Cn: Connection, Modifier: FnMut(&mut Self) -> R + Send + Sync>(
157
3
        &mut self,
158
3
        connection: &Cn,
159
3
        mut modifier: Modifier,
160
3
    ) -> Result<R, Error>
161
3
    where
162
3
        C::Contents: Clone,
163
3
    {
164
3
        let mut is_first_loop = true;
165
        // TODO this should have a retry-limit.
166
3
        loop {
167
3
            // On the first attempt, we want to try sending the update to the
168
3
            // database without fetching new contents. If we receive a conflict,
169
3
            // on future iterations we will first re-load the data.
170
3
            if is_first_loop {
171
3
                is_first_loop = false;
172
3
            } else {
173
                *self =
174
                    C::get(&self.header.id, connection)?.ok_or_else(|| {
175
                        match DocumentId::new(&self.header.id) {
176
                            Ok(id) => Error::DocumentNotFound(C::collection_name(), Box::new(id)),
177
                            Err(err) => err,
178
                        }
179
                    })?;
180
            }
181
3
            let result = modifier(&mut *self);
182
3
            match self.update(connection) {
183
                Err(Error::DocumentConflict(..)) => {}
184
3
                other => return other.map(|()| result),
185
            }
186
        }
187
3
    }
188

            
189
    /// Modifies `self`, automatically retrying the modification if the document
190
    /// has been updated on the server.
191
    ///
192
    /// The contents of `Ok` will be the last value returned from `modifier`.
193
    ///
194
    /// ## Data loss warning
195
    ///
196
    /// If you've modified `self` before calling this function and a conflict
197
    /// occurs, all changes to self will be lost when the current document is
198
    /// fetched before retrying the process again. When you use this function,
199
    /// you should limit the edits to the value to within the `modifier`
200
    /// callback.
201
5
    pub async fn modify_async<
202
5
        R,
203
5
        Cn: AsyncConnection,
204
5
        Modifier: FnMut(&mut Self) -> R + Send + Sync,
205
5
    >(
206
5
        &mut self,
207
5
        connection: &Cn,
208
5
        mut modifier: Modifier,
209
5
    ) -> Result<R, Error>
210
5
    where
211
5
        C::Contents: Clone,
212
5
    {
213
5
        let mut is_first_loop = true;
214
        // TODO this should have a retry-limit.
215
5
        loop {
216
5
            // On the first attempt, we want to try sending the update to the
217
5
            // database without fetching new contents. If we receive a conflict,
218
5
            // on future iterations we will first re-load the data.
219
5
            if is_first_loop {
220
5
                is_first_loop = false;
221
5
            } else {
222
                *self = C::get_async(&self.header.id, connection)
223
                    .await?
224
                    .ok_or_else(|| match DocumentId::new(&self.header.id) {
225
                        Ok(id) => Error::DocumentNotFound(C::collection_name(), Box::new(id)),
226
                        Err(err) => err,
227
                    })?;
228
            }
229
5
            let result = modifier(&mut *self);
230
5
            match self.update_async(connection).await {
231
                Err(Error::DocumentConflict(..)) => {}
232
5
                other => return other.map(|()| result),
233
            }
234
        }
235
5
    }
236

            
237
    /// Removes the document from the collection.
238
    ///
239
    /// ```rust
240
    /// # bonsaidb_core::__doctest_prelude!();
241
    /// # use bonsaidb_core::connection::Connection;
242
    /// # fn test_fn<C: Connection>(db: C) -> Result<(), Error> {
243
    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
244
    /// if let Some(document) = MyCollection::get(&42, &db)? {
245
    ///     document.delete(&db)?;
246
    /// }
247
    /// # Ok(())
248
    /// # })
249
    /// # }
250
    /// ```
251
11
    pub fn delete<Cn: Connection>(&self, connection: &Cn) -> Result<(), Error> {
252
11
        connection.collection::<C>().delete(self)?;
253

            
254
11
        Ok(())
255
11
    }
256

            
257
    /// Removes the document from the collection.
258
    ///
259
    /// ```rust
260
    /// # bonsaidb_core::__doctest_prelude!();
261
    /// # use bonsaidb_core::connection::AsyncConnection;
262
    /// # fn test_fn<C: AsyncConnection>(db: C) -> Result<(), Error> {
263
    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
264
    /// if let Some(document) = MyCollection::get_async(&42, &db).await? {
265
    ///     document.delete_async(&db).await?;
266
    /// }
267
    /// # Ok(())
268
    /// # })
269
    /// # }
270
    /// ```
271
523
    pub async fn delete_async<Cn: AsyncConnection>(&self, connection: &Cn) -> Result<(), Error> {
272
523
        connection.collection::<C>().delete(self).await?;
273

            
274
521
        Ok(())
275
523
    }
276

            
277
    /// Pushes a delete [`Operation`] to the transaction for this document.
278
    ///
279
    /// The document will be deleted once the transaction is applied.
280
8
    pub fn delete_in_transaction(&self, transaction: &mut Transaction) -> Result<(), Error> {
281
8
        transaction.push(Operation::delete(C::collection_name(), self.header()?));
282
8
        Ok(())
283
8
    }
284

            
285
    /// Refreshes this instance from `connection`. If the document is no longer
286
    /// present, [`Error::DocumentNotFound`] will be returned.
287
3
    pub fn refresh<Cn: Connection>(&mut self, connection: &Cn) -> Result<(), Error> {
288
3
        let id = DocumentId::new(&self.header.id)?;
289
3
        *self = C::get(&id, connection)?
290
3
            .ok_or_else(|| Error::DocumentNotFound(C::collection_name(), Box::new(id)))?;
291
3
        Ok(())
292
3
    }
293

            
294
    /// Refreshes this instance from `connection`. If the document is no longer
295
    /// present, [`Error::DocumentNotFound`] will be returned.
296
5
    pub async fn refresh_async<Cn: AsyncConnection>(
297
5
        &mut self,
298
5
        connection: &Cn,
299
5
    ) -> Result<(), Error> {
300
5
        let id = DocumentId::new(&self.header.id)?;
301
5
        *self = C::get_async(&id, connection)
302
5
            .await?
303
5
            .ok_or_else(|| Error::DocumentNotFound(C::collection_name(), Box::new(id)))?;
304
5
        Ok(())
305
5
    }
306

            
307
    /// Converts this value to a serialized `Document`.
308
6002
    pub fn to_document(&self) -> Result<OwnedDocument, Error> {
309
6002
        Ok(OwnedDocument {
310
6002
            contents: Bytes::from(C::serialize(&self.contents)?),
311
6002
            header: Header::try_from(self.header.clone())?,
312
        })
313
6002
    }
314
}
315

            
316
impl<C> Serialize for CollectionDocument<C>
317
where
318
    C: SerializedCollection,
319
    C::Contents: Serialize,
320
    C::PrimaryKey: Serialize,
321
{
322
3
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
323
3
    where
324
3
        S: serde::Serializer,
325
3
    {
326
3
        let mut s = serializer.serialize_struct("CollectionDocument", 2)?;
327
3
        s.serialize_field("header", &self.header)?;
328
3
        s.serialize_field("contents", &self.contents)?;
329
3
        s.end()
330
3
    }
331
}
332

            
333
impl<'de, C> Deserialize<'de> for CollectionDocument<C>
334
where
335
    C: SerializedCollection,
336
    C::PrimaryKey: Deserialize<'de>,
337
    C::Contents: Deserialize<'de>,
338
{
339
2
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
340
2
    where
341
2
        D: serde::Deserializer<'de>,
342
2
    {
343
2
        struct CollectionDocumentVisitor<C>
344
2
        where
345
2
            C: SerializedCollection,
346
2
        {
347
2
            header: Option<CollectionHeader<C::PrimaryKey>>,
348
2
            contents: Option<C::Contents>,
349
2
        }
350
2

            
351
2
        impl<C> Default for CollectionDocumentVisitor<C>
352
2
        where
353
2
            C: SerializedCollection,
354
2
        {
355
2
            fn default() -> Self {
356
2
                Self {
357
2
                    header: None,
358
2
                    contents: None,
359
2
                }
360
2
            }
361
2
        }
362
2

            
363
2
        impl<'de, C> Visitor<'de> for CollectionDocumentVisitor<C>
364
2
        where
365
2
            C: SerializedCollection,
366
2
            C::PrimaryKey: Deserialize<'de>,
367
2
            C::Contents: Deserialize<'de>,
368
2
        {
369
2
            type Value = CollectionDocument<C>;
370
2

            
371
2
            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372
                formatter.write_str("a collection document")
373
            }
374
2

            
375
2
            fn visit_map<A>(mut self, mut map: A) -> Result<Self::Value, A::Error>
376
1
            where
377
1
                A: serde::de::MapAccess<'de>,
378
1
            {
379
3
                while let Some(key) = map.next_key::<Cow<'_, str>>()? {
380
2
                    match key.as_ref() {
381
2
                        "header" => {
382
1
                            self.header = Some(map.next_value()?);
383
2
                        }
384
2
                        "contents" => {
385
1
                            self.contents = Some(map.next_value()?);
386
2
                        }
387
2
                        _ => {
388
2
                            return Err(<A::Error as de::Error>::custom(format!(
389
                                "unknown field {key}"
390
                            )))
391
2
                        }
392
2
                    }
393
2
                }
394
2

            
395
2
                Ok(CollectionDocument {
396
2
                    header: self
397
1
                        .header
398
1
                        .ok_or_else(|| <A::Error as de::Error>::custom("`header` missing"))?,
399
2
                    contents: self
400
1
                        .contents
401
1
                        .ok_or_else(|| <A::Error as de::Error>::custom("`contents` missing"))?,
402
2
                })
403
2
            }
404
2

            
405
2
            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
406
1
            where
407
1
                A: de::SeqAccess<'de>,
408
1
            {
409
2
                let header = seq
410
1
                    .next_element()?
411
2
                    .ok_or_else(|| <A::Error as de::Error>::custom("`header` missing"))?;
412
2
                let contents = seq
413
1
                    .next_element()?
414
2
                    .ok_or_else(|| <A::Error as de::Error>::custom("`contents` missing"))?;
415
2
                Ok(CollectionDocument { header, contents })
416
2
            }
417
2
        }
418
2

            
419
2
        deserializer.deserialize_struct(
420
2
            "CollectionDocument",
421
2
            &["header", "contents"],
422
2
            CollectionDocumentVisitor::default(),
423
2
        )
424
2
    }
425
}
426

            
427
/// Helper functions for a slice of [`OwnedDocument`]s.
428
pub trait OwnedDocuments {
429
    /// Returns a list of deserialized documents.
430
    fn collection_documents<C: SerializedCollection>(
431
        &self,
432
    ) -> Result<Vec<CollectionDocument<C>>, Error>;
433
}
434

            
435
impl OwnedDocuments for [OwnedDocument] {
436
680
    fn collection_documents<C: SerializedCollection>(
437
680
        &self,
438
680
    ) -> Result<Vec<CollectionDocument<C>>, Error> {
439
680
        self.iter().map(CollectionDocument::try_from).collect()
440
680
    }
441
}
442

            
443
1
#[test]
444
1
fn collection_document_serialization() {
445
1
    use crate::test_util::Basic;
446
1

            
447
1
    let original: CollectionDocument<Basic> = CollectionDocument {
448
1
        header: CollectionHeader {
449
1
            id: 1,
450
1
            revision: super::Revision::new(b"hello world"),
451
1
        },
452
1
        contents: Basic::new("test"),
453
1
    };
454
1

            
455
1
    // Pot uses a map to represent a struct
456
1
    let pot = pot::to_vec(&original).unwrap();
457
1
    assert_eq!(
458
1
        pot::from_slice::<CollectionDocument<Basic>>(&pot).unwrap(),
459
1
        original
460
1
    );
461
    // Bincode uses a sequence to represent a struct
462
1
    let bincode = transmog_bincode::bincode::serialize(&original).unwrap();
463
1
    assert_eq!(
464
1
        transmog_bincode::bincode::deserialize::<CollectionDocument<Basic>>(&bincode).unwrap(),
465
1
        original
466
1
    );
467
1
}