1
use std::any::TypeId;
2
use std::collections::{hash_map, HashMap};
3
use std::fmt::Debug;
4
use std::marker::PhantomData;
5

            
6
use derive_where::derive_where;
7

            
8
use crate::document::{BorrowedDocument, DocumentId, KeyId};
9
use crate::key::{ByteSource, Key, KeyDescription};
10
use crate::schema::collection::Collection;
11
use crate::schema::view::map::{self, MappedValue};
12
use crate::schema::view::{
13
    self, MapReduce, Serialized, SerializedView, ViewSchema, ViewUpdatePolicy,
14
};
15
use crate::schema::{CollectionName, Schema, SchemaName, View, ViewName};
16
use crate::Error;
17

            
18
/// A collection of defined collections and views.
19
pub struct Schematic {
20
    /// The name of the schema this was built from.
21
    pub name: SchemaName,
22
    contained_collections: HashMap<CollectionName, KeyDescription>,
23
    collections_by_type_id: HashMap<TypeId, CollectionName>,
24
    collection_encryption_keys: HashMap<CollectionName, KeyId>,
25
    collection_id_generators: HashMap<CollectionName, Box<dyn IdGenerator>>,
26
    views: HashMap<TypeId, Box<dyn view::Serialized>>,
27
    views_by_name: HashMap<ViewName, TypeId>,
28
    views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
29
    eager_views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
30
}
31

            
32
impl Schematic {
33
    /// Returns an initialized version from `S`.
34
148284
    pub fn from_schema<S: Schema + ?Sized>() -> Result<Self, Error> {
35
148284
        let mut schematic = Self {
36
148284
            name: S::schema_name(),
37
148284
            contained_collections: HashMap::new(),
38
148284
            collections_by_type_id: HashMap::new(),
39
148284
            collection_encryption_keys: HashMap::new(),
40
148284
            collection_id_generators: HashMap::new(),
41
148284
            views: HashMap::new(),
42
148284
            views_by_name: HashMap::new(),
43
148284
            views_by_collection: HashMap::new(),
44
148284
            eager_views_by_collection: HashMap::new(),
45
148284
        };
46
148284
        S::define_collections(&mut schematic)?;
47
148284
        Ok(schematic)
48
148284
    }
49

            
50
    /// Adds the collection `C` and its views.
51
5790439
    pub fn define_collection<C: Collection + 'static>(&mut self) -> Result<(), Error> {
52
5790439
        let name = C::collection_name();
53
5790439
        match self.contained_collections.entry(name.clone()) {
54
5790439
            hash_map::Entry::Vacant(entry) => {
55
5790439
                self.collections_by_type_id
56
5790439
                    .insert(TypeId::of::<C>(), name.clone());
57
5790439
                if let Some(key) = C::encryption_key() {
58
1799362
                    self.collection_encryption_keys.insert(name.clone(), key);
59
3991077
                }
60
5790439
                self.collection_id_generators
61
5790439
                    .insert(name, Box::<KeyIdGenerator<C>>::default());
62
5790439
                entry.insert(KeyDescription::for_key::<C::PrimaryKey>());
63
5790439
                C::define_views(self)
64
            }
65
            hash_map::Entry::Occupied(_) => Err(Error::CollectionAlreadyDefined),
66
        }
67
5790439
    }
68

            
69
    /// Adds the view `V`.
70
17315132
    pub fn define_view<V: MapReduce + ViewSchema<View = V> + SerializedView + Clone + 'static>(
71
17315132
        &mut self,
72
17315132
        view: V,
73
17315132
    ) -> Result<(), Error> {
74
17315132
        self.define_view_with_schema(view.clone(), view)
75
17315132
    }
76

            
77
    /// Adds the view `V`.
78
17315132
    pub fn define_view_with_schema<
79
17315132
        V: SerializedView + 'static,
80
17315132
        S: MapReduce + ViewSchema<View = V> + 'static,
81
17315132
    >(
82
17315132
        &mut self,
83
17315132
        view: V,
84
17315132
        schema: S,
85
17315132
    ) -> Result<(), Error> {
86
17315132
        let instance = ViewInstance { view, schema };
87
17315132
        let name = instance.view_name();
88
17315132
        if self.views_by_name.contains_key(&name) {
89
            return Err(Error::ViewAlreadyRegistered(name));
90
17315132
        }
91
17315132

            
92
17315132
        let collection = instance.collection();
93
17315132
        let eager = instance.update_policy().is_eager();
94
17315132
        self.views.insert(TypeId::of::<V>(), Box::new(instance));
95
17315132
        self.views_by_name.insert(name, TypeId::of::<V>());
96
17315132

            
97
17315132
        if eager {
98
3851489
            let unique_views = self
99
3851489
                .eager_views_by_collection
100
3851489
                .entry(collection.clone())
101
3851489
                .or_insert_with(Vec::new);
102
3851489
            unique_views.push(TypeId::of::<V>());
103
13463643
        }
104
17315132
        let views = self
105
17315132
            .views_by_collection
106
17315132
            .entry(collection)
107
17315132
            .or_insert_with(Vec::new);
108
17315132
        views.push(TypeId::of::<V>());
109
17315132

            
110
17315132
        Ok(())
111
17315132
    }
