1
use std::{
2
    any::TypeId,
3
    collections::{HashMap, HashSet},
4
    fmt::Debug,
5
    marker::PhantomData,
6
};
7

            
8
use derive_where::derive_where;
9

            
10
use crate::{
11
    document::{BorrowedDocument, DocumentId, KeyId},
12
    key::Key,
13
    schema::{
14
        collection::Collection,
15
        view::{
16
            self,
17
            map::{self, MappedValue},
18
            Serialized, SerializedView, ViewSchema,
19
        },
20
        CollectionName, Schema, SchemaName, View, ViewName,
21
    },
22
    Error,
23
};
24

            
25
/// A collection of defined collections and views.
26
#[derive(Debug)]
27
pub struct Schematic {
28
    /// The name of the schema this was built from.
29
    pub name: SchemaName,
30
    contained_collections: HashSet<CollectionName>,
31
    collections_by_type_id: HashMap<TypeId, CollectionName>,
32
    collection_encryption_keys: HashMap<CollectionName, KeyId>,
33
    collection_id_generators: HashMap<CollectionName, Box<dyn IdGenerator>>,
34
    views: HashMap<TypeId, Box<dyn view::Serialized>>,
35
    views_by_name: HashMap<ViewName, TypeId>,
36
    views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
37
    unique_views_by_collection: HashMap<CollectionName, Vec<TypeId>>,
38
}
39

            
40
impl Schematic {
41
    /// Returns an initialized version from `S`.
42
125965
    pub fn from_schema<S: Schema + ?Sized>() -> Result<Self, Error> {
43
125965
        let mut schematic = Self {
44
125965
            name: S::schema_name(),
45
125965
            contained_collections: HashSet::new(),
46
125965
            collections_by_type_id: HashMap::new(),
47
125965
            collection_encryption_keys: HashMap::new(),
48
125965
            collection_id_generators: HashMap::new(),
49
125965
            views: HashMap::new(),
50
125965
            views_by_name: HashMap::new(),
51
125965
            views_by_collection: HashMap::new(),
52
125965
            unique_views_by_collection: HashMap::new(),
53
125965
        };
54
125965
        S::define_collections(&mut schematic)?;
55
125965
        Ok(schematic)
56
125965
    }
57

            
58
    /// Adds the collection `C` and its views.
59
4376085
    pub fn define_collection<C: Collection + 'static>(&mut self) -> Result<(), Error> {
60
4376085
        let name = C::collection_name();
61
4376085
        if self.contained_collections.contains(&name) {
62
            Err(Error::CollectionAlreadyDefined)
63
        } else {
64
4376085
            self.collections_by_type_id
65
4376085
                .insert(TypeId::of::<C>(), name.clone());
66
4376085
            if let Some(key) = C::encryption_key() {
67
1380907
                self.collection_encryption_keys.insert(name.clone(), key);
68
2995209
            }
69
4376116
            self.collection_id_generators
70
4376116
                .insert(name.clone(), Box::new(KeyIdGenerator::<C>::default()));
71
4376116
            self.contained_collections.insert(name);
72
4376116
            C::define_views(self)
73
        }
74
4376116
    }
75

            
76
    /// Adds the view `V`.
77
10763627
    pub fn define_view<V: ViewSchema<View = V> + SerializedView + Clone + 'static>(
78
10763627
        &mut self,
79
10763627
        view: V,
80
10763627
    ) -> Result<(), Error> {
81
10763627
        self.define_view_with_schema(view.clone(), view)
82
10763627
    }
83

            
84
    /// Adds the view `V`.
85
10763627
    pub fn define_view_with_schema<
86
10763627
        V: SerializedView + 'static,
87
10763627
        S: ViewSchema<View = V> + 'static,
