1
use std::borrow::Cow;
2
use std::fmt::Debug;
3

            
4
use arc_bytes::serde::{Bytes, CowBytes};
5
use serde::de::{self, Visitor};
6
use serde::ser::SerializeStruct;
7
use serde::{Deserialize, Serialize};
8

            
9
use crate::connection::{AsyncConnection, Connection};
10
use crate::document::{
11
    BorrowedDocument, CollectionHeader, DocumentId, HasHeader, Header, OwnedDocument,
12
};
13
use crate::schema::SerializedCollection;
14
use crate::transaction::{Operation, Transaction};
15
use crate::Error;
16

            
17
/// A document with serializable contents.
18
29
#[derive(Clone, Debug, Eq, PartialEq)]
19
pub struct CollectionDocument<C>
20
where
21
    C: SerializedCollection,
22
{
23
    /// The header of the document, which contains the id and `Revision`.
24
    pub header: CollectionHeader<C::PrimaryKey>,
25

            
26
    /// The document's contents.
27
    pub contents: C::Contents,
28
}
29

            
30
impl<'a, C> TryFrom<&'a BorrowedDocument<'a>> for CollectionDocument<C>
31
where
32
    C: SerializedCollection,
33
{
34
    type Error = Error;
35

            
36
58526
    fn try_from(value: &'a BorrowedDocument<'a>) -> Result<Self, Self::Error> {
37
58526
        Ok(Self {
38
58526
            contents: C::deserialize(&value.contents)?,
39
58526
            header: CollectionHeader::try_from(value.header.clone())?,
40
        })
41
58526
    }
42
}
43

            
44
impl<'a, C> TryFrom<&'a OwnedDocument> for CollectionDocument<C>
45
where
46
    C: SerializedCollection,
47
{
48
    type Error = Error;
49

            
50
27988
    fn try_from(value: &'a OwnedDocument) -> Result<Self, Self::Error> {
51
27988
        Ok(Self {
52
27988
            contents: C::deserialize(&value.contents)?,
53
27988
            header: CollectionHeader::try_from(value.header.clone())?,
54
        })
55
27988
    }
56
}
57

            
58
impl<'a, 'b, C> TryFrom<&'b CollectionDocument<C>> for BorrowedDocument<'a>
59
where
60
    C: SerializedCollection,
61
{
62
    type Error = crate::Error;
63

            
64
    fn try_from(value: &'b CollectionDocument<C>) -> Result<Self, Self::Error> {
65
        Ok(Self {
66
            contents: CowBytes::from(C::serialize(&value.contents)?),
67
            header: Header::try_from(value.header.clone())?,
68
        })
69
    }
70
}
71

            
72
impl<C> CollectionDocument<C>
73
where
74
    C: SerializedCollection,
