1
use std::{borrow::Cow, collections::HashSet, convert::Infallible, hash::Hash, sync::Arc};
2

            
3
use async_lock::Mutex;
4
use async_trait::async_trait;
5
use bonsaidb_core::{
6
    document::DocumentId,
7
    schema::{CollectionName, ViewName},
8
};
9
use nebari::{
10
    io::any::AnyFile,
11
    tree::{KeyEvaluation, Operation, Unversioned, Versioned},
12
    ArcBytes, Tree,
13
};
14
use serde::{Deserialize, Serialize};
15

            
16
use super::{
17
    mapper::{Map, Mapper},
18
    view_document_map_tree_name, view_invalidated_docs_tree_name, view_versions_tree_name,
19
};
20
use crate::{
21
    database::{document_tree_name, Database},
22
    jobs::{task, Job, Keyed},
23
    tasks::Task,
24
    Error,
25
};
26

            
27
#[derive(Debug)]
28
pub struct IntegrityScanner {
29
    pub database: Database,
30
    pub scan: IntegrityScan,
31
}
32

            
33
21236
#[derive(Debug, Hash, Eq, PartialEq, Clone)]
34
pub struct IntegrityScan {
35
    pub view_version: u64,
36
    pub database: Arc<Cow<'static, str>>,
37
    pub collection: CollectionName,
38
    pub view_name: ViewName,
39
}
40

            
41
pub type OptionalViewMapHandle = Option<Arc<Mutex<Option<task::Handle<u64, Error, Task>>>>>;
42

            
43
#[async_trait]
44
impl Job for IntegrityScanner {
45
    type Output = OptionalViewMapHandle;
46
    type Error = Error;
47

            
48
24447
    #[cfg_attr(feature = "tracing", tracing::instrument)]
49
    #[allow(clippy::too_many_lines)]
50
8149
    async fn execute(&mut self) -> Result<Self::Output, Self::Error> {
51
8149
        let documents =
52
8149
            self.database
53
8149
                .roots()
54
8149
                .tree(self.database.collection_tree::<Versioned, _>(
55
8149
                    &self.scan.collection,
56
8149
                    document_tree_name(&self.scan.collection),
57
8149
                )?)?;
58

            
59
8149
        let view_versions_tree = self.database.collection_tree::<Unversioned, _>(
60
8149
            &self.scan.collection,
61
8149
            view_versions_tree_name(&self.scan.collection),
62
8149
        )?;
63
8149
        let view_versions = self.database.roots().tree(view_versions_tree.clone())?;
64

            
65
8149
        let document_map =
66
8149
            self.database
67
8149
                .roots()
68
8149
                .tree(self.database.collection_tree::<Unversioned, _>(
69
8149
                    &self.scan.collection,
70
8149
                    view_document_map_tree_name(&self.scan.view_name),
71
8149
                )?)?;
72

            
73
8149
        let invalidated_entries_tree = self.database.collection_tree::<Unversioned, _>(
74
8149
            &self.scan.collection,
75
8149
            view_invalidated_docs_tree_name(&self.scan.view_name),
76
8149
        )?;
77

            
78
8149
        let view_name = self.scan.view_name.clone();
79
8149
        let view_version = self.scan.view_version;
80
8149
        let roots = self.database.roots().clone();
81

            
82
8149
        let needs_update = tokio::task::spawn_blocking::<_, Result<bool, Error>>(move || {
83
8149
            let document_ids = tree_keys::<Versioned>(&documents)?;
84
8149
            let view_is_current_version =
85
8149
                if let Some(version) = view_versions.get(view_name.to_string().as_bytes())? {
86
52
                    if let Ok(version) = ViewVersion::from_bytes(&version) {
87
52
                        version.is_current(view_version)
88
                    } else {
89
                        false
90
                    }
91
                } else {
92
8097
                    false
93
                };
94

            
95
8149
            let missing_entries = if view_is_current_version {
96
51
                let stored_document_ids = tree_keys::<Unversioned>(&document_map)?;
97

            
98
51
                document_ids
99
51
                    .difference(&stored_document_ids)
100
51
                    .copied()
101
51
                    .collect::<HashSet<_>>()
102
            } else {
103
                // The view isn't the current version, queue up all documents.
104
8098
                document_ids
105
            };
106

            
107
8149
            if !missing_entries.is_empty() {
108
                // Add all missing entries to the invalidated list. The view
109
                // mapping job will update them on the next pass.
110
1293
                let mut transaction =
111
1293
                    roots.transaction(&[invalidated_entries_tree, view_versions_tree])?;
112
1293
                let view_versions = transaction.tree::<Unversioned>(1).unwrap();
113
1293
                view_versions.set(
114
1293
                    view_name.to_string().as_bytes().to_vec(),
115
1293
                    ViewVersion::current_for(view_version).to_vec()?,
116
                )?;
117
1293
                let invalidated_entries = transaction.tree::<Unversioned>(0).unwrap();
118
1293
                let mut missing_entries = missing_entries
119
1293
                    .into_iter()
120
38416
                    .map(|id| ArcBytes::from(id.to_vec()))
121
1293
                    .collect::<Vec<_>>();
122
1293
                missing_entries.sort();
123
1293
                invalidated_entries.modify(missing_entries, Operation::Set(ArcBytes::default()))?;
124
1293
                transaction.commit()?;
125

            
126
1293
                return Ok(true);
127
6856
            }
128
6856

            
129
6856
            Ok(false)
130
8149
        })
131
7688
        .await??;
132

            
133
8149
        let task = if needs_update {
134
            Some(Arc::new(Mutex::new(Some(
135
1293
                self.database
136
1293
                    .data
137
1293
                    .storage
138
1293
                    .tasks()
139
1293
                    .jobs
140
1293
                    .lookup_or_enqueue(Mapper {
141
1293
                        database: self.database.clone(),
142
1293
                        map: Map {
143
1293
                            database: self.database.data.name.clone(),
144
1293
                            collection: self.scan.collection.clone(),
145
1293
                            view_name: self.scan.view_name.clone(),
146
1293
                        },
147
1293
                    })
148
                    .await,
149
            ))))
150
        } else {
151
6856
            None
152
        };
153

            
154
8149
        self.database
155
8149
            .data
156
8149
            .storage
157
8149
            .tasks()
158
8149
            .mark_integrity_check_complete(
159
8149
                self.database.data.name.clone(),
160
8149
                self.scan.collection.clone(),
161
8149
                self.scan.view_name.clone(),
162
8149
            )
163
            .await;
164

            
165
8149
        Ok(task)
166
16298
    }
167
}
168

            
169
1293
#[derive(Serialize, Deserialize, Debug)]
170
pub struct ViewVersion {
171
    internal_version: u8,
172
    schema_version: u64,
173
}
174

            
175
impl ViewVersion {
176
    const CURRENT_VERSION: u8 = 1;
177
    pub fn from_bytes(bytes: &[u8]) -> Result<Self, crate::Error> {
178
52
        match pot::from_slice(bytes) {
179
52
            Ok(version) => Ok(version),
180
            Err(err) if matches!(err, pot::Error::NotAPot) && bytes.len() == 8 => {
181
                let mut be_bytes = [0_u8; 8];
182
                be_bytes.copy_from_slice(bytes);
183
                let schema_version = u64::from_be_bytes(be_bytes);
184
                Ok(Self {
185
                    internal_version: 0,
186
                    schema_version,
187
                })
188
            }
189
            Err(err) => Err(crate::Error::from(err)),
190
        }
191
52
    }
192

            
193
1293
    pub fn to_vec(&self) -> Result<Vec<u8>, crate::Error> {
194
1293
        pot::to_vec(self).map_err(crate::Error::from)
195
1293
    }
196

            
197
1293
    pub fn current_for(schema_version: u64) -> Self {
198
1293
        Self {
199
1293
            internal_version: Self::CURRENT_VERSION,
200
1293
            schema_version,
201
1293
        }
202
1293
    }
203

            
204
52
    pub fn is_current(&self, schema_version: u64) -> bool {
205
52
        self.internal_version == Self::CURRENT_VERSION && self.schema_version == schema_version
206
52
    }
207
}
208

            
209
8200
fn tree_keys<R: nebari::tree::Root>(
210
8200
    tree: &Tree<R, AnyFile>,
211
8200
) -> Result<HashSet<DocumentId>, crate::Error> {
212
8200
    let mut ids = Vec::new();
213
8200
    tree.scan::<Infallible, _, _, _, _>(
214
8200
        &(..),
215
8200
        true,
216
11675
        |_, _, _| true,
217
38775
        |key, _| {
218
38720
            ids.push(key.clone());
219
38720
            KeyEvaluation::Skip
220
38775
        },
221
8200
        |_, _, _| unreachable!(),
222
8200
    )?;
223

            
224
8200
    Ok(ids
225
8200
        .into_iter()
226
38775
        .map(|key| DocumentId::try_from(key.as_slice()))
227
8200
        .collect::<Result<HashSet<_>, bonsaidb_core::Error>>()?)
228
8200
}
229

            
230
impl Keyed<Task> for IntegrityScanner {
231
8175
    fn key(&self) -> Task {
232
8175
        Task::IntegrityScan(self.scan.clone())
233
8175
    }
234
}
235

            
236
// The reason we use jobs like this is to make sure we can tweak how much is
237
// happening at any given time.
238
//
239
// On the Server level, we'll need to cooperate with all the databases in a
240
// shared pool of workers. So, we need to come up with a design for the view
241
// updaters to work within this limitation.
242
//
243
// Integrity scan is simple: Have a shared structure on Database that keeps track
244
// of all integrity scan results. It can check for an existing value and return,
245
// or make you wait until the job is finished. For views, I suppose the best
246
// that can be done is a similar approach, but the indexer's output is the last
247
// transaction id it synced. When a request comes in, a check can be done if
248
// there are any docs outdated, if so, the client can get the current transaction id
249
// and ask the ViewScanning service manager to wait until that txid is scanned.
250
//
251
// The view can then scan and return the results it finds with confidence it was updated to that time.
252
// If new requests come in while the current batch is being caught up to,