1
use std::{
2
    any::TypeId,
3
    collections::{HashMap, HashSet},
4
    fmt::Debug,
5
    marker::PhantomData,
6
};
7

            
8
use derive_where::derive_where;
9

            
10
use crate::{
11
    document::{BorrowedDocument, DocumentId, KeyId},
12
    key::Key,
13
    schema::{
14
        collection::Collection,
15
        view::{
16
            self,
17
            map::{self, MappedValue},
18
            Serialized, SerializedView, ViewSchema,
19
        },
20
        CollectionName, Schema, SchemaName, View, ViewName,
21
    },
22
    Error,
23
};
24

            
25
/// A collection of defined collections and views.
26
#[derive(Debug)]
27
pub struct Schematic {
28
    /// The name of the schema this was built from.
29
    pub name: SchemaName,
30
    contained_collections: HashSet<CollectionName>,
31
    collections_by_type_id: HashMap<TypeId, CollectionName>,
32
    collection_encryption_keys: HashMap<CollectionName, KeyId>,
33
    collection_id_generators: HashMap<CollectionName, Box<dyn IdGenerator>>,
34
    views: HashMap<TypeId, Box<dyn view::Serialized>>,
35
    views_by_name: HashMap<ViewName, TypeId>,
36
    views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
37
    unique_views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
38
}
39

            
40
impl Schematic {
41
    /// Returns an initialized version from `S`.
42
127772
    pub fn from_schema<S: Schema + ?Sized>() -> Result<Self, Error> {
43
127772
        let mut schematic = Self {
44
127772
            name: S::schema_name(),
45
127772
            contained_collections: HashSet::new(),
46
127772
            collections_by_type_id: HashMap::new(),
47
127772
            collection_encryption_keys: HashMap::new(),
48
127772
            collection_id_generators: HashMap::new(),
49
127772
            views: HashMap::new(),
50
127772
            views_by_name: HashMap::new(),
51
127772
            views_by_collection: HashMap::new(),
52
127772
            unique_views_by_collection: HashMap::new(),
53
127772
        };
54
127772
        S::define_collections(&mut schematic)?;
55
127772
        Ok(schematic)
56
127772
    }
57

            
58
    /// Adds the collection `C` and its views.
59
4385685
    pub fn define_collection<C: Collection + 'static>(&mut self) -> Result<(), Error> {
60
4385685
        let name = C::collection_name();
61
4385685
        if self.contained_collections.contains(&name) {
62
            Err(Error::CollectionAlreadyDefined)
63
        } else {
64
4385685
            self.collections_by_type_id
65
4385685
                .insert(TypeId::of::<C>(), name.clone());
66
4385685
            if let Some(key) = C::encryption_key() {
67
1380566
                self.collection_encryption_keys.insert(name.clone(), key);
68
3005181
            }
69
4385747
            self.collection_id_generators
70
4385747
                .insert(name.clone(), Box::new(KeyIdGenerator::<C>::default()));
71
4385747
            self.contained_collections.insert(name);
72
4385747
            C::define_views(self)
73
        }
74
4385747
    }
75

            
76
    /// Adds the view `V`.
77
10766412
    pub fn define_view<V: ViewSchema<View = V> + SerializedView + Clone + 'static>(
78
10766412
        &mut self,
79
10766412
        view: V,
80
10766412
    ) -> Result<(), Error> {
81
10766412
        self.define_view_with_schema(view.clone(), view)
82
10766412
    }
83

            
84
    /// Adds the view `V`.
85
10766443
    pub fn define_view_with_schema<
86
10766443
        V: SerializedView + 'static,
87
10766443
        S: ViewSchema<View = V> + 'static,