75
{
76
    /// Updates the document stored in the database with the contents of this
77
    /// collection document.
78
    ///
79
    /// ```rust
80
    /// # bonsaidb_core::__doctest_prelude!();
81
    /// # use bonsaidb_core::connection::Connection;
82
    /// # fn test_fn<C: Connection>(db: C) -> Result<(), Error> {
83
    /// if let Some(mut document) = MyCollection::get(&42, &db)? {
84
    ///     // ... do something `document`
85
    ///     document.update(&db)?;
86
    ///     println!(
87
    ///         "The document has been updated: {:?}",
88
    ///         document.header.revision
89
    ///     );
90
    /// }
91
    /// # Ok(())
92
    /// # }
93
    /// ```
94
892
    pub fn update<Cn: Connection>(&mut self, connection: &Cn) -> Result<(), Error> {
95
892
        let mut doc = self.to_document()?;
96

            
97
892
        connection.update::<C, _>(&mut doc)?;
98

            
99
889
        self.header = CollectionHeader::try_from(doc.header)?;
100

            
101
889
        Ok(())
102
892
    }
103

            
104
    /// Pushes an update [`Operation`] to the transaction for this document.
105
    ///
106
    /// The changes will happen once the transaction is applied.
107
8
    pub fn update_in_transaction(&self, transaction: &mut Transaction) -> Result<(), Error> {
108
8
        transaction.push(Operation::update_serialized::<C>(
109
8
            self.header.clone(),
110
8
            &self.contents,
111
8
        )?);
112
8
        Ok(())
113
8
    }
114

            
115
    /// Stores the new value of `contents` in the document.
116
    ///
117
    /// ```rust
118
    /// # bonsaidb_core::__doctest_prelude!();
119
    /// # use bonsaidb_core::connection::AsyncConnection;
120
    /// # fn test_fn<C: AsyncConnection>(db: C) -> Result<(), Error> {
121
    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
122
    /// if let Some(mut document) = MyCollection::get_async(&42, &db).await? {
123
    ///     // modify the document
124
    ///     document.update_async(&db).await?;
125
    ///     println!("Updated revision: {:?}", document.header.revision);
126
    /// }
127
    /// # Ok(())
128
    /// # })
129
    /// # }
130
    /// ```
131
5277
    pub async fn update_async<Cn: AsyncConnection>(
132
5277
        &mut self,
133
5277
        connection: &Cn,
134
5277
    ) -> Result<(), Error> {
135
5277
        let mut doc = self.to_document()?;
136

            
137
8317
        connection.update::<C, _>(&mut doc).await?;
138

            
139
5272
        self.header = CollectionHeader::try_from(doc.header)?;
140

            
141
5272
        Ok(())
142
5277
    }
143

            
144
    /// Modifies `self`, automatically retrying the modification if the document
145
    /// has been updated on the server.
146
    ///
147
    /// ## Data loss warning
148
    ///
149
    /// If you've modified `self` before calling this function and a conflict
150
    /// occurs, all changes to self will be lost when the current document is
151
    /// fetched before retrying the process again. When you use this function,
152
    /// you should limit the edits to the value to within the `modifier`
153
    /// callback.
154
3
    pub fn modify<Cn: Connection, Modifier: FnMut(&mut Self) + Send + Sync>(
155
3
        &mut self,
156
3
        connection: &Cn,
157
3
        mut modifier: Modifier,
158
3
    ) -> Result<(), Error>
159
3
    where
160
3
        C::Contents: Clone,
161
3
    {
162
3
        let mut is_first_loop = true;
163
        // TODO this should have a retry-limit.
164
3
        loop {
165
3
            // On the first attempt, we want to try sending the update to the
166
3
            // database without fetching new contents. If we receive a conflict,
167
3
            // on future iterations we will first re-load the data.
168
3
            if is_first_loop {
169
3
                is_first_loop = false;
170
3
            } else {
171
                *self =
172
                    C::get(&self.header.id, connection)?.ok_or_else(|| {
173
                        match DocumentId::new(&self.header.id) {
174
                            Ok(id) => Error::DocumentNotFound(C::collection_name(), Box::new(id)),
175
                            Err(err) => err,
176
                        }
177
                    })?;
178
            }
179
3
            modifier(&mut *self);
180
3
            match self.update(connection) {
181
                Err(Error::DocumentConflict(..)) => {}
182
3
                other => return other,
183
            }
184
        }
185
3
    }
186

            
187
    /// Modifies `self`, automatically retrying the modification if the document
188
    /// has been updated on the server.
189
    ///
190
    /// ## Data loss warning
191
    ///
192
    /// If you've modified `self` before calling this function and a conflict
193
    /// occurs, all changes to self will be lost when the current document is
194
    /// fetched before retrying the process again. When you use this function,
195
    /// you should limit the edits to the value to within the `modifier`
196
    /// callback.
197
5
    pub async fn modify_async<Cn: AsyncConnection, Modifier: FnMut(&mut Self) + Send + Sync>(
198
5
        &mut self,
199
5
        connection: &Cn,
200
5
        mut modifier: Modifier,
201
5
    ) -> Result<(), Error>
202
5
    where
203
5
        C::Contents: Clone,
204
5
    {
205
5
        let mut is_first_loop = true;
206
        // TODO this should have a retry-limit.
207
5
        loop {
208
5
            // On the first attempt, we want to try sending the update to the
209
5
            // database without fetching new contents. If we receive a conflict,
210
5
            // on future iterations we will first re-load the data.
211
5
            if is_first_loop {
212
5
                is_first_loop = false;
213
5
            } else {
214
                *self = C::get_async(&self.header.id, connection)
215
                    .await?
216
                    .ok_or_else(|| match DocumentId::new(&self.header.id) {
217
                        Ok(id) => Error::DocumentNotFound(C::collection_name(), Box::new(id)),
218
                        Err(err) => err,
219
                    })?;
220
            }
221
5
            modifier(&mut *self);
222
5
            match self.update_async(connection).await {
223
                Err(Error::DocumentConflict(..)) => {}
224
5
                other => return other,
225
            }
226
        }
227
5
    }
228

            
229
    /// Removes the document from the collection.
230
    ///
231
    /// ```rust
232
    /// # bonsaidb_core::__doctest_prelude!();
233
    /// # use bonsaidb_core::connection::Connection;
234
    /// # fn test_fn<C: Connection>(db: C) -> Result<(), Error> {
235
    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
236
    /// if let Some(document) = MyCollection::get(&42, &db)? {
237
    ///     document.delete(&db)?;
238
    /// }
239
    /// # Ok(())
240
    /// # })
241
    /// # }
242
    /// ```
243
    pub fn delete<Cn: Connection>(&self, connection: &Cn) -> Result<(), Error> {
244
11
        connection.collection::<C>().delete(self)?;
245

            
246
11
        Ok(())
247
11
    }
248

            
249
    /// Removes the document from the collection.
250
    ///
251
    /// ```rust
252
    /// # bonsaidb_core::__doctest_prelude!();
253
    /// # use bonsaidb_core::connection::AsyncConnection;
254
    /// # fn test_fn<C: AsyncConnection>(db: C) -> Result<(), Error> {
255
    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
256
    /// if let Some(document) = MyCollection::get_async(&42, &db).await? {
257
    ///     document.delete_async(&db).await?;
258
    /// }
259
    /// # Ok(())
260
    /// # })
261
    /// # }
262
    /// ```
263
535
    pub async fn delete_async<Cn: AsyncConnection>(&self, connection: &Cn) -> Result<(), Error> {
264
535
        connection.collection::<C>().delete(self).await?;
265

            
266
533
        Ok(())
267
535
    }
268

            
269
    /// Pushes a delete [`Operation`] to the transaction for this document.
270
    ///
271
    /// The document will be deleted once the transaction is applied.
272
8
    pub fn delete_in_transaction(&self, transaction: &mut Transaction) -> Result<(), Error> {
273
8
        transaction.push(Operation::delete(C::collection_name(), self.header()?));
274
8
        Ok(())
275
8
    }
276

            
277
    /// Refreshes this instance from `connection`. If the document is no longer
278
    /// present, [`Error::DocumentNotFound`] will be returned.
279
3
    pub fn refresh<Cn: Connection>(&mut self, connection: &Cn) -> Result<(), Error> {
280
3
        let id = DocumentId::new(&self.header.id)?;
281
3
        *self = C::get(&id, connection)?
282
3
            .ok_or_else(|| Error::DocumentNotFound(C::collection_name(), Box::new(id)))?;
283
3
        Ok(())
284
3
    }
285

            
286
    /// Refreshes this instance from `connection`. If the document is no longer
287
    /// present, [`Error::DocumentNotFound`] will be returned.
288
5
    pub async fn refresh_async<Cn: AsyncConnection>(
289
5
        &mut self,
290
5
        connection: &Cn,
291
5
    ) -> Result<(), Error> {
292
5
        let id = DocumentId::new(&self.header.id)?;
293
5
        *self = C::get_async(&id, connection)
294
5
            .await?
295
5
            .ok_or_else(|| Error::DocumentNotFound(C::collection_name(), Box::new(id)))?;
296
5
        Ok(())
297
5
    }
298

            
299
    /// Converts this value to a serialized `Document`.
300
6169
    pub fn to_document(&self) -> Result<OwnedDocument, Error> {
301
6169
        Ok(OwnedDocument {
302
6169
            contents: Bytes::from(C::serialize(&self.contents)?),
303
6169
            header: Header::try_from(self.header.clone())?,
304
        })
305
6169
    }
306
}
307

            
308
impl<C> Serialize for CollectionDocument<C>
309
where
310
    C: SerializedCollection,
