1
use std::{
2
    any::TypeId,
3
    collections::{HashMap, HashSet},
4
    fmt::Debug,
5
    marker::PhantomData,
6
};
7

            
8
use derive_where::derive_where;
9

            
10
use crate::{
11
    document::{BorrowedDocument, DocumentId, KeyId},
12
    key::Key,
13
    schema::{
14
        collection::Collection,
15
        view::{
16
            self,
17
            map::{self, MappedValue},
18
            Serialized, SerializedView, ViewSchema,
19
        },
20
        CollectionName, Schema, SchemaName, View, ViewName,
21
    },
22
    Error,
23
};
24

            
25
/// A collection of defined collections and views.
26
#[derive(Debug)]
27
pub struct Schematic {
28
    /// The name of the schema this was built from.
29
    pub name: SchemaName,
30
    contained_collections: HashSet<CollectionName>,
31
    collections_by_type_id: HashMap<TypeId, CollectionName>,
32
    collection_encryption_keys: HashMap<CollectionName, KeyId>,
33
    collection_id_generators: HashMap<CollectionName, Box<dyn IdGenerator>>,
34
    views: HashMap<TypeId, Box<dyn view::Serialized>>,
35
    views_by_name: HashMap<ViewName, TypeId>,
36
    views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
37
    unique_views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
38
}
39

            
40
impl Schematic {
41
    /// Returns an initialized version from `S`.
42
126503
    pub fn from_schema<S: Schema + ?Sized>() -> Result<Self, Error> {
43
126503
        let mut schematic = Self {
44
126503
            name: S::schema_name(),
45
126503
            contained_collections: HashSet::new(),
46
126503
            collections_by_type_id: HashMap::new(),
47
126503
            collection_encryption_keys: HashMap::new(),
48
126503
            collection_id_generators: HashMap::new(),
49
126503
            views: HashMap::new(),
50
126503
            views_by_name: HashMap::new(),
51
126503
            views_by_collection: HashMap::new(),
52
126503
            unique_views_by_collection: HashMap::new(),
53
126503
        };
54
126503
        S::define_collections(&mut schematic)?;
55
126503
        Ok(schematic)
56
126503
    }
57

            
58
    /// Adds the collection `C` and its views.
59
4379133
    pub fn define_collection<C: Collection + 'static>(&mut self) -> Result<(), Error> {
60
4379133
        let name = C::collection_name();
61
4379133
        if self.contained_collections.contains(&name) {
62
            Err(Error::CollectionAlreadyDefined)
63
        } else {
64
4379133
            self.collections_by_type_id
65
4379133
                .insert(TypeId::of::<C>(), name.clone());
66
4379133
            if let Some(key) = C::encryption_key() {
67
1380907
                self.collection_encryption_keys.insert(name.clone(), key);
68
2998257
            }
69
4379164
            self.collection_id_generators
70
4379164
                .insert(name.clone(), Box::new(KeyIdGenerator::<C>::default()));
71
4379164
            self.contained_collections.insert(name);
72
4379164
            C::define_views(self)
73
        }
74
4379164
    }
75

            
76
    /// Adds the view `V`.
77
10765120
    pub fn define_view<V: ViewSchema<View = V> + SerializedView + Clone + 'static>(
78
10765120
        &mut self,
79
10765120
        view: V,
80
10765120
    ) -> Result<(), Error> {
81
10765120
        self.define_view_with_schema(view.clone(), view)
82
10765120
    }
83

            
84
    /// Adds the view `V`.
85
10765151
    pub fn define_view_with_schema<
86
10765151
        V: SerializedView + 'static,
87
10765151
        S: ViewSchema<View = V> + 'static,