112

            
113
    /// Returns `true` if this schema contains the collection `C`.
114
    #[must_use]
115
    pub fn contains_collection<C: Collection + 'static>(&self) -> bool {
116
        self.collections_by_type_id.contains_key(&TypeId::of::<C>())
117
    }
118

            
119
    /// Returns the description of the primary keyof the collection with the
120
    /// given name, or `None` if the collection can't be found.
121
    #[must_use]
122
1348880
    pub fn collection_primary_key_description<'a>(
123
1348880
        &'a self,
124
1348880
        collection: &CollectionName,
125
1348880
    ) -> Option<&'a KeyDescription> {
126
1348880
        self.contained_collections.get(collection)
127
1348880
    }
128

            
129
    /// Returns the next id in sequence for the collection, if the primary key
130
    /// type supports the operation and the next id would not overflow.
131
649160
    pub fn next_id_for_collection(
132
649160
        &self,
133
649160
        collection: &CollectionName,
134
649160
        id: Option<DocumentId>,
135
649160
    ) -> Result<DocumentId, Error> {
136
649160
        let generator = self
137
649160
            .collection_id_generators
138
649160
            .get(collection)
139
649160
            .ok_or(Error::CollectionNotFound)?;
140
649160
        generator.next_id(id)
141
649160
    }
142

            
143
    /// Looks up a [`view::Serialized`] by name.
144
1384960
    pub fn view_by_name(&self, name: &ViewName) -> Result<&'_ dyn view::Serialized, Error> {
145
1384960
        self.views_by_name
146
1384960
            .get(name)
147
1384960
            .and_then(|type_id| self.views.get(type_id))
148
1384960
            .map(AsRef::as_ref)
149
1384960
            .ok_or(Error::ViewNotFound)
150
1384960
    }
151

            
152
    /// Looks up a [`view::Serialized`] through the the type `V`.
153
53872
    pub fn view<V: View + 'static>(&self) -> Result<&'_ dyn view::Serialized, Error> {
154
53872
        self.views
155
53872
            .get(&TypeId::of::<V>())
156
53872
            .map(AsRef::as_ref)
157
53872
            .ok_or(Error::ViewNotFound)
158
53872
    }
159

            
160
    /// Iterates over all registered views.
161
160
    pub fn views(&self) -> impl Iterator<Item = &'_ dyn view::Serialized> {
162
160
        self.views.values().map(AsRef::as_ref)
163
160
    }
164

            
165
    /// Iterates over all views that belong to `collection`.
166
2266400
    pub fn views_in_collection(
167
2266400
        &self,
168
2266400
        collection: &CollectionName,
169
2266400
    ) -> impl Iterator<Item = &'_ dyn view::Serialized> {
170
2266400
        self.views_by_collection
171
2266400
            .get(collection)
172
2266400
            .into_iter()
173
2266400
            .flat_map(|view_ids| {
174
1379480
                view_ids
175
1379480
                    .iter()
176
6636400
                    .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
177
2266400
            })
178
2266400
    }
179

            
180
    /// Iterates over all views that are eagerly updated that belong to
181
    /// `collection`.
182
2262120
    pub fn eager_views_in_collection(
183
2262120
        &self,
184
2262120
        collection: &CollectionName,
185
2262120
    ) -> impl Iterator<Item = &'_ dyn view::Serialized> {
186
2262120
        self.eager_views_by_collection
187
2262120
            .get(collection)
188
2262120
            .into_iter()
189
2262120
            .flat_map(|view_ids| {
190
828991
                view_ids
191
828991
                    .iter()
192
828991
                    .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
193
2262120
            })
194
2262120
    }
195

            
196
    /// Returns a collection's default encryption key, if one was defined.
197
    #[must_use]
198
4687920
    pub fn encryption_key_for_collection(&self, collection: &CollectionName) -> Option<&KeyId> {
199
4687920
        self.collection_encryption_keys.get(collection)
200
4687920
    }
201

            
202
    /// Returns a list of all collections contained in this schematic.
203
1640
    pub fn collections(&self) -> impl Iterator<Item = &CollectionName> {
204
1640
        self.contained_collections.keys()
205
1640
    }
