1
use std::{
2
    any::TypeId,
3
    collections::{HashMap, HashSet},
4
    fmt::Debug,
5
    marker::PhantomData,
6
};
7

            
8
use derive_where::derive_where;
9

            
10
use crate::{
11
    document::{BorrowedDocument, DocumentId, KeyId},
12
    key::Key,
13
    schema::{
14
        collection::Collection,
15
        view::{
16
            self,
17
            map::{self, MappedValue},
18
            Serialized, SerializedView, ViewSchema,
19
        },
20
        CollectionName, Schema, SchemaName, View, ViewName,
21
    },
22
    Error,
23
};
24

            
25
/// A collection of defined collections and views.
26
#[derive(Debug)]
27
pub struct Schematic {
28
    /// The name of the schema this was built from.
29
    pub name: SchemaName,
30
    contained_collections: HashSet<CollectionName>,
31
    collections_by_type_id: HashMap<TypeId, CollectionName>,
32
    collection_encryption_keys: HashMap<CollectionName, KeyId>,
33
    collection_id_generators: HashMap<CollectionName, Box<dyn IdGenerator>>,
34
    views: HashMap<TypeId, Box<dyn view::Serialized>>,
35
    views_by_name: HashMap<ViewName, TypeId>,
36
    views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
37
    unique_views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
38
}
39

            
40
impl Schematic {
41
    /// Returns an initialized version from `S`.
42
126131
    pub fn from_schema<S: Schema + ?Sized>() -> Result<Self, Error> {
43
126131
        let mut schematic = Self {
44
126131
            name: S::schema_name(),
45
126131
            contained_collections: HashSet::new(),
46
126131
            collections_by_type_id: HashMap::new(),
47
126131
            collection_encryption_keys: HashMap::new(),
48
126131
            collection_id_generators: HashMap::new(),
49
126131
            views: HashMap::new(),
50
126131
            views_by_name: HashMap::new(),
51
126131
            views_by_collection: HashMap::new(),
52
126131
            unique_views_by_collection: HashMap::new(),
53
126131
        };
54
126131
        S::define_collections(&mut schematic)?;
55
126131
        Ok(schematic)
56
126131
    }
57

            
58
    /// Adds the collection `C` and its views.
59
4376901
    pub fn define_collection<C: Collection + 'static>(&mut self) -> Result<(), Error> {
60
4376901
        let name = C::collection_name();
61
4376901
        if self.contained_collections.contains(&name) {
62
            Err(Error::CollectionAlreadyDefined)
63
        } else {
64
4376901
            self.collections_by_type_id
65
4376901
                .insert(TypeId::of::<C>(), name.clone());
66
4376901
            if let Some(key) = C::encryption_key() {
67
1380907
                self.collection_encryption_keys.insert(name.clone(), key);
68
2996025
            }
69
4376932
            self.collection_id_generators
70
4376932
                .insert(name.clone(), Box::new(KeyIdGenerator::<C>::default()));
71
4376932
            self.contained_collections.insert(name);
72
4376932
            C::define_views(self)
73
        }
74
4376932
    }
75

            
76
    /// Adds the view `V`.
77
10764035
    pub fn define_view<V: ViewSchema<View = V> + SerializedView + Clone + 'static>(
78
10764035
        &mut self,
79
10764035
        view: V,
80
10764035
    ) -> Result<(), Error> {
81
10764035
        self.define_view_with_schema(view.clone(), view)
82
10764035
    }
83

            
84
    /// Adds the view `V`.
85
10764035
    pub fn define_view_with_schema<
86
10764035
        V: SerializedView + 'static,
87
10764035
        S: ViewSchema<View = V> + 'static,