88
10766443
    >(
89
10766443
        &mut self,
90
10766443
        view: V,
91
10766443
        schema: S,
92
10766443
    ) -> Result<(), Error> {
93
10766443
        let instance = ViewInstance { view, schema };
94
10766443
        let name = instance.view_name();
95
10766443
        let collection = instance.collection();
96
10766443
        let unique = instance.unique();
97
10766443
        self.views.insert(TypeId::of::<V>(), Box::new(instance));
98
10766443

            
99
10766443
        if self.views_by_name.contains_key(&name) {
100
            return Err(Error::ViewAlreadyRegistered(name));
101
10766443
        }
102
10766443
        self.views_by_name.insert(name, TypeId::of::<V>());
103
10766443

            
104
10766443
        if unique {
105
1649700
            let unique_views = self
106
1649700
                .unique_views_by_collection
107
1649700
                .entry(collection.clone())
108
1649700
                .or_insert_with(Vec::new);
109
1649700
            unique_views.push(TypeId::of::<V>());
110
9116774
        }
111
10766474
        let views = self
112
10766474
            .views_by_collection
113
10766474
            .entry(collection)
114
10766474
            .or_insert_with(Vec::new);
115
10766474
        views.push(TypeId::of::<V>());
116
10766474

            
117
10766474
        Ok(())
118
10766474
    }
119

            
120
    /// Returns `true` if this schema contains the collection `C`.
121
    #[must_use]
122
    pub fn contains_collection<C: Collection + 'static>(&self) -> bool {
123
        self.collections_by_type_id.contains_key(&TypeId::of::<C>())
124
    }
125

            
126
    /// Returns `true` if this schema contains the collection `C`.
127
    #[must_use]
128
935270
    pub fn contains_collection_name(&self, collection: &CollectionName) -> bool {
129
935270
        self.contained_collections.contains(collection)
130
935270
    }
131

            
132
    /// Returns the next id in sequence for the collection, if the primary key
133
    /// type supports the operation and the next id would not overflow.
134
469216
    pub fn next_id_for_collection(
135
469216
        &self,
136
469216
        collection: &CollectionName,
137
469216
        id: Option<DocumentId>,
138
469216
    ) -> Result<DocumentId, Error> {
139
469216
        let generator = self
140
469216
            .collection_id_generators
141
469216
            .get(collection)
142
469216
            .ok_or(Error::CollectionNotFound)?;
143
469216
        generator.next_id(id)
144
469216
    }
145

            
146
    /// Looks up a [`view::Serialized`] by name.
147
1086395
    pub fn view_by_name(&self, name: &ViewName) -> Result<&'_ dyn view::Serialized, Error> {
148
1086395
        self.views_by_name
149
1086395
            .get(name)
150
1086395
            .and_then(|type_id| self.views.get(type_id))
151
1086395
            .map(AsRef::as_ref)
152
1086395
            .ok_or(Error::ViewNotFound)
153
1086395
    }
154

            
155
    /// Looks up a [`view::Serialized`] through the the type `V`.
156
47989
    pub fn view<V: View + 'static>(&self) -> Result<&'_ dyn view::Serialized, Error> {
157
47989
        self.views
158
47989
            .get(&TypeId::of::<V>())
159
47989
            .map(AsRef::as_ref)
160
47989
            .ok_or(Error::ViewNotFound)
161
47989
    }
162

            
163
    /// Iterates over all registered views.
164
124
    pub fn views(&self) -> impl Iterator<Item = &'_ dyn view::Serialized> {
165
124
        self.views.values().map(AsRef::as_ref)
166
124
    }
167

            
168
    /// Iterates over all views that belong to `collection`.
169
    #[must_use]
170
2340159
    pub fn views_in_collection(
171
2340159
        &self,
172
2340159
        collection: &CollectionName,
173
2340159
    ) -> Option<Vec<&'_ dyn view::Serialized>> {
174
2340159
        self.views_by_collection.get(collection).map(|view_ids| {
175
1505546
            view_ids
176
1505546
                .iter()
177
5085984
                .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
178
1505546
                .collect()
179
2340159
        })
180
2340159
    }
181

            
182
    /// Iterates over all views that are unique that belong to `collection`.
183
    #[must_use]
184
918158
    pub fn unique_views_in_collection(
185
918158
        &self,
186
918158
        collection: &CollectionName,
187
918158
    ) -> Option<Vec<&'_ dyn view::Serialized>> {
188
918158
        self.unique_views_by_collection
189
918158
            .get(collection)
190
918158
            .map(|view_ids| {
191
98177
                view_ids
192
98177
                    .iter()
193
98177
                    .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
194
98177
                    .collect()
195
918158
            })
196
918158
    }
197

            
198
    /// Returns a collection's default encryption key, if one was defined.
199
    #[must_use]
200
3435544
    pub fn encryption_key_for_collection(&self, collection: &CollectionName) -> Option<&KeyId> {
201
3435544
        self.collection_encryption_keys.get(collection)
202
3435544
    }