206
}
207

            
208
impl Debug for Schematic {
209
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210
        let mut views = self
211
            .views
212
            .values()
213
            .map(|v| v.view_name())
214
            .collect::<Vec<_>>();
215
        views.sort();
216

            
217
        f.debug_struct("Schematic")
218
            .field("name", &self.name)
219
            .field("contained_collections", &self.contained_collections)
220
            .field("collections_by_type_id", &self.collections_by_type_id)
221
            .field(
222
                "collection_encryption_keys",
223
                &self.collection_encryption_keys,
224
            )
225
            .field("collection_id_generators", &self.collection_id_generators)
226
            .field("views", &views)
227
            .field("views_by_name", &self.views_by_name)
228
            .field("views_by_collection", &self.views_by_collection)
229
            .field("eager_views_by_collection", &self.eager_views_by_collection)
230
            .finish()
231
    }
232
}
233

            
234
#[derive(Debug)]
235
struct ViewInstance<V, S> {
236
    view: V,
237
    schema: S,
238
}
239

            
240
impl<V, S> Serialized for ViewInstance<V, S>
241
where
242
    V: SerializedView,
243
    S: MapReduce + ViewSchema<View = V>,
244
{
245
18203373
    fn collection(&self) -> CollectionName {
246
18203373
        <<V as View>::Collection as Collection>::collection_name()
247
18203373
    }
248

            
249
4609
    fn key_description(&self) -> KeyDescription {
250
4609
        KeyDescription::for_key::<<V as View>::Key>()
251
4609
    }
252

            
253
24270750
    fn update_policy(&self) -> ViewUpdatePolicy {
254
24270750
        self.schema.update_policy()
255
24270750
    }
256

            
257
139894
    fn version(&self) -> u64 {
258
139894
        self.schema.version()
259
139894
    }
260

            
261
24772975
    fn view_name(&self) -> ViewName {
262
24772975
        self.view.view_name()
263
24772975
    }
264

            
265
563648
    fn map(&self, document: &BorrowedDocument<'_>) -> Result<Vec<map::Serialized>, view::Error> {
266
563648
        let mappings = self.schema.map(document)?;
267

            
268
563648
        mappings
269
563648
            .iter()
270
563648
            .map(map::Map::serialized::<V>)
271
563648
            .collect::<Result<_, _>>()
272
563648
            .map_err(view::Error::key_serialization)
273
563648
    }
274

            
275
567654
    fn reduce(&self, mappings: &[(&[u8], &[u8])], rereduce: bool) -> Result<Vec<u8>, view::Error> {
276
567654
        let mappings = mappings
277
567654
            .iter()
278
160933429
            .map(|(key, value)| {
279
160933429
                match <S::MappedKey<'_> as Key>::from_ord_bytes(ByteSource::Borrowed(key)) {
280
160933429
                    Ok(key) => {
281
160933429
                        let value = V::deserialize(value)?;
282
160933429
                        Ok(MappedValue::new(key, value))
283
                    }
284
                    Err(err) => Err(view::Error::key_serialization(err)),
285
                }
286
160933429
            })
287
567654
            .collect::<Result<Vec<_>, view::Error>>()?;
288

            
289
567654
        let reduced_value = self.schema.reduce(&mappings, rereduce)?;
290

            
291
509767
        V::serialize(&reduced_value).map_err(view::Error::from)
292
567654
    }
293
}
294

            
295
pub trait IdGenerator: Debug + Send + Sync {
296
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error>;
297
}
298

            
299
5790439
#[derive_where(Default, Debug)]
300
pub struct KeyIdGenerator<C: Collection>(PhantomData<C>);
301

            
302
impl<C> IdGenerator for KeyIdGenerator<C>
303
where
304
    C: Collection,
305
{
306
467979
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error> {
307
467979
        let key = id.map(|id| id.deserialize::<C::PrimaryKey>()).transpose()?;
308
467979
        let key = if let Some(key) = key {
309
429891
            key
310
        } else {
311
38088
            <C::PrimaryKey as Key<'_>>::first_value()
312
38088
                .map_err(|err| Error::DocumentPush(C::collection_name(), err))?
313
        };
314
467979
        let next_value = key
315
467979
            .next_value()
316
467979
            .map_err(|err| Error::DocumentPush(C::collection_name(), err))?;
317
467978
        DocumentId::new(&next_value)
318
467979
    }
319
}
320

            
321
1
#[test]
322
1
fn schema_tests() -> anyhow::Result<()> {
323
    use crate::test_util::{Basic, BasicCount};
324
1
    let schema = Schematic::from_schema::<Basic>()?;
325

            
326
1
    assert_eq!(schema.collections_by_type_id.len(), 1);
327
1
    assert_eq!(
328
1
        schema.collections_by_type_id[&TypeId::of::<Basic>()],
329
1
        Basic::collection_name()
330
1
    );
331
1
    assert_eq!(schema.views.len(), 6);
332
1
    assert_eq!(
333
1
        schema.views[&TypeId::of::<BasicCount>()].view_name(),
334
1
        View::view_name(&BasicCount)
335
1
    );
336

            
337
1
    Ok(())
338
1
}