311
    C::Contents: Serialize,
312
    C::PrimaryKey: Serialize,
313
{
314
3
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
315
3
    where
316
3
        S: serde::Serializer,
317
3
    {
318
3
        let mut s = serializer.serialize_struct("CollectionDocument", 2)?;
319
3
        s.serialize_field("header", &self.header)?;
320
3
        s.serialize_field("contents", &self.contents)?;
321
3
        s.end()
322
3
    }
323
}
324

            
325
impl<'de, C> Deserialize<'de> for CollectionDocument<C>
326
where
327
    C: SerializedCollection,
328
    C::PrimaryKey: Deserialize<'de>,
329
    C::Contents: Deserialize<'de>,
330
{
331
2
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
332
2
    where
333
2
        D: serde::Deserializer<'de>,
334
2
    {
335
2
        struct CollectionDocumentVisitor<C>
336
2
        where
337
2
            C: SerializedCollection,
338
2
        {
339
2
            header: Option<CollectionHeader<C::PrimaryKey>>,
340
2
            contents: Option<C::Contents>,
341
2
        }
342
2

            
343
2
        impl<C> Default for CollectionDocumentVisitor<C>
344
2
        where
345
2
            C: SerializedCollection,
346
2
        {
347
2
            fn default() -> Self {
348
2
                Self {
349
2
                    header: None,
350
2
                    contents: None,
351
2
                }
352
2
            }
353
2
        }
354
2

            
355
2
        impl<'de, C> Visitor<'de> for CollectionDocumentVisitor<C>
356
2
        where
357
2
            C: SerializedCollection,
358
2
            C::PrimaryKey: Deserialize<'de>,
359
2
            C::Contents: Deserialize<'de>,
360
2
        {
361
2
            type Value = CollectionDocument<C>;
362
2

            
363
2
            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364
                formatter.write_str("a collection document")
365
            }
366
2

            
367
2
            fn visit_map<A>(mut self, mut map: A) -> Result<Self::Value, A::Error>
368
1
            where
369
1
                A: serde::de::MapAccess<'de>,
370
1
            {
371
3
                while let Some(key) = map.next_key::<Cow<'_, str>>()? {
372
2
                    match key.as_ref() {
373
2
                        "header" => {
374
1
                            self.header = Some(map.next_value()?);
375
2
                        }
376
2
                        "contents" => {
377
1
                            self.contents = Some(map.next_value()?);
378
2
                        }
379
2
                        _ => {
380
2
                            return Err(<A::Error as de::Error>::custom(format!(
381
                                "unknown field {key}"
382
                            )))
383
2
                        }
384
2
                    }
385
2
                }
386
2

            
387
2
                Ok(CollectionDocument {
388
2
                    header: self
389
1
                        .header
390
1
                        .ok_or_else(|| <A::Error as de::Error>::custom("`header` missing"))?,
391
2
                    contents: self
392
1
                        .contents
393
1
                        .ok_or_else(|| <A::Error as de::Error>::custom("`contents` missing"))?,
394
2
                })
395
2
            }
396
2

            
397
2
            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
398
1
            where
399
1
                A: de::SeqAccess<'de>,
400
1
            {
401
2
                let header = seq
402
1
                    .next_element()?
403
2
                    .ok_or_else(|| <A::Error as de::Error>::custom("`header` missing"))?;
404
2
                let contents = seq
405
1
                    .next_element()?
406
2
                    .ok_or_else(|| <A::Error as de::Error>::custom("`contents` missing"))?;
407
2
                Ok(CollectionDocument { header, contents })
408
2
            }
409
2
        }
