1
use std::{
2
    any::TypeId,
3
    collections::{HashMap, HashSet},
4
    fmt::Debug,
5
    marker::PhantomData,
6
};
7

            
8
use derive_where::derive_where;
9

            
10
use crate::{
11
    document::{BorrowedDocument, DocumentId, KeyId},
12
    key::Key,
13
    schema::{
14
        collection::Collection,
15
        view::{
16
            self,
17
            map::{self, MappedValue},
18
            Serialized, SerializedView, ViewSchema,
19
        },
20
        CollectionName, Schema, SchemaName, View, ViewName,
21
    },
22
    Error,
23
};
24

            
25
/// A collection of defined collections and views.
26
#[derive(Debug)]
27
pub struct Schematic {
28
    /// The name of the schema this was built from.
29
    pub name: SchemaName,
30
    contained_collections: HashSet<CollectionName>,
31
    collections_by_type_id: HashMap<TypeId, CollectionName>,
32
    collection_encryption_keys: HashMap<CollectionName, KeyId>,
33
    collection_id_generators: HashMap<CollectionName, Box<dyn IdGenerator>>,
34
    views: HashMap<TypeId, Box<dyn view::Serialized>>,
35
    views_by_name: HashMap<ViewName, TypeId>,
36
    views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
37
    unique_views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
38
}
39

            
40
impl Schematic {
41
    /// Returns an initialized version from `S`.
42
92677
    pub fn from_schema<S: Schema + ?Sized>() -> Result<Self, Error> {
43
92677
        let mut schematic = Self {
44
92677
            name: S::schema_name(),
45
92677
            contained_collections: HashSet::new(),
46
92677
            collections_by_type_id: HashMap::new(),
47
92677
            collection_encryption_keys: HashMap::new(),
48
92677
            collection_id_generators: HashMap::new(),
49
92677
            views: HashMap::new(),
50
92677
            views_by_name: HashMap::new(),
51
92677
            views_by_collection: HashMap::new(),
52
92677
            unique_views_by_collection: HashMap::new(),
53
92677
        };
54
92677
        S::define_collections(&mut schematic)?;
55
92677
        Ok(schematic)
56
92677
    }
57

            
58
    /// Adds the collection `C` and its views.
59
2695113
    pub fn define_collection<C: Collection + 'static>(&mut self) -> Result<(), Error> {
60
2695113
        let name = C::collection_name();
61
2695113
        if self.contained_collections.contains(&name) {
62
            Err(Error::CollectionAlreadyDefined)
63
        } else {
64
2695113
            self.collections_by_type_id
65
2695113
                .insert(TypeId::of::<C>(), name.clone());
66
2695113
            if let Some(key) = C::encryption_key() {
67
845764
                self.collection_encryption_keys.insert(name.clone(), key);
68
1849375
            }
69
2695139
            self.collection_id_generators
70
2695139
                .insert(name.clone(), Box::new(KeyIdGenerator::<C>::default()));
71
2695139
            self.contained_collections.insert(name);
72
2695139
            C::define_views(self)
73
        }
74
2695139
    }
75

            
76
    /// Adds the view `V`.
77
6571112
    pub fn define_view<V: ViewSchema<View = V> + SerializedView + Clone + 'static>(
78
6571112
        &mut self,
79
6571112
        view: V,
80
6571112
    ) -> Result<(), Error> {
81
6571112
        self.define_view_with_schema(view.clone(), view)
82
6571112
    }
83

            
84
    /// Adds the view `V`.
85
6571112
    pub fn define_view_with_schema<
86
6571112
        V: SerializedView + 'static,
87
6571112
        S: ViewSchema<View = V> + 'static,