88
10765151
    >(
89
10765151
        &mut self,
90
10765151
        view: V,
91
10765151
        schema: S,
92
10765151
    ) -> Result<(), Error> {
93
10765151
        let instance = ViewInstance { view, schema };
94
10765151
        let name = instance.view_name();
95
10765151
        let collection = instance.collection();
96
10765151
        let unique = instance.unique();
97
10765151
        self.views.insert(TypeId::of::<V>(), Box::new(instance));
98
10765151

            
99
10765151
        if self.views_by_name.contains_key(&name) {
100
            return Err(Error::ViewAlreadyRegistered(name));
101
10765151
        }
102
10765151
        self.views_by_name.insert(name, TypeId::of::<V>());
103
10765151

            
104
10765151
        if unique {
105
1648949
            let unique_views = self
106
1648949
                .unique_views_by_collection
107
1648949
                .entry(collection.clone())
108
1648949
                .or_insert_with(Vec::new);
109
1648949
            unique_views.push(TypeId::of::<V>());
110
9116171
        }
111
10765120
        let views = self
112
10765120
            .views_by_collection
113
10765120
            .entry(collection)
114
10765120
            .or_insert_with(Vec::new);
115
10765120
        views.push(TypeId::of::<V>());
116
10765120

            
117
10765120
        Ok(())
118
10765120
    }
119

            
120
    /// Returns `true` if this schema contains the collection `C`.
121
    #[must_use]
122
    pub fn contains_collection<C: Collection + 'static>(&self) -> bool {
123
        self.collections_by_type_id.contains_key(&TypeId::of::<C>())
124
    }
125

            
126
    /// Returns `true` if this schema contains the collection `C`.
127
    #[must_use]
128
1051334
    pub fn contains_collection_name(&self, collection: &CollectionName) -> bool {
129
1051334
        self.contained_collections.contains(collection)
130
1051334
    }
131

            
132
    /// Returns the next id in sequence for the collection, if the primary key
133
    /// type supports the operation and the next id would not overflow.
134
448260
    pub fn next_id_for_collection(
135
448260
        &self,
136
448260
        collection: &CollectionName,
137
448260
        id: Option<DocumentId>,
138
448260
    ) -> Result<DocumentId, Error> {
139
448260
        let generator = self
140
448260
            .collection_id_generators
141
448260
            .get(collection)
142
448260
            .ok_or(Error::CollectionNotFound)?;
143
448260
        generator.next_id(id)
144
448260
    }
145

            
146
    /// Looks up a [`view::Serialized`] by name.
147
1201033
    pub fn view_by_name(&self, name: &ViewName) -> Result<&'_ dyn view::Serialized, Error> {
148
1201033
        self.views_by_name
149
1201033
            .get(name)
150
1201033
            .and_then(|type_id| self.views.get(type_id))
151
1201033
            .map(AsRef::as_ref)
152
1201033
            .ok_or(Error::ViewNotFound)
153
1201033
    }
154

            
155
    /// Looks up a [`view::Serialized`] through the the type `V`.
156
46845
    pub fn view<V: View + 'static>(&self) -> Result<&'_ dyn view::Serialized, Error> {
157
46845
        self.views
158
46845
            .get(&TypeId::of::<V>())
159
46845
            .map(AsRef::as_ref)
160
46845
            .ok_or(Error::ViewNotFound)
161
46845
    }
162

            
163
    /// Iterates over all registered views.
164
124
    pub fn views(&self) -> impl Iterator<Item = &'_ dyn view::Serialized> {
165
124
        self.views.values().map(AsRef::as_ref)
166
124
    }
167

            
168
    /// Iterates over all views that belong to `collection`.
169
    #[must_use]
170
2439638
    pub fn views_in_collection(
171
2439638
        &self,
172
2439638
        collection: &CollectionName,
173
2439638
    ) -> Option<Vec<&'_ dyn view::Serialized>> {
174
2439638
        self.views_by_collection.get(collection).map(|view_ids| {
175
1500865
            view_ids
176
1500865
                .iter()
177
5102414
                .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
178
1500865
                .collect()
179
2439638
        })
180
2439638
    }
181

            
182
    /// Iterates over all views that are unique that belong to `collection`.
183
    #[must_use]
184
1034253
    pub fn unique_views_in_collection(
185
1034253
        &self,
186
1034253
        collection: &CollectionName,
187
1034253
    ) -> Option<Vec<&'_ dyn view::Serialized>> {
188
1034253
        self.unique_views_by_collection
189
1034253
            .get(collection)
190
1034253
            .map(|view_ids| {
191
114793
                view_ids
192
114793
                    .iter()
193
114793
                    .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
194
114793
                    .collect()
195
1034253
            })
196
1034253
    }
197

            
198
    /// Returns a collection's default encryption key, if one was defined.
199
    #[must_use]
200
3460995
    pub fn encryption_key_for_collection(&self, collection: &CollectionName) -> Option<&KeyId> {
201
3460995
        self.collection_encryption_keys.get(collection)
202
3460995
    }