88
10764035
    >(
89
10764035
        &mut self,
90
10764035
        view: V,
91
10764035
        schema: S,
92
10764035
    ) -> Result<(), Error> {
93
10764035
        let instance = ViewInstance { view, schema };
94
10764035
        let name = instance.view_name();
95
10764035
        let collection = instance.collection();
96
10764035
        let unique = instance.unique();
97
10764035
        self.views.insert(TypeId::of::<V>(), Box::new(instance));
98
10764035

            
99
10764035
        if self.views_by_name.contains_key(&name) {
100
            return Err(Error::ViewAlreadyRegistered(name));
101
10764035
        }
102
10764035
        self.views_by_name.insert(name, TypeId::of::<V>());
103
10764035

            
104
10764035
        if unique {
105
1648608
            let unique_views = self
106
1648608
                .unique_views_by_collection
107
1648608
                .entry(collection.clone())
108
1648608
                .or_insert_with(Vec::new);
109
1648608
            unique_views.push(TypeId::of::<V>());
110
9115427
        }
111
10764035
        let views = self
112
10764035
            .views_by_collection
113
10764035
            .entry(collection)
114
10764035
            .or_insert_with(Vec::new);
115
10764035
        views.push(TypeId::of::<V>());
116
10764035

            
117
10764035
        Ok(())
118
10764035
    }
119

            
120
    /// Returns `true` if this schema contains the collection `C`.
121
    #[must_use]
122
    pub fn contains_collection<C: Collection + 'static>(&self) -> bool {
123
        self.collections_by_type_id.contains_key(&TypeId::of::<C>())
124
    }
125

            
126
    /// Returns `true` if this schema contains the collection `C`.
127
    #[must_use]
128
1017668
    pub fn contains_collection_name(&self, collection: &CollectionName) -> bool {
129
1017668
        self.contained_collections.contains(collection)
130
1017668
    }
131

            
132
    /// Returns the next id in sequence for the collection, if the primary key
133
    /// type supports the operation and the next id would not overflow.
134
475044
    pub fn next_id_for_collection(
135
475044
        &self,
136
475044
        collection: &CollectionName,
137
475044
        id: Option<DocumentId>,
138
475044
    ) -> Result<DocumentId, Error> {
139
475044
        let generator = self
140
475044
            .collection_id_generators
141
475044
            .get(collection)
142
475044
            .ok_or(Error::CollectionNotFound)?;
143
475044
        generator.next_id(id)
144
475044
    }
145

            
146
    /// Looks up a [`view::Serialized`] by name.
147
1070182
    pub fn view_by_name(&self, name: &ViewName) -> Result<&'_ dyn view::Serialized, Error> {
148
1070182
        self.views_by_name
149
1070182
            .get(name)
150
1070182
            .and_then(|type_id| self.views.get(type_id))
151
1070182
            .map(AsRef::as_ref)
152
1070182
            .ok_or(Error::ViewNotFound)
153
1070182
    }
154

            
155
    /// Looks up a [`view::Serialized`] through the the type `V`.
156
46665
    pub fn view<V: View + 'static>(&self) -> Result<&'_ dyn view::Serialized, Error> {
157
46665
        self.views
158
46665
            .get(&TypeId::of::<V>())
159
46665
            .map(AsRef::as_ref)
160
46665
            .ok_or(Error::ViewNotFound)
161
46665
    }
162

            
163
    /// Iterates over all registered views.
164
124
    pub fn views(&self) -> impl Iterator<Item = &'_ dyn view::Serialized> {
165
124
        self.views.values().map(AsRef::as_ref)
166
124
    }
167

            
168
    /// Iterates over all views that belong to `collection`.
169
    #[must_use]
170
2392115
    pub fn views_in_collection(
171
2392115
        &self,
172
2392115
        collection: &CollectionName,
173
2392115
    ) -> Option<Vec<&'_ dyn view::Serialized>> {
174
2392115
        self.views_by_collection.get(collection).map(|view_ids| {
175
1526192
            view_ids
176
1526192
                .iter()
177
5124858
                .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
178
1526192
                .collect()
179
2392115
        })
180
2392115
    }
181

            
182
    /// Iterates over all views that are unique that belong to `collection`.
183
    #[must_use]
184
1000649
    pub fn unique_views_in_collection(
185
1000649
        &self,
186
1000649
        collection: &CollectionName,
187
1000649
    ) -> Option<Vec<&'_ dyn view::Serialized>> {
188
1000649
        self.unique_views_by_collection
189
1000649
            .get(collection)
190
1000649
            .map(|view_ids| {
191
111693
                view_ids
192
111693
                    .iter()
193
111693
                    .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
194
111693
                    .collect()
195
1000649
            })
196
1000649
    }
197

            
198
    /// Returns a collection's default encryption key, if one was defined.
199
    #[must_use]
200
3416448
    pub fn encryption_key_for_collection(&self, collection: &CollectionName) -> Option<&KeyId> {
201
3416448
        self.collection_encryption_keys.get(collection)
202
3416448
    }
