1
use std::{
2
    collections::{hash_map, HashMap},
3
    fmt::Debug,
4
    net::SocketAddr,
5
    ops::Deref,
6
    path::PathBuf,
7
    sync::{
8
        atomic::{AtomicU32, AtomicUsize, Ordering},
9
        Arc,
10
    },
11
    time::Duration,
12
};
13

            
14
use async_lock::{Mutex, RwLock};
15
use async_trait::async_trait;
16
#[cfg(feature = "password-hashing")]
17
use bonsaidb_core::connection::Authentication;
18
use bonsaidb_core::{
19
    admin::{Admin, ADMIN_DATABASE_NAME},
20
    api,
21
    arc_bytes::serde::Bytes,
22
    connection::{
23
        self, AsyncConnection, AsyncStorageConnection, HasSession, IdentityReference, Session,
24
        SessionId,
25
    },
26
    networking::{self, Payload, CURRENT_PROTOCOL_VERSION},
27
    permissions::{
28
        bonsai::{bonsaidb_resource_name, BonsaiAction, ServerAction},
29
        Permissions,
30
    },
31
    schema::{self, ApiName, Nameable, NamedCollection, Schema},
32
};
33
use bonsaidb_local::{config::Builder, AsyncStorage, Storage, StorageNonBlocking};
34
use bonsaidb_utils::{fast_async_lock, fast_async_read, fast_async_write};
35
use derive_where::derive_where;
36
use fabruic::{self, CertificateChain, Endpoint, KeyPair, PrivateKey};
37
use flume::Sender;
38
use futures::{Future, StreamExt};
39
use rustls::sign::CertifiedKey;
40
use schema::SchemaName;
41
#[cfg(not(windows))]
42
use signal_hook::consts::SIGQUIT;
43
use signal_hook::consts::{SIGINT, SIGTERM};
44
use tokio::sync::{oneshot, Notify};
45

            
46
#[cfg(feature = "acme")]
47
use crate::config::AcmeConfiguration;
48
use crate::{
49
    api::{AnyHandler, HandlerSession},
50
    backend::ConnectionHandling,
51
    dispatch::{register_api_handlers, ServerDispatcher},
52
    error::Error,
53
    hosted::{Hosted, SerializablePrivateKey, TlsCertificate, TlsCertificatesByDomain},
54
    server::shutdown::{Shutdown, ShutdownState},
55
    Backend, BackendError, NoBackend, ServerConfiguration,
56
};
57

            
58
#[cfg(feature = "acme")]
59
pub mod acme;
60
mod connected_client;
61
mod database;
62

            
63
mod shutdown;
64
mod tcp;
65
#[cfg(feature = "websockets")]
66
mod websockets;
67

            
68
use self::connected_client::OwnedClient;
69
pub use self::{
70
    connected_client::{ConnectedClient, LockedClientDataGuard, Transport},
71
    database::ServerDatabase,
72
    tcp::{ApplicationProtocols, HttpService, Peer, StandardTcpProtocols, TcpService},
73
};
74

            
75
static CONNECTED_CLIENT_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
76

            
77
/// A BonsaiDb server.
78
#[derive(Debug)]
79
72160
#[derive_where(Clone)]
80
pub struct CustomServer<B: Backend = NoBackend> {
81
    data: Arc<Data<B>>,
82
    pub(crate) storage: AsyncStorage,
83
}
84

            
85
impl<'a, B: Backend> From<&'a CustomServer<B>> for Storage {
86
    fn from(server: &'a CustomServer<B>) -> Self {
87
        Self::from(server.storage.clone())
88
    }
89
}
90

            
91
impl<B: Backend> From<CustomServer<B>> for Storage {
92
    fn from(server: CustomServer<B>) -> Self {
93
        Self::from(server.storage)
94
    }
95
}
96

            
97
/// A BonsaiDb server without a custom backend.
98
pub type Server = CustomServer<NoBackend>;
99

            
100
#[derive(Debug)]
101
struct Data<B: Backend = NoBackend> {
102
    clients: RwLock<HashMap<u32, ConnectedClient<B>>>,
103
    request_processor: flume::Sender<ClientRequest<B>>,
104
    default_session: Session,
105
    endpoint: RwLock<Option<Endpoint>>,
106
    client_simultaneous_request_limit: usize,
107
    primary_tls_key: CachedCertifiedKey,
108
    primary_domain: String,
109
    custom_apis: parking_lot::RwLock<HashMap<ApiName, Arc<dyn AnyHandler<B>>>>,
110
    #[cfg(feature = "acme")]
111
    acme: AcmeConfiguration,
112
    #[cfg(feature = "acme")]
113
    alpn_keys: AlpnKeys,
114
    shutdown: Shutdown,
115
}
116

            
117
82
#[derive(Default)]
118
struct CachedCertifiedKey(parking_lot::Mutex<Option<Arc<CertifiedKey>>>);
119

            
120
impl Debug for CachedCertifiedKey {
121
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122
        f.debug_tuple("CachedCertifiedKey").finish()
123
    }