203

            
204
    /// Returns a list of all collections contained in this schematic.
205
    #[must_use]
206
620
    pub fn collections(&self) -> Vec<CollectionName> {
207
620
        self.contained_collections.iter().cloned().collect()
208
620
    }
209
}
210

            
211
#[derive(Debug)]
212
struct ViewInstance<V, S> {
213
    view: V,
214
    schema: S,
215
}
216

            
217
impl<V, S> Serialized for ViewInstance<V, S>
218
where
219
    V: SerializedView,
220
    S: ViewSchema<View = V>,
221
    <V as View>::Key: 'static,
222
{
223
11025316
    fn collection(&self) -> CollectionName {
224
11025316
        <<V as View>::Collection as Collection>::collection_name()
225
11025316
    }
226

            
227
15697894
    fn unique(&self) -> bool {
228
15697894
        self.schema.unique()
229
15697894
    }
230

            
231
56062
    fn version(&self) -> u64 {
232
56062
        self.schema.version()
233
56062
    }
234

            
235
14393880
    fn view_name(&self) -> ViewName {
236
14393880
        self.view.view_name()
237
14393880
    }
238

            
239
49391
    fn map(&self, document: &BorrowedDocument<'_>) -> Result<Vec<map::Serialized>, view::Error> {
240
49391
        let map = self.schema.map(document)?;
241

            
242
49391
        map.into_iter()
243
49399
            .map(|map| map.serialized::<V>())
244
49391
            .collect::<Result<Vec<_>, view::Error>>()
245
49391
    }
246

            
247
52431
    fn reduce(&self, mappings: &[(&[u8], &[u8])], rereduce: bool) -> Result<Vec<u8>, view::Error> {
248
52431
        let mappings = mappings
249
52431
            .iter()
250
53469
            .map(|(key, value)| match <V::Key as Key>::from_ord_bytes(key) {
251
52356
                Ok(key) => {
252
52387
                    let value = V::deserialize(value)?;
253
52387
                    Ok(MappedValue::new(key, value))
254
                }
255
                Err(err) => Err(view::Error::key_serialization(err)),
256
53499
            })
257
52431
            .collect::<Result<Vec<_>, view::Error>>()?;
258

            
259
52431
        let reduced_value = self.schema.reduce(&mappings, rereduce)?;
260

            
261
11202
        V::serialize(&reduced_value).map_err(view::Error::from)
262
52430
    }
263
}
264

            
265
pub trait IdGenerator: Debug + Send + Sync {
266
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error>;
267
}
268

            
269
#[derive(Debug)]
270
4385685
#[derive_where(Default)]
271
pub struct KeyIdGenerator<C: Collection>(PhantomData<C>);
272

            
273
impl<C> IdGenerator for KeyIdGenerator<C>
274
where
275
    C: Collection,
276
{
277
327070
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error> {
278
327070
        let key = id.map(|id| id.deserialize::<C::PrimaryKey>()).transpose()?;
279
327070
        let key = if let Some(key) = key {
280
299869
            key
281
        } else {
282
27201
            <C::PrimaryKey as Key<'_>>::first_value()
283
27201
                .map_err(|err| Error::DocumentPush(C::collection_name(), err))?
284
        };
285
327070
        let next_value = key
286
327070
            .next_value()
287
327070
            .map_err(|err| Error::DocumentPush(C::collection_name(), err))?;
288
327056
        DocumentId::new(next_value)
289
327057
    }
290
}
291

            
292
1
#[test]
293
1
fn schema_tests() -> anyhow::Result<()> {
294
    use crate::test_util::{Basic, BasicCount};
295
1
    let schema = Schematic::from_schema::<Basic>()?;
296

            
297
1
    assert_eq!(schema.collections_by_type_id.len(), 1);
298
1
    assert_eq!(
299
1
        schema.collections_by_type_id[&TypeId::of::<Basic>()],
300
1
        Basic::collection_name()
301
1
    );
302
1
    assert_eq!(schema.views.len(), 4);
303
1
    assert_eq!(
304
1
        schema.views[&TypeId::of::<BasicCount>()].view_name(),
305
1
        View::view_name(&BasicCount)
306
1
    );
307

            
308
1
    Ok(())
309
1
}