88
10763627
    >(
89
10763627
        &mut self,
90
10763627
        view: V,
91
10763627
        schema: S,
92
10763627
    ) -> Result<(), Error> {
93
10763627
        let instance = ViewInstance { view, schema };
94
10763627
        let name = instance.view_name();
95
10763627
        let collection = instance.collection();
96
10763627
        let unique = instance.unique();
97
10763627
        self.views.insert(TypeId::of::<V>(), Box::new(instance));
98
10763627

            
99
10763627
        if self.views_by_name.contains_key(&name) {
100
            return Err(Error::ViewAlreadyRegistered(name));
101
10763627
        }
102
10763627
        self.views_by_name.insert(name, TypeId::of::<V>());
103
10763627

            
104
10763627
        if unique {
105
1648441
            let unique_views = self
106
1648441
                .unique_views_by_collection
107
1648441
                .entry(collection.clone())
108
1648441
                .or_insert_with(Vec::new);
109
1648441
            unique_views.push(TypeId::of::<V>());
110
9115155
        }
111
10763596
        let views = self
112
10763596
            .views_by_collection
113
10763596
            .entry(collection)
114
10763596
            .or_insert_with(Vec::new);
115
10763596
        views.push(TypeId::of::<V>());
116
10763596

            
117
10763596
        Ok(())
118
10763596
    }
119

            
120
    /// Returns `true` if this schema contains the collection `C`.
121
    #[must_use]
122
    pub fn contains_collection<C: Collection + 'static>(&self) -> bool {
123
        self.collections_by_type_id.contains_key(&TypeId::of::<C>())
124
    }
125

            
126
    /// Returns `true` if this schema contains the collection `C`.
127
    #[must_use]
128
886910
    pub fn contains_collection_name(&self, collection: &CollectionName) -> bool {
129
886910
        self.contained_collections.contains(collection)
130
886910
    }
131

            
132
    /// Returns the next id in sequence for the collection, if the primary key
133
    /// type supports the operation and the next id would not overflow.
134
442804
    pub fn next_id_for_collection(
135
442804
        &self,
136
442804
        collection: &CollectionName,
137
442804
        id: Option<DocumentId>,
138
442804
    ) -> Result<DocumentId, Error> {
139
442804
        let generator = self
140
442804
            .collection_id_generators
141
442804
            .get(collection)
142
442804
            .ok_or(Error::CollectionNotFound)?;
143
442804
        generator.next_id(id)
144
442804
    }
145

            
146
    /// Looks up a [`view::Serialized`] by name.
147
1202025
    pub fn view_by_name(&self, name: &ViewName) -> Result<&'_ dyn view::Serialized, Error> {
148
1202025
        self.views_by_name
149
1202025
            .get(name)
150
1202025
            .and_then(|type_id| self.views.get(type_id))
151
1202025
            .map(AsRef::as_ref)
152
1202025
            .ok_or(Error::ViewNotFound)
153
1202025
    }
154

            
155
    /// Looks up a [`view::Serialized`] through the the type `V`.
156
46337
    pub fn view<V: View + 'static>(&self) -> Result<&'_ dyn view::Serialized, Error> {
157
46337
        self.views
158
46337
            .get(&TypeId::of::<V>())
159
46337
            .map(AsRef::as_ref)
160
46337
            .ok_or(Error::ViewNotFound)
161
46337
    }
162

            
163
    /// Iterates over all registered views.
164
124
    pub fn views(&self) -> impl Iterator<Item = &'_ dyn view::Serialized> {
165
124
        self.views.values().map(AsRef::as_ref)
166
124
    }
167

            
168
    /// Iterates over all views that belong to `collection`.
169
    #[must_use]
170
2273230
    pub fn views_in_collection(
171
2273230
        &self,
172
2273230
        collection: &CollectionName,
173
2273230
    ) -> Option<Vec<&'_ dyn view::Serialized>> {
174
2273230
        self.views_by_collection.get(collection).map(|view_ids| {
175
1495936
            view_ids
176
1495936
                .iter()
177
5096865
                .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
178
1495936
                .collect()
179
2273230
        })
180
2273230
    }
181

            
182
    /// Iterates over all views that are unique that belong to `collection`.
183
    #[must_use]
184
869829
    pub fn unique_views_in_collection(
185
869829
        &self,
186
869829
        collection: &CollectionName,
187
869829
    ) -> Option<Vec<&'_ dyn view::Serialized>> {
188
869829
        self.unique_views_by_collection
189
869829
            .get(collection)
190
869829
            .map(|view_ids| {
191
114297
                view_ids
192
114297
                    .iter()
193
114297
                    .filter_map(|id| self.views.get(id).map(AsRef::as_ref))
194
114297
                    .collect()
195
869829
            })
196
869829
    }
197

            
198
    /// Returns a collection's default encryption key, if one was defined.
199
    #[must_use]
200
3262099
    pub fn encryption_key_for_collection(&self, collection: &CollectionName) -> Option<&KeyId> {
201
3262099
        self.collection_encryption_keys.get(collection)
202
3262099
    }