88
6571112
    >(
89
6571112
        &mut self,
90
6571112
        view: V,
91
6571112
        schema: S,
92
6571112
    ) -> Result<(), Error> {
93
6571112
        let instance = ViewInstance { view, schema };
94
6571112
        let name = instance.view_name();
95
6571112
        let collection = instance.collection();
96
6571112
        let unique = instance.unique();
97
6571112
        self.views.insert(TypeId::of::<V>(), Box::new(instance));
98
6571112
        // TODO check for name collision
99
6571112
        self.views_by_name.insert(name, TypeId::of::<V>());
100
6571112
        if unique {
101
1017820
            let unique_views = self
102
1017820
                .unique_views_by_collection
103
1017820
                .entry(collection.clone())
104
1017820
                .or_insert_with(Vec::new);
105
1017820
            unique_views.push(TypeId::of::<V>());
106
5553292
        }
107
6571112
        let views = self
108
6571112
            .views_by_collection
109
6571112
            .entry(collection)
110
6571112
            .or_insert_with(Vec::new);
111
6571112
        views.push(TypeId::of::<V>());
112
6571112

            
113
6571112
        Ok(())
114
6571112
    }
115

            
116
    /// Returns `true` if this schema contains the collection `C`.
117
    #[must_use]
118
    pub fn contains<C: Collection + 'static>(&self) -> bool {
119
        self.collections_by_type_id.contains_key(&TypeId::of::<C>())
120
    }
121

            
122
    /// Returns `true` if this schema contains the collection `C`.
123
    #[must_use]
124
672828
    pub fn contains_collection_id(&self, collection: &CollectionName) -> bool {
125
672828
        self.contained_collections.contains(collection)
126
672828
    }
127

            
128
    /// Returns the next id in sequence for the collection, if the primary key
129
    /// type supports the operation and the next id would not overflow.
130
    pub fn next_id_for_collection(
131
        &self,
132
        collection: &CollectionName,
133
        id: Option<DocumentId>,
134
    ) -> Result<DocumentId, Error> {
135
284102
        if let Some(generator) = self.collection_id_generators.get(collection) {
136
284102
            generator.next_id(id)
137
        } else {
138
            Err(Error::CollectionNotFound)
139
        }
140
284102
    }
141

            
142
    /// Looks up a [`view::Serialized`] by name.
143
    #[must_use]
144
1014780
    pub fn view_by_name(&self, name: &ViewName) -> Option<&'_ dyn view::Serialized> {
145
1014780
        self.views_by_name
146
1014780
            .get(name)
147
1014858
            .and_then(|type_id| self.views.get(type_id))
148
1014780
            .map(AsRef::as_ref)
149
1014780
    }
150

            
151
    /// Looks up a [`view::Serialized`] through the the type `V`.
152
    #[must_use]
153
59917
    pub fn view<V: View + 'static>(&self) -> Option<&'_ dyn view::Serialized> {
154
59917
        self.views.get(&TypeId::of::<V>()).map(AsRef::as_ref)
155
59917
    }
156

            
157
    /// Iterates over all registered views.
158
104
    pub fn views(&self) -> impl Iterator<Item = &'_ dyn view::Serialized> {
159
104
        self.views.values().map(AsRef::as_ref)
160
104
    }
161

            
162
    /// Iterates over all views that belong to `collection`.
163
    #[must_use]
164
1604304
    pub fn views_in_collection(
165
1604304
        &self,
166
1604304
        collection: &CollectionName,
167
1604304
    ) -> Option<Vec<&'_ dyn view::Serialized>> {
168
1604304
        self.views_by_collection.get(collection).map(|view_ids| {
169
979108
            view_ids
170
979108
                .iter()
171
3256370
                .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
172
979108
                .collect()
173
1604304
        })
174
1604304
    }
175

            
176
    /// Iterates over all views that are unique that belong to `collection`.
177
    #[must_use]
178
658970
    pub fn unique_views_in_collection(
179
658970
        &self,
180
658970
        collection: &CollectionName,
181
658970
    ) -> Option<Vec<&'_ dyn view::Serialized>> {
182
658970
        self.unique_views_by_collection
183
658970
            .get(collection)
184
658970
            .map(|view_ids| {
185
79690
                view_ids
186
79690
                    .iter()
187
79690
                    .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
188
79690
                    .collect()
189
658970
            })
190
658970
    }
191

            
192
    /// Returns a collection's default encryption key, if one was defined.
193
    #[must_use]
194
2628002
    pub fn encryption_key_for_collection(&self, collection: &CollectionName) -> Option<&KeyId> {
195
2628002
        self.collection_encryption_keys.get(collection)
196
2628002
    }
197

            
198
    /// Returns a list of all collections contained in this schematic.