203

            
204
    /// Returns a list of all collections contained in this schematic.
205
    #[must_use]
206
651
    pub fn collections(&self) -> Vec<CollectionName> {
207
651
        self.contained_collections.iter().cloned().collect()
208
651
    }
209
}
210

            
211
#[derive(Debug)]
212
struct ViewInstance<V, S> {
213
    view: V,
214
    schema: S,
215
}
216

            
217
impl<V, S> Serialized for ViewInstance<V, S>
218
where
219
    V: SerializedView,
220
    S: ViewSchema<View = V>,
221
    <V as View>::Key: 'static,
222
{
223
11022736
    fn collection(&self) -> CollectionName {
224
11022736
        <<V as View>::Collection as Collection>::collection_name()
225
11022736
    }
226

            
227
15703404
    fn unique(&self) -> bool {
228
15703404
        self.schema.unique()
229
15703404
    }
230

            
231
56937
    fn version(&self) -> u64 {
232
56937
        self.schema.version()
233
56937
    }
234

            
235
14394881
    fn view_name(&self) -> ViewName {
236
14394881
        self.view.view_name()
237
14394881
    }
238

            
239
50335
    fn map(&self, document: &BorrowedDocument<'_>) -> Result<Vec<map::Serialized>, view::Error> {
240
50335
        let map = self.schema.map(document)?;
241

            
242
50335
        map.into_iter()
243
50343
            .map(|map| map.serialized::<V>())
244
50335
            .collect::<Result<Vec<_>, view::Error>>()
245
50335
    }
246

            
247
54556
    fn reduce(&self, mappings: &[(&[u8], &[u8])], rereduce: bool) -> Result<Vec<u8>, view::Error> {
248
54556
        let mappings = mappings
249
54556
            .iter()
250
55352
            .map(|(key, value)| match <V::Key as Key>::from_ord_bytes(key) {
251
51143
                Ok(key) => {
252
51143
                    let value = V::deserialize(value)?;
253
51143
                    Ok(MappedValue::new(key, value))
254
                }
255
                Err(err) => Err(view::Error::key_serialization(err)),
256
55352
            })
257
54556
            .collect::<Result<Vec<_>, view::Error>>()?;
258

            
259
54556
        let reduced_value = self.schema.reduce(&mappings, rereduce)?;
260

            
261
12892
        V::serialize(&reduced_value).map_err(view::Error::from)
262
54556
    }
263
}
264

            
265
pub trait IdGenerator: Debug + Send + Sync {
266
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error>;
267
}
268

            
269
#[derive(Debug)]
270
4376930
#[derive_where(Default)]
271
pub struct KeyIdGenerator<C: Collection>(PhantomData<C>);
272

            
273
impl<C> IdGenerator for KeyIdGenerator<C>
274
where
275
    C: Collection,
276
{
277
327498
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error> {
278
327498
        let key = id.map(|id| id.deserialize::<C::PrimaryKey>()).transpose()?;
279
327498
        let key = if let Some(key) = key {
280
300297
            key
281
        } else {
282
27201
            <C::PrimaryKey as Key<'_>>::first_value()
283
27201
                .map_err(|err| Error::DocumentPush(C::collection_name(), err))?
284
        };
285
327498
        let next_value = key
286
327498
            .next_value()
287
327498
            .map_err(|err| Error::DocumentPush(C::collection_name(), err))?;
288
327497
        DocumentId::new(next_value)
289
327498
    }
290
}
291

            
292
1
#[test]
293
1
fn schema_tests() -> anyhow::Result<()> {
294
    use crate::test_util::{Basic, BasicCount};
295
1
    let schema = Schematic::from_schema::<Basic>()?;
296

            
297
1
    assert_eq!(schema.collections_by_type_id.len(), 1);
298
1
    assert_eq!(
299
1
        schema.collections_by_type_id[&TypeId::of::<Basic>()],
300
1
        Basic::collection_name()
301
1
    );
302
1
    assert_eq!(schema.views.len(), 4);
303
1
    assert_eq!(
304
1
        schema.views[&TypeId::of::<BasicCount>()].view_name(),
305
1
        View::view_name(&BasicCount)
306
1
    );
307

            
308
1
    Ok(())
309
1
}