203

            
204
    /// Returns a list of all collections contained in this schematic.
205
    #[must_use]
206
651
    pub fn collections(&self) -> Vec<CollectionName> {
207
651
        self.contained_collections.iter().cloned().collect()
208
651
    }
209
}
210

            
211
#[derive(Debug)]
212
struct ViewInstance<V, S> {
213
    view: V,
214
    schema: S,
215
}
216

            
217
impl<V, S> Serialized for ViewInstance<V, S>
218
where
219
    V: SerializedView,
220
    S: ViewSchema<View = V>,
221
    <V as View>::Key: 'static,
222
{
223
11021613
    fn collection(&self) -> CollectionName {
224
11021613
        <<V as View>::Collection as Collection>::collection_name()
225
11021613
    }
226

            
227
15700901
    fn unique(&self) -> bool {
228
15700901
        self.schema.unique()
229
15700901
    }
230

            
231
56880
    fn version(&self) -> u64 {
232
56880
        self.schema.version()
233
56880
    }
234

            
235
14392228
    fn view_name(&self) -> ViewName {
236
14392228
        self.view.view_name()
237
14392228
    }
238

            
239
49287
    fn map(&self, document: &BorrowedDocument<'_>) -> Result<Vec<map::Serialized>, view::Error> {
240
49287
        let map = self.schema.map(document)?;
241

            
242
49287
        map.into_iter()
243
49295
            .map(|map| map.serialized::<V>())
244
49287
            .collect::<Result<Vec<_>, view::Error>>()
245
49287
    }
246

            
247
58956
    fn reduce(&self, mappings: &[(&[u8], &[u8])], rereduce: bool) -> Result<Vec<u8>, view::Error> {
248
58956
        let mappings = mappings
249
58956
            .iter()
250
59784
            .map(|(key, value)| match <V::Key as Key>::from_ord_bytes(key) {
251
49790
                Ok(key) => {
252
49790
                    let value = V::deserialize(value)?;
253
49790
                    Ok(MappedValue::new(key, value))
254
                }
255
                Err(err) => Err(view::Error::key_serialization(err)),
256
59784
            })
257
58956
            .collect::<Result<Vec<_>, view::Error>>()?;
258

            
259
58956
        let reduced_value = self.schema.reduce(&mappings, rereduce)?;
260

            
261
17208
        V::serialize(&reduced_value).map_err(view::Error::from)
262
58925
    }
263
}
264

            
265
pub trait IdGenerator: Debug + Send + Sync {
266
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error>;
267
}
268

            
269
#[derive(Debug)]
270
4376116
#[derive_where(Default)]
271
pub struct KeyIdGenerator<C: Collection>(PhantomData<C>);
272

            
273
impl<C> IdGenerator for KeyIdGenerator<C>
274
where
275
    C: Collection,
276
{
277
326458
    fn next_id(&self, id: Option<DocumentId>) -> Result<DocumentId, Error> {
278
326458
        let key = id.map(|id| id.deserialize::<C::PrimaryKey>()).transpose()?;
279
326458
        let key = if let Some(key) = key {
280
299257
            key
281
        } else {
282
27201
            <C::PrimaryKey as Key<'_>>::first_value()
283
27201
                .map_err(|err| Error::DocumentPush(C::collection_name(), err))?
284
        };
285
326458
        let next_value = key
286
326458
            .next_value()
287
326458
            .map_err(|err| Error::DocumentPush(C::collection_name(), err))?;
288
326457
        DocumentId::new(next_value)
289
326458
    }
290
}
291

            
292
1
#[test]
293
1
fn schema_tests() -> anyhow::Result<()> {
294
    use crate::test_util::{Basic, BasicCount};
295
1
    let schema = Schematic::from_schema::<Basic>()?;
296

            
297
1
    assert_eq!(schema.collections_by_type_id.len(), 1);
298
1
    assert_eq!(
299
1
        schema.collections_by_type_id[&TypeId::of::<Basic>()],
300
1
        Basic::collection_name()
301
1
    );
302
1
    assert_eq!(schema.views.len(), 4);
303
1
    assert_eq!(
304
1
        schema.views[&TypeId::of::<BasicCount>()].view_name(),
305
1
        View::view_name(&BasicCount)
306
1
    );
307

            
308
1
    Ok(())
309
1
}