199
    #[must_use]
200
442
    pub fn collections(&self) -> Vec<CollectionName> {
201
442
        self.contained_collections.iter().cloned().collect()
202
442
    }
203
}
204

            
205
#[derive(Debug)]
206
struct ViewInstance<V, S> {
207
    view: V,
208
    schema: S,
209
}
210

            
211
impl<V, S> Serialized for ViewInstance<V, S>
212
where
213
    V: SerializedView,
214
    S: ViewSchema<View = V>,
215
    <V as View>::Key: 'static,
216
{
217
6852998
    fn collection(&self) -> CollectionName {
218
6852998
        <<V as View>::Collection as Collection>::collection_name()
219
6852998
    }
220

            
221
9711132
    fn unique(&self) -> bool {
222
9711132
        self.schema.unique()
223
9711132
    }
224

            
225
39096
    fn version(&self) -> u64 {
226
39096
        self.schema.version()
227
39096
    }
228

            
229
8936703
    fn view_name(&self) -> ViewName {
230
8936703
        self.view.view_name()
231
8936703
    }
232

            
233
34869
    fn map(&self, document: &BorrowedDocument<'_>) -> Result<Vec<map::Serialized>, view::Error> {
234
34869
        let map = self.schema.map(document)?;
235

            
236
34869
        map.into_iter()
237
34869
            .map(|map| map.serialized::<V>())
238
34869
            .collect::<Result<Vec<_>, view::Error>>()
239
34869
    }
240

            
241
41038
    fn reduce(&self, mappings: &[(&[u8], &[u8])], rereduce: bool) -> Result<Vec<u8>, view::Error> {
242
41038
        let mappings = mappings
243
41038
            .iter()
244
42159
            .map(|(key, value)| match <V::Key as Key>::from_ord_bytes(key) {
245
37585
                Ok(key) => {
246
37585
                    let value = V::deserialize(value)?;
247
37585
                    Ok(MappedValue::new(key, value))
248
                }
249
                Err(err) => Err(view::Error::key_serialization(err)),
250
42159
            })
251
41038
            .collect::<Result<Vec<_>, view::Error>>()?;
252

            
253
41038
        let reduced_value = self.schema.reduce(&mappings, rereduce)?;
254

            
255
10778
        V::serialize(&reduced_value).map_err(view::Error::from)
256
41038
    }
257
}
258

            
259
pub trait IdGenerator: Debug + Send + Sync {
260
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error>;
261
}
262

            
263
#[derive(Debug)]
264
2695139
#[derive_where(Default)]
265
pub struct KeyIdGenerator<C: Collection>(PhantomData<C>);
266

            
267
impl<C> IdGenerator for KeyIdGenerator<C>
268
where
269
    C: Collection,
270
{
271
189141
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error> {
272
189141
        let key = id.map(|id| id.deserialize::<C::PrimaryKey>()).transpose()?;
273
189141
        let key = if let Some(key) = key {
274
169171
            key
275
        } else {
276
19970
            <C::PrimaryKey as Key<'_>>::first_value()
277
19970
                .map_err(|err| Error::DocumentPush(C::collection_name(), err))?
278
        };
279
189141
        let next_value = key
280
189141
            .next_value()
281
189141
            .map_err(|err| Error::DocumentPush(C::collection_name(), err))?;
282
189140
        DocumentId::new(next_value)
283
189141
    }
284
}
285

            
286
1
#[test]
287
1
fn schema_tests() -> anyhow::Result<()> {
288
    use crate::test_util::{Basic, BasicCount};
289
1
    let schema = Schematic::from_schema::<Basic>()?;
290

            
291
1
    assert_eq!(schema.collections_by_type_id.len(), 1);
292
1
    assert_eq!(
293
1
        schema.collections_by_type_id[&TypeId::of::<Basic>()],
294
1
        Basic::collection_name()
295
1
    );
296
1
    assert_eq!(schema.views.len(), 4);
297
1
    assert_eq!(
298
1
        schema.views[&TypeId::of::<BasicCount>()].view_name(),
299
1
        View::view_name(&BasicCount)
300
1
    );
301

            
302
1
    Ok(())
303
1
}