124
}
125

            
126
impl Deref for CachedCertifiedKey {
127
    type Target = parking_lot::Mutex<Option<Arc<CertifiedKey>>>;
128

            
129
606
    fn deref(&self) -> &Self::Target {
130
606
        &self.0
131
606
    }
132
}
133

            
134
impl<B: Backend> CustomServer<B> {
135
    /// Opens a server using `directory` for storage.
136
82
    pub async fn open(
137
82
        configuration: ServerConfiguration<B>,
138
82
    ) -> Result<Self, BackendError<B::Error>> {
139
82
        let configuration = register_api_handlers(B::configure(configuration)?)?;
140
82
        let (request_sender, request_receiver) = flume::unbounded::<ClientRequest<B>>();
141
1312
        for _ in 0..configuration.request_workers {
142
1312
            let request_receiver = request_receiver.clone();
143
1312
            tokio::task::spawn(async move {
144
101854
                while let Ok(mut client_request) = request_receiver.recv_async().await {
145
71139
                    let request = client_request.request.take().unwrap();
146
71139
                    let session = client_request.session.clone();
147
                    // TODO we should be able to upgrade a session-less Storage to one with a Session.
148
                    // The Session needs to be looked up from the client based on the request's session id.
149
71139
                    let result = match client_request.server.storage.assume_session(session) {
150
71139
                        Ok(storage) => {
151
71139
                            let client = HandlerSession {
152
71139
                                server: &client_request.server,
153
71139
                                client: &client_request.client,
154
71139
                                as_client: Self {
155
71139
                                    data: client_request.server.data.clone(),
156
71139
                                    storage,
157
71139
                                },
158
71139
                            };
159
71139
                            ServerDispatcher::dispatch_api_request(
160
71139
                                client,
161
71139
                                &request.name,
162
71139
                                request.value.unwrap(),
163
97459
                            )
164
97449
                            .await
165
71139
                            .map_err(bonsaidb_core::Error::from)
166
                        }
167
                        Err(err) => Err(err),
168
                    };
169
71139
                    drop(client_request.result_sender.send((request.name, result)));
170
                }
171
1312
            });
172
1312
        }
173

            
174
82
        let storage = AsyncStorage::open(configuration.storage.with_schema::<Hosted>()?).await?;
175

            
176
164
        storage.create_database::<Hosted>("_hosted", true).await?;
177

            
178
82
        let default_permissions = Permissions::from(configuration.default_permissions);
179
82

            
180
82
        let server = Self {
181
82
            storage,
182
82
            data: Arc::new(Data {
183
82
                clients: RwLock::default(),
184
82
                endpoint: RwLock::default(),
185
82
                request_processor: request_sender,
186
82
                default_session: Session {
187
82
                    permissions: default_permissions,
188
82
                    ..Session::default()
189
82
                },
190
82
                client_simultaneous_request_limit: configuration.client_simultaneous_request_limit,
191
82
                primary_tls_key: CachedCertifiedKey::default(),
192
82
                primary_domain: configuration.server_name,
193
82
                custom_apis: parking_lot::RwLock::new(configuration.custom_apis),
194
82
                #[cfg(feature = "acme")]
195
82
                acme: configuration.acme,
196
82
                #[cfg(feature = "acme")]
197
82
                alpn_keys: AlpnKeys::default(),
198
82
                shutdown: Shutdown::new(),
199
82
            }),
200
82
        };
201
88
        B::initialize(&server).await?;
202
82
        Ok(server)
203
82
    }
204

            
205
    /// Returns the path to the public pinned certificate, if this server has
206
    /// one. Note: this function will always succeed, but the file may not
207
    /// exist.
208
    #[must_use]
209
152
    pub fn pinned_certificate_path(&self) -> PathBuf {
210
152
        self.path().join("pinned-certificate.der")
211
152
    }
212

            
213
    /// Returns the primary domain configured for this server.
214
    #[must_use]
215
34
    pub fn primary_domain(&self) -> &str {
216
34
        &self.data.primary_domain
217
34
    }
218

            
219
    /// Returns the administration database.
220
2
    pub async fn admin(&self) -> ServerDatabase<B> {
221
2
        let db = self.storage.admin().await;
222
2
        ServerDatabase {
223
2
            server: self.clone(),
224
2
            db,
225
2
        }
226
2
    }
227

            
228
288
    pub(crate) async fn hosted(&self) -> ServerDatabase<B> {
229
288
        let db = self.storage.database::<Hosted>("_hosted").await.unwrap();
230
288
        ServerDatabase {
231
288
            server: self.clone(),
232
288
            db,
233
288
        }
234
288
    }
235

            
236
71138
    pub(crate) fn custom_api_dispatcher(&self, name: &ApiName) -> Option<Arc<dyn AnyHandler<B>>> {
237
71138
        let dispatchers = self.data.custom_apis.read();
238
71138
        dispatchers.get(name).cloned()
239
71138
    }
240

            
241
    /// Installs an X.509 certificate used for general purpose connections.
242
77
    pub async fn install_self_signed_certificate(&self, overwrite: bool) -> Result<(), Error> {
243
77
        let keypair = KeyPair::new_self_signed(&self.data.primary_domain);
244
77

            
245
153
        if self.certificate_chain().await.is_ok() && !overwrite {
246
1
            return Err(Error::Core(bonsaidb_core::Error::Configuration(String::from("Certificate already installed. Enable overwrite if you wish to replace the existing certificate."))));
247
76
        }
248
76

            
249
533
        self.install_certificate(keypair.certificate_chain(), keypair.private_key())
250
533
            .await?;
251

            
252
76
        tokio::fs::write(
253
76
            self.pinned_certificate_path(),
254
76
            keypair.end_entity_certificate().as_ref(),
255
76
        )
256
76
        .await?;
257

            
258
76
        Ok(())
259
77
    }
260

            
261
    /// Installs a certificate chain and private key used for TLS connections.
262
    #[cfg(feature = "pem")]
263
    pub async fn install_pem_certificate(
264
        &self,
265
        certificate_chain: &[u8],
266
        private_key: &[u8],
267
    ) -> Result<(), Error> {
268
        let private_key = match pem::parse(private_key) {
269
            Ok(pem) => PrivateKey::unchecked_from_der(pem.contents),
270
            Err(_) => PrivateKey::from_der(private_key)?,
271
        };
272
        let certificates = pem::parse_many(&certificate_chain)?
273
            .into_iter()
274
            .map(|entry| fabruic::Certificate::unchecked_from_der(entry.contents))
275
            .collect::<Vec<_>>();
276

            
277
        self.install_certificate(
278
            &CertificateChain::unchecked_from_certificates(certificates),
279
            &private_key,
280
        )
281
        .await
282
    }
283

            
284
    /// Installs a certificate chain and private key used for TLS connections.
285
76
    pub async fn install_certificate(
286
76
        &self,
287
76
        certificate_chain: &CertificateChain,
288
76
        private_key: &PrivateKey,
289
76
    ) -> Result<(), Error> {
290
76
        let db = self.hosted().await;
291

            
292
76
        TlsCertificate::entry_async(&self.data.primary_domain, &db)
293
76
            .update_with(|cert: &mut TlsCertificate| {
294
1
                cert.certificate_chain = certificate_chain.clone();
295
1
                cert.private_key = SerializablePrivateKey(private_key.clone());
296
76
            })
297
76
            .or_insert_with(|| TlsCertificate {
298
75
                domains: vec![self.data.primary_domain.clone()],
299
75
                private_key: SerializablePrivateKey(private_key.clone()),
300
75
                certificate_chain: certificate_chain.clone(),
301
228
            })
302
228
            .await?;
303

            
304
228
        self.refresh_certified_key().await?;
305

            
306
76
        let pinned_certificate_path = self.pinned_certificate_path();
307
76
        if pinned_certificate_path.exists() {
308
1
            tokio::fs::remove_file(&pinned_certificate_path).await?;
309
75
        }
310

            
311
76
        Ok(())
312
76
    }
313

            
314
77
    async fn refresh_certified_key(&self) -> Result<(), Error> {
315
231
        let certificate = self.tls_certificate().await?;
316

            
317
77
        let mut cached_key = self.data.primary_tls_key.lock();
318
77
        let private_key = rustls::PrivateKey(
319
77
            fabruic::dangerous::PrivateKey::as_ref(&certificate.private_key.0).to_vec(),
320
77
        );
321
77
        let private_key = rustls::sign::any_ecdsa_type(&Arc::new(private_key))?;
322

            
323
77
        let certificates = certificate
324
77
            .certificate_chain
325
77
            .iter()
326
77
            .map(|cert| rustls::Certificate(cert.as_ref().to_vec()))
327
77
            .collect::<Vec<_>>();
328
77

            
329
77
        let certified_key = Arc::new(CertifiedKey::new(certificates, private_key));
330
77
        *cached_key = Some(certified_key);
331
77
        Ok(())
332
77
    }
333

            
334
99
    async fn tls_certificate(&self) -> Result<TlsCertificate, Error> {
335
99
        let db = self.hosted().await;
336
99
        let (_, certificate) = db
337
99
            .view::<TlsCertificatesByDomain>()
338
99
            .with_key(self.data.primary_domain.clone())
339
198
            .query_with_collection_docs()
340
198
            .await?
341
            .documents
342
99
            .into_iter()
343
99
            .next()
344
99
            .ok_or_else(|| {
345
                Error::Core(bonsaidb_core::Error::Configuration(format!(
346
                    "no certificate found for {}",
347
                    self.data.primary_domain
348
                )))
349
99
            })?;
350
99
        Ok(certificate.contents)
351
99
    }
352

            
353
    /// Returns the current certificate chain.
354
113
    pub async fn certificate_chain(&self) -> Result<CertificateChain, Error> {
355
113
        let db = self.hosted().await;
356
113
        if let Some(mapping) = db
357
113
            .view::<TlsCertificatesByDomain>()
358
113
            .with_key(self.data.primary_domain.clone())
359
113
            .query()
360
107
            .await?
361
113
            .into_iter()
362
113
            .next()
363
        {
364
35
            Ok(mapping.value)
365
        } else {
366
78
            Err(Error::Core(bonsaidb_core::Error::Configuration(format!(
367
78
                "no certificate found for {}",
368
78
                self.data.primary_domain
369
78
            ))))
370
        }
371
113
    }
372

            
373
    /// Listens for incoming client connections. Does not return until the
374
    /// server shuts down.
375
22
    pub async fn listen_on(&self, port: u16) -> Result<(), Error> {
376
66
        let certificate = self.tls_certificate().await?;
377
22
        let keypair =
378
22
            KeyPair::from_parts(certificate.certificate_chain, certificate.private_key.0)?;
379
22
        let mut builder = Endpoint::builder();
380
22
        builder.set_protocols([CURRENT_PROTOCOL_VERSION.as_bytes().to_vec()]);
381
22
        builder.set_address(([0; 8], port).into());
382
22
        builder
383
22
            .set_max_idle_timeout(None)
384
22
            .map_err(|err| Error::Core(bonsaidb_core::Error::Transport(err.to_string())))?;
385
22
        builder.set_server_key_pair(Some(keypair));
386
22
        let mut server = builder
387
22
            .build()
388
22
            .map_err(|err| Error::Core(bonsaidb_core::Error::Transport(err.to_string())))?;
389
22
        {
390
22
            let mut endpoint = fast_async_write!(self.data.endpoint);
391
22
            *endpoint = Some(server.clone());
392
        }
393

            
394
22
        let mut shutdown_watcher = self
395
22
            .data
396
22
            .shutdown
397
22
            .watcher()
398
            .await
399
22
            .expect("server already shut down");
400

            
401
83
        while let Some(result) = tokio::select! {
402
4
            shutdown_state = shutdown_watcher.wait_for_shutdown() => {
403
                drop(server.close_incoming());
404
                if matches!(shutdown_state, ShutdownState::GracefulShutdown) {
405
                    server.wait_idle().await;
406
                }
407
                drop(server.close());
408
                None
409
            },
410
61
            msg = server.next() => msg
411
61
        } {
412
61
            let connection = result.accept::<()>().await?;
413
61
            let task_self = self.clone();
414
61
            tokio::spawn(async move {
415
61
                let address = connection.remote_address();
416
61
                if let Err(err) = task_self.handle_bonsai_connection(connection).await {
417
                    log::error!("[server] closing connection {}: {:?}", address, err);
418
61
                }
419
61
            });
420
        }
421

            
422
4
        Ok(())
423
4
    }
424

            
425
    /// Returns all of the currently connected clients.
426
    pub async fn connected_clients(&self) -> Vec<ConnectedClient<B>> {
427
        let clients = fast_async_read!(self.data.clients);
428
        clients.values().cloned().collect()
429
    }
430

            
431
    /// Sends a custom API response to all connected clients.
432
    pub async fn broadcast<Api: api::Api>(&self, response: &Api::Response) {
433
        let clients = fast_async_read!(self.data.clients);
434
        for client in clients.values() {
435
            // TODO should this broadcast to all sessions too rather than only the global session?
436
            drop(client.send::<Api>(None, response));
437
        }
438
    }
439

            
440
147
    async fn initialize_client(
441
147
        &self,
442
147
        transport: Transport,
443
147
        address: SocketAddr,
444
147
        sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
445
147
    ) -> Option<OwnedClient<B>> {
446
147
        if !self.data.default_session.allowed_to(
447
147
            &bonsaidb_resource_name(),
448
147
            &BonsaiAction::Server(ServerAction::Connect),
449
147
        ) {
450
            return None;
451
147
        }
452

            
453
147
        let client = loop {
454
147
            let next_id = CONNECTED_CLIENT_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
455
147
            let mut clients = fast_async_write!(self.data.clients);
456
147
            if let hash_map::Entry::Vacant(e) = clients.entry(next_id) {
457
147
                let client = OwnedClient::new(
458
147
                    next_id,
459
147
                    address,
460
147
                    transport,
461
147
                    sender,
462
147
                    self.clone(),
463
147
                    self.data.default_session.clone(),
464
147
                );
465
147
                e.insert(client.clone());
466
147
                break client;
467
            }
468
        };
469

            
470
147
        match B::client_connected(&client, self).await {
471
147
            Ok(ConnectionHandling::Accept) => Some(client),
472
            Ok(ConnectionHandling::Reject) => None,
473
            Err(err) => {
474
                log::error!(
475
                    "[server] Rejecting connection due to error in `client_connected`: {err:?}"
476
                );
477
                None
478
            }
479
        }
480
147
    }
481

            
482
86
    async fn disconnect_client(&self, id: u32) {
483
86
        if let Some(client) = {
484
86
            let mut clients = fast_async_write!(self.data.clients);
485
86
            clients.remove(&id)
486
        } {
487
86
            if let Err(err) = B::client_disconnected(client, self).await {
488
                log::error!("[server] Error in `client_disconnected`: {err:?}");
489
86
            }
490
        }
491
86
    }
492

            
493
61
    async fn handle_bonsai_connection(
494
61
        &self,
495
61
        mut connection: fabruic::Connection<()>,
496
61
    ) -> Result<(), Error> {
497
61
        if let Some(incoming) = connection.next().await {
498
61
            let incoming = match incoming {
499
61
                Ok(incoming) => incoming,
500
                Err(err) => {
501
                    log::error!("[server] Error establishing a stream: {err:?}");
502
                    return Ok(());
503
                }
504
            };
505

            
506
61
            match incoming
507
61
                .accept::<networking::Payload, networking::Payload>()
508
                .await
509
            {
510
61
                Ok((sender, receiver)) => {
511
61
                    let (api_response_sender, api_response_receiver) = flume::unbounded();
512
61
                    if let Some(disconnector) = self
513
61
                        .initialize_client(
514
61
                            Transport::Bonsai,
515
61
                            connection.remote_address(),
516
61
                            api_response_sender,
517
61
                        )
518
                        .await
519
61
                    {
520
61
                        let task_sender = sender.clone();
521
61
                        tokio::spawn(async move {
522
15
                            while let Ok((session_id, name, bytes)) =
523
76
                                api_response_receiver.recv_async().await
524
                            {
525
15
                                if task_sender
526
15
                                    .send(&Payload {
527
15
                                        id: None,
528
15
                                        session_id,
529
15
                                        name,
530
15
                                        value: Ok(bytes),
531
15
                                    })
532
15
                                    .is_err()
533
                                {
534
                                    break;
535
15
                                }
536
                            }
537
14
                            let _ = connection.close().await;
538
61
                        });
539
61

            
540
61
                        let task_self = self.clone();
541
61
                        tokio::spawn(async move {
542
61
                            if let Err(err) = task_self
543
25568
                                .handle_stream(disconnector, sender, receiver)
544
25568
                                .await
545
                            {
546
                                log::error!("[server] Error handling stream: {err:?}");
547
14
                            }
548
61
                        });
549
61
                    } else {
550
                        log::error!("[server] Backend rejected connection.");
551
                        return Ok(());
552
                    }
553
                }
554
                Err(err) => {
555
                    log::error!("[server] Error accepting incoming stream: {err:?}");
556
                    return Ok(());
557
                }
558
            }
559
        }
560
61
        Ok(())
561
61
    }
562

            
563
147
    async fn handle_client_requests(
564
147
        &self,
565
147
        client: ConnectedClient<B>,
566
147
        request_receiver: flume::Receiver<Payload>,
567
147
        response_sender: flume::Sender<Payload>,
568
147
    ) {
569
147
        let notify = Arc::new(Notify::new());
570
147
        let requests_in_queue = Arc::new(AtomicUsize::new(0));
571
101694
        loop {
572
101694
            let current_requests = requests_in_queue.load(Ordering::SeqCst);
573
101694
            if current_requests == self.data.client_simultaneous_request_limit {
574
                // Wait for requests to finish.
575
30399
                notify.notified().await;
576
71295
            } else if requests_in_queue
577
71295
                .compare_exchange(
578
71295
                    current_requests,
579
71295
                    current_requests + 1,
580
71295
                    Ordering::SeqCst,
581
71295
                    Ordering::SeqCst,
582
71295
                )
583
71295
                .is_ok()
584
            {
585
71286
                let payload = match request_receiver.recv_async().await {
586
71139
                    Ok(payload) => payload,
587
86
                    Err(_) => break,
588
                };
589
71139
                let session_id = payload.session_id;
590
71139
                let id = payload.id;
591
71139
                let task_sender = response_sender.clone();
592
71139

            
593
71139
                let notify = notify.clone();
594
71139
                let requests_in_queue = requests_in_queue.clone();
595
71139
                self.handle_request_through_worker(
596
71139
                    payload,
597
71139
                    move |name, value| async move {
598
71134
                        drop(task_sender.send(Payload {
599
71134
                            session_id,
600
71134
                            id,
601
71134
                            name,
602
71134
                            value,
603
71134
                        }));
604
71134

            
605
71134
                        requests_in_queue.fetch_sub(1, Ordering::SeqCst);
606
71134

            
607
71134
                        notify.notify_one();
608
71134

            
609
71134
                        Ok(())
610
71139
                    },
611
71139
                    client.clone(),
612
71139
                )
613
                .await
614
71139
                .unwrap();
615
9
            }
616
        }
617
86
    }
618

            
619
71139
    async fn handle_request_through_worker<
620
71139
        F: FnOnce(ApiName, Result<Bytes, bonsaidb_core::Error>) -> R + Send + 'static,
621
71139
        R: Future<Output = Result<(), Error>> + Send,
622
71139
    >(
623
71139
        &self,
624
71139
        request: Payload,
625
71139
        callback: F,
626
71139
        client: ConnectedClient<B>,
627
71139
    ) -> Result<(), Error> {
628
71139
        let (result_sender, result_receiver) = oneshot::channel();
629
71139
        let session = client
630
71139
            .session(request.session_id)
631
71139
            .unwrap_or_else(|| self.data.default_session.clone());
632
71139
        self.data
633
71139
            .request_processor
634
71139
            .send(ClientRequest::<B>::new(
635
71139
                request,
636
71139
                self.clone(),
637
71139
                client,
638
71139
                session,
639
71139
                result_sender,
640
71139
            ))
641
71139
            .map_err(|_| Error::InternalCommunication)?;
642
71139
        tokio::spawn(async move {
643
71139
            let (name, result) = result_receiver.await?;
644
            // Map the error into a Response::Error. The jobs system supports
645
            // multiple receivers receiving output, and wraps Err to avoid
646
            // requiring the error to be cloneable. As such, we have to unwrap
647
            // it. Thankfully, we can guarantee nothing else is waiting on a
648
            // response to a request than the original requestor, so this can be
649
            // safely unwrapped.
650
71139
            callback(name, result).await?;
651
71138
            Result::<(), Error>::Ok(())
652
71139
        });
653
71139
        Ok(())
654
71139
    }
655

            
656
61
    async fn handle_stream(
657
61
        &self,
658
61
        client: OwnedClient<B>,
659
61
        sender: fabruic::Sender<Payload>,
660
61
        mut receiver: fabruic::Receiver<Payload>,
661
61
    ) -> Result<(), Error> {
662
61
        let (payload_sender, payload_receiver) = flume::unbounded();
663
61
        tokio::spawn(async move {
664
34225
            while let Ok(payload) = payload_receiver.recv_async().await {
665
34164
                if sender.send(&payload).is_err() {
666
                    break;
667
34164
                }
668
            }
669
61
        });
670
61

            
671
61
        let (request_sender, request_receiver) =
672
61
            flume::bounded::<Payload>(self.data.client_simultaneous_request_limit);
673
61
        let task_self = self.clone();
674
61
        tokio::spawn(async move {
675
61
            task_self
676
28120
                .handle_client_requests(client.clone(), request_receiver, payload_sender)
677
28120
                .await;
678
61
        });
679

            
680
34225
        while let Some(payload) = receiver.next().await {
681
34164
            drop(request_sender.send_async(payload?).await);
682
        }
683

            
684
14
        Ok(())
685
14
    }
686

            
687
    /// Shuts the server down. If a `timeout` is provided, the server will stop
688
    /// accepting new connections and attempt to respond to any outstanding
689
    /// requests already being processed. After the `timeout` has elapsed or if
690
    /// no `timeout` was provided, the server is forcefully shut down.
691
30
    pub async fn shutdown(&self, timeout: Option<Duration>) -> Result<(), Error> {
692
30
        if let Some(timeout) = timeout {
693
5
            self.data.shutdown.graceful_shutdown(timeout).await;
694
        } else {
695
26
            self.data.shutdown.shutdown().await;
696
        }
697

            
698
30
        Ok(())
699
30
    }
700

            
701
    /// Listens for signals from the operating system that the server should
702
    /// shut down and attempts to gracefully shut down.
703
1
    pub async fn listen_for_shutdown(&self) -> Result<(), Error> {
704
1
        const GRACEFUL_SHUTDOWN: usize = 1;
705
1
        const TERMINATE: usize = 2;
706
1

            
707
1
        enum SignalShutdownState {
708
1
            Running,
709
1
            ShuttingDown(flume::Receiver<()>),
710
1
        }
711
1

            
712
1
        let shutdown_state = Arc::new(Mutex::new(SignalShutdownState::Running));
713
1
        let flag = Arc::new(AtomicUsize::default());
714
1
        let register_hook = |flag: &Arc<AtomicUsize>| {
715
1
            signal_hook::flag::register_usize(SIGINT, flag.clone(), GRACEFUL_SHUTDOWN)?;
716
1
            signal_hook::flag::register_usize(SIGTERM, flag.clone(), TERMINATE)?;
717
            #[cfg(not(windows))]
718
1
            signal_hook::flag::register_usize(SIGQUIT, flag.clone(), TERMINATE)?;
719
1
            Result::<(), std::io::Error>::Ok(())
720
1
        };
721
1
        if let Err(err) = register_hook(&flag) {
722
            log::error!("Error installing signals for graceful shutdown: {err:?}");
723
            tokio::time::sleep(Duration::MAX).await;
724
        } else {
725
            loop {
726
4
                match flag.load(Ordering::Relaxed) {
727
4
                    0 => {
728
4
                        // No signal
729
4
                    }
730
                    GRACEFUL_SHUTDOWN => {
731
                        let mut state = fast_async_lock!(shutdown_state);
732
                        match *state {
733
                            SignalShutdownState::Running => {
734
                                log::error!("Interrupt signal received. Shutting down gracefully.");
735
                                let task_server = self.clone();
736
                                let (shutdown_sender, shutdown_receiver) = flume::bounded(1);
737
                                tokio::task::spawn(async move {
738
                                    task_server.shutdown(Some(Duration::from_secs(30))).await?;
739
                                    let _ = shutdown_sender.send(());
740
                                    Result::<(), Error>::Ok(())
741
                                });
742
                                *state = SignalShutdownState::ShuttingDown(shutdown_receiver);
743
                            }
744
                            SignalShutdownState::ShuttingDown(_) => {
745
                                // Two interrupts, go ahead and force the shutdown
746
                                break;
747
                            }
748
                        }
749
                    }
750
                    TERMINATE => {
751
                        log::error!("Quit signal received. Shutting down.");
752
                        break;
753
                    }
754
                    _ => unreachable!(),
755
                }
756

            
757
4
                let state = fast_async_lock!(shutdown_state);
758
4
                if let SignalShutdownState::ShuttingDown(receiver) = &*state {
759
                    if receiver.try_recv().is_ok() {
760
                        // Fully shut down.
761
                        return Ok(());
762
                    }
763
4
                }
764

            
765
4
                tokio::time::sleep(Duration::from_millis(300)).await;
766
            }
767
            self.shutdown(None).await?;
768
        }
769

            
770
        Ok(())
771
    }
772
}
773

            
774
impl<B: Backend> Deref for CustomServer<B> {
775
    type Target = AsyncStorage;
776

            
777
70068
    fn deref(&self) -> &Self::Target {
778
70068
        &self.storage
779
70068
    }
780
}
781

            
782
#[derive(Debug)]
783
struct ClientRequest<B: Backend> {
784
    request: Option<Payload>,
785
    client: ConnectedClient<B>,
786
    session: Session,
787
    server: CustomServer<B>,
788
    result_sender: oneshot::Sender<(ApiName, Result<Bytes, bonsaidb_core::Error>)>,
789
}
790

            
791
impl<B: Backend> ClientRequest<B> {
792
71139
    pub fn new(
793
71139
        request: Payload,
794
71139
        server: CustomServer<B>,
795
71139
        client: ConnectedClient<B>,
796
71139
        session: Session,
797
71139
        result_sender: oneshot::Sender<(ApiName, Result<Bytes, bonsaidb_core::Error>)>,
798
71139
    ) -> Self {
799
71139
        Self {
800
71139
            request: Some(request),
801
71139
            server,
802
71139
            client,
803
71139
            session,
804
71139
            result_sender,
805
71139
        }
806
71139
    }
807
}
808

            
809
impl<B: Backend> HasSession for CustomServer<B> {
810
91
    fn session(&self) -> Option<&Session> {
811
91
        self.storage.session()
812
91
    }
813
}
814

            
815
#[async_trait]
816
impl<B: Backend> AsyncStorageConnection for CustomServer<B> {
817
    type Database = ServerDatabase<B>;
818
    type Authenticated = Self;
819

            
820
    async fn admin(&self) -> Self::Database {
821
        self.database::<Admin>(ADMIN_DATABASE_NAME).await.unwrap()
822
    }
823

            
824
692
    async fn create_database_with_schema(
825
692
        &self,
826
692
        name: &str,
827
692
        schema: SchemaName,
828
692
        only_if_needed: bool,
829
692
    ) -> Result<(), bonsaidb_core::Error> {
830
692
        self.storage
831
692
            .create_database_with_schema(name, schema, only_if_needed)
832
688
            .await
833
1384
    }
834

            
835
154
    async fn database<DB: Schema>(
836
154
        &self,
837
154
        name: &str,
838
154
    ) -> Result<Self::Database, bonsaidb_core::Error> {
839
154
        let db = self.storage.database::<DB>(name).await?;
840
154
        Ok(ServerDatabase {
841
154
            server: self.clone(),
842
154
            db,
843
154
        })
844
308
    }
845

            
846
508
    async fn delete_database(&self, name: &str) -> Result<(), bonsaidb_core::Error> {
847
508
        self.storage.delete_database(name).await
848
1016
    }
849

            
850
4
    async fn list_databases(&self) -> Result<Vec<connection::Database>, bonsaidb_core::Error> {
851
4
        self.storage.list_databases().await
852
8
    }
853

            
854
4
    async fn list_available_schemas(&self) -> Result<Vec<SchemaName>, bonsaidb_core::Error> {
855
4
        self.storage.list_available_schemas().await
856
8
    }
857

            
858
10
    async fn create_user(&self, username: &str) -> Result<u64, bonsaidb_core::Error> {
859
10
        self.storage.create_user(username).await
860
20
    }
861

            
862
4
    async fn delete_user<'user, U: Nameable<'user, u64> + Send + Sync>(
863
4
        &self,
864
4
        user: U,
865
4
    ) -> Result<(), bonsaidb_core::Error> {
866
4
        self.storage.delete_user(user).await
867
8
    }
868

            
869
    #[cfg(feature = "password-hashing")]
870
3
    async fn set_user_password<'user, U: Nameable<'user, u64> + Send + Sync>(
871
3
        &self,
872
3
        user: U,
873
3
        password: bonsaidb_core::connection::SensitiveString,
874
3
    ) -> Result<(), bonsaidb_core::Error> {
875
3
        self.storage.set_user_password(user, password).await
876
6
    }
877

            
878
    #[cfg(feature = "password-hashing")]
879
5
    async fn authenticate<'user, U: Nameable<'user, u64> + Send + Sync>(
880
5
        &self,
881
5
        user: U,
882
5
        authentication: Authentication,
883
5
    ) -> Result<Self::Authenticated, bonsaidb_core::Error> {
884
5
        let storage = self.storage.authenticate(user, authentication).await?;
885
5
        Ok(Self {
886
5
            data: self.data.clone(),
887
5
            storage,
888
5
        })
889
10
    }
890

            
891
2
    async fn assume_identity(
892
2
        &self,
893
2
        identity: IdentityReference<'_>,
894
2
    ) -> Result<Self::Authenticated, bonsaidb_core::Error> {
895
2
        let storage = self.storage.assume_identity(identity).await?;
896
2
        Ok(Self {
897
2
            data: self.data.clone(),
898
2
            storage,
899
2
        })
900
4
    }
901

            
902
12
    async fn add_permission_group_to_user<
903
12
        'user,
904
12
        'group,
905
12
        U: Nameable<'user, u64> + Send + Sync,
906
12
        G: Nameable<'group, u64> + Send + Sync,
907
12
    >(
908
12
        &self,
909
12
        user: U,
910
12
        permission_group: G,
911
12
    ) -> Result<(), bonsaidb_core::Error> {
912
12
        self.storage
913
12
            .add_permission_group_to_user(user, permission_group)
914
12
            .await
915
24
    }
916

            
917
8
    async fn remove_permission_group_from_user<
918
8
        'user,
919
8
        'group,
920
8
        U: Nameable<'user, u64> + Send + Sync,
921
8
        G: Nameable<'group, u64> + Send + Sync,
922
8
    >(
923
8
        &self,
924
8
        user: U,
925
8
        permission_group: G,
926
8
    ) -> Result<(), bonsaidb_core::Error> {
927
8
        self.storage
928
8
            .remove_permission_group_from_user(user, permission_group)
929
8
            .await
930
16
    }
931

            
932
8
    async fn add_role_to_user<
933
8
        'user,
934
8
        'group,
935
8
        U: Nameable<'user, u64> + Send + Sync,
936
8
        G: Nameable<'group, u64> + Send + Sync,
937
8
    >(
938
8
        &self,
939
8
        user: U,
940
8
        role: G,
941
8
    ) -> Result<(), bonsaidb_core::Error> {
942
8
        self.storage.add_role_to_user(user, role).await
943
16
    }
944

            
945
8
    async fn remove_role_from_user<
946
8
        'user,
947
8
        'group,
948
8
        U: Nameable<'user, u64> + Send + Sync,
949
8
        G: Nameable<'group, u64> + Send + Sync,
950
8
    >(
951
8
        &self,
952
8
        user: U,
953
8
        role: G,
954
8
    ) -> Result<(), bonsaidb_core::Error> {
955
8
        self.storage.remove_role_from_user(user, role).await
956
16
    }
957
}
958

            
959
82
#[derive(Default)]
960
struct AlpnKeys(Arc<std::sync::Mutex<HashMap<String, Arc<rustls::sign::CertifiedKey>>>>);
961

            
962
impl Debug for AlpnKeys {
963
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
964
        f.debug_tuple("AlpnKeys").finish()
965
    }
966
}
967

            
968
impl Deref for AlpnKeys {
969
    type Target = Arc<std::sync::Mutex<HashMap<String, Arc<rustls::sign::CertifiedKey>>>>;
970

            
971
    fn deref(&self) -> &Self::Target {
972
        &self.0
973
    }
974
}