410
2

            
411
2
        deserializer.deserialize_struct(
412
2
            "CollectionDocument",
413
2
            &["header", "contents"],
414
2
            CollectionDocumentVisitor::default(),
415
2
        )
416
2
    }
417
}
418

            
419
/// Helper functions for a slice of [`OwnedDocument`]s.
420
pub trait OwnedDocuments {
421
    /// Returns a list of deserialized documents.
422
    fn collection_documents<C: SerializedCollection>(
423
        &self,
424
    ) -> Result<Vec<CollectionDocument<C>>, Error>;
425
}
426

            
427
impl OwnedDocuments for [OwnedDocument] {
428
697
    fn collection_documents<C: SerializedCollection>(
429
697
        &self,
430
697
    ) -> Result<Vec<CollectionDocument<C>>, Error> {
431
697
        self.iter().map(CollectionDocument::try_from).collect()
432
697
    }
433
}
434

            
435
1
#[test]
436
1
fn collection_document_serialization() {
437
1
    use crate::test_util::Basic;
438
1

            
439
1
    let original: CollectionDocument<Basic> = CollectionDocument {
440
1
        header: CollectionHeader {
441
1
            id: 1,
442
1
            revision: super::Revision::new(b"hello world"),
443
1
        },
444
1
        contents: Basic::new("test"),
445
1
    };
446
1

            
447
1
    // Pot uses a map to represent a struct
448
1
    let pot = pot::to_vec(&original).unwrap();
449
1
    assert_eq!(
450
1
        pot::from_slice::<CollectionDocument<Basic>>(&pot).unwrap(),
451
1
        original
452
1
    );
453
    // Bincode uses a sequence to represent a struct
454
1
    let bincode = transmog_bincode::bincode::serialize(&original).unwrap();
455
1
    assert_eq!(
456
1
        transmog_bincode::bincode::deserialize::<CollectionDocument<Basic>>(&bincode).unwrap(),
457
1
        original
458
1
    );
459
1
}