203

            
204
    /// Returns a list of all collections contained in this schematic.
205
    #[must_use]
206
651
    pub fn collections(&self) -> Vec<CollectionName> {
207
651
        self.contained_collections.iter().cloned().collect()
208
651
    }
209
}
210

            
211
#[derive(Debug)]
212
struct ViewInstance<V, S> {
213
    view: V,
214
    schema: S,
215
}
216

            
217
impl<V, S> Serialized for ViewInstance<V, S>
218
where
219
    V: SerializedView,
220
    S: ViewSchema<View = V>,
221
    <V as View>::Key: 'static,
222
{
223
11023964
    fn collection(&self) -> CollectionName {
224
11023964
        <<V as View>::Collection as Collection>::collection_name()
225
11023964
    }
226

            
227
15703018
    fn unique(&self) -> bool {
228
15703018
        self.schema.unique()
229
15703018
    }
230

            
231
56805
    fn version(&self) -> u64 {
232
56805
        self.schema.version()
233
56805
    }
234

            
235
14395797
    fn view_name(&self) -> ViewName {
236
14395797
        self.view.view_name()
237
14395797
    }
238

            
239
49487
    fn map(&self, document: &BorrowedDocument<'_>) -> Result<Vec<map::Serialized>, view::Error> {
240
49487
        let map = self.schema.map(document)?;
241

            
242
49487
        map.into_iter()
243
49495
            .map(|map| map.serialized::<V>())
244
49487
            .collect::<Result<Vec<_>, view::Error>>()
245
49487
    }
246

            
247
58511
    fn reduce(&self, mappings: &[(&[u8], &[u8])], rereduce: bool) -> Result<Vec<u8>, view::Error> {
248
58511
        let mappings = mappings
249
58511
            .iter()
250
59313
            .map(|(key, value)| match <V::Key as Key>::from_ord_bytes(key) {
251
49998
                Ok(key) => {
252
49998
                    let value = V::deserialize(value)?;
253
49998
                    Ok(MappedValue::new(key, value))
254
                }
255
                Err(err) => Err(view::Error::key_serialization(err)),
256
59313
            })
257
58511
            .collect::<Result<Vec<_>, view::Error>>()?;
258

            
259
58511
        let reduced_value = self.schema.reduce(&mappings, rereduce)?;
260

            
261
16747
        V::serialize(&reduced_value).map_err(view::Error::from)
262
58480
    }
263
}
264

            
265
pub trait IdGenerator: Debug + Send + Sync {
266
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error>;
267
}
268

            
269
#[derive(Debug)]
270
4379164
#[derive_where(Default)]
271
pub struct KeyIdGenerator<C: Collection>(PhantomData<C>);
272

            
273
impl<C> IdGenerator for KeyIdGenerator<C>
274
where
275
    C: Collection,
276
{
277
326634
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error> {
278
326634
        let key = id.map(|id| id.deserialize::<C::PrimaryKey>()).transpose()?;
279
326634
        let key = if let Some(key) = key {
280
299433
            key
281
        } else {
282
27201
            <C::PrimaryKey as Key<'_>>::first_value()
283
27201
                .map_err(|err| Error::DocumentPush(C::collection_name(), err))?
284
        };
285
326634
        let next_value = key
286
326634
            .next_value()
287
326634
            .map_err(|err| Error::DocumentPush(C::collection_name(), err))?;
288
326633
        DocumentId::new(next_value)
289
326634
    }
290
}
291

            
292
1
#[test]
293
1
fn schema_tests() -> anyhow::Result<()> {
294
    use crate::test_util::{Basic, BasicCount};
295
1
    let schema = Schematic::from_schema::<Basic>()?;
296

            
297
1
    assert_eq!(schema.collections_by_type_id.len(), 1);
298
1
    assert_eq!(
299
1
        schema.collections_by_type_id[&TypeId::of::<Basic>()],
300
1
        Basic::collection_name()
301
1
    );
302
1
    assert_eq!(schema.views.len(), 4);
303
1
    assert_eq!(
304
1
        schema.views[&TypeId::of::<BasicCount>()].view_name(),
305
1
        View::view_name(&BasicCount)
306
1
    );
307

            
308
1
    Ok(())
309
1
}