1
use std::collections::{hash_map, HashMap};
2
use std::fmt::Debug;
3
use std::net::SocketAddr;
4
use std::ops::Deref;
5
use std::path::PathBuf;
6
use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
7
use std::sync::Arc;
8
use std::time::Duration;
9

            
10
use async_trait::async_trait;
11
use bonsaidb_core::admin::{Admin, ADMIN_DATABASE_NAME};
12
use bonsaidb_core::api;
13
use bonsaidb_core::api::ApiName;
14
use bonsaidb_core::arc_bytes::serde::Bytes;
15
use bonsaidb_core::connection::{
16
    self, AsyncConnection, AsyncStorageConnection, HasSession, IdentityReference, Session,
17
    SessionId,
18
};
19
use bonsaidb_core::networking::{self, Payload, CURRENT_PROTOCOL_VERSION};
20
use bonsaidb_core::permissions::bonsai::{bonsaidb_resource_name, BonsaiAction, ServerAction};
21
use bonsaidb_core::permissions::Permissions;
22
use bonsaidb_core::schema::{self, Nameable, NamedCollection, Schema, SchemaSummary};
23
use bonsaidb_local::config::Builder;
24
use bonsaidb_local::{AsyncStorage, Storage, StorageNonBlocking};
25
use bonsaidb_utils::fast_async_lock;
26
use derive_where::derive_where;
27
use fabruic::{self, CertificateChain, Endpoint, KeyPair, PrivateKey};
28
use flume::Sender;
29
use futures::{Future, StreamExt};
30
use parking_lot::{Mutex, RwLock};
31
use rustls::sign::CertifiedKey;
32
use schema::SchemaName;
33
#[cfg(not(windows))]
34
use signal_hook::consts::SIGQUIT;
35
use signal_hook::consts::{SIGINT, SIGTERM};
36
use tokio::sync::{oneshot, Notify};
37

            
38
use crate::api::{AnyHandler, HandlerSession};
39
use crate::backend::ConnectionHandling;
40
#[cfg(feature = "acme")]
41
use crate::config::AcmeConfiguration;
42
use crate::dispatch::{register_api_handlers, ServerDispatcher};
43
use crate::error::Error;
44
use crate::hosted::{Hosted, SerializablePrivateKey, TlsCertificate, TlsCertificatesByDomain};
45
use crate::server::shutdown::{Shutdown, ShutdownState, ShutdownStateWatcher};
46
use crate::{Backend, BackendError, BonsaiListenConfig, NoBackend, ServerConfiguration};
47

            
48
#[cfg(feature = "acme")]
49
pub mod acme;
50
mod connected_client;
51
mod database;
52

            
53
mod shutdown;
54
mod tcp;
55
#[cfg(feature = "websockets")]
56
mod websockets;
57

            
58
use self::connected_client::OwnedClient;
59
pub use self::connected_client::{ConnectedClient, LockedClientDataGuard, Transport};
60
pub use self::database::ServerDatabase;
61
pub use self::tcp::{ApplicationProtocols, HttpService, Peer, StandardTcpProtocols, TcpService};
62

            
63
static CONNECTED_CLIENT_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
64

            
65
/// A BonsaiDb server.
66
#[derive(Debug)]
67
73429
#[derive_where(Clone)]
68
pub struct CustomServer<B: Backend = NoBackend> {
69
    data: Arc<Data<B>>,
70
    pub(crate) storage: AsyncStorage,
71
}
72

            
73
impl<'a, B: Backend> From<&'a CustomServer<B>> for Storage {
74
    fn from(server: &'a CustomServer<B>) -> Self {
75
        Self::from(server.storage.clone())
76
    }
77
}
78

            
79
impl<B: Backend> From<CustomServer<B>> for Storage {
80
    fn from(server: CustomServer<B>) -> Self {
81
        Self::from(server.storage)
82
    }
83
}
84

            
85
/// A BonsaiDb server without a custom backend.
86
pub type Server = CustomServer<NoBackend>;
87

            
88
#[derive(Debug)]
89
struct Data<B: Backend = NoBackend> {
90
    backend: B,
91
    clients: RwLock<HashMap<u32, ConnectedClient<B>>>,
92
    request_processor: flume::Sender<ClientRequest<B>>,
93
    default_session: Session,
94
    client_simultaneous_request_limit: usize,
95
    primary_tls_key: CachedCertifiedKey,
96
    primary_domain: String,
97
    custom_apis: RwLock<HashMap<ApiName, Arc<dyn AnyHandler<B>>>>,
98
    #[cfg(feature = "acme")]
99
    acme: AcmeConfiguration,
100
    #[cfg(feature = "acme")]
101
    alpn_keys: AlpnKeys,
102
    shutdown: Shutdown,
103
}
104

            
105
92
#[derive(Default)]
106
struct CachedCertifiedKey(Mutex<Option<Arc<CertifiedKey>>>);
107

            
108
impl Debug for CachedCertifiedKey {
109
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110
        f.debug_tuple("CachedCertifiedKey").finish()
111
    }
112
}
113

            
114
impl Deref for CachedCertifiedKey {
115
    type Target = Mutex<Option<Arc<CertifiedKey>>>;
116

            
117
710
    fn deref(&self) -> &Self::Target {
118
710
        &self.0
119
710
    }
120
}
121

            
122
impl<B: Backend> CustomServer<B> {
123
    /// Opens a server using `directory` for storage.
124
92
    pub async fn open(
125
92
        configuration: ServerConfiguration<B>,
126
92
    ) -> Result<Self, BackendError<B::Error>> {
127
92
        let configuration = register_api_handlers(B::configure(configuration)?)?;
128
92
        let (request_sender, request_receiver) = flume::unbounded::<ClientRequest<B>>();
129
1472
        for _ in 0..configuration.request_workers {
130
1472
            let request_receiver = request_receiver.clone();
131
1472
            tokio::task::spawn(async move {
132
100696
                while let Ok(mut client_request) = request_receiver.recv_async().await {
133
72053
                    let request = client_request.request.take().unwrap();
134
72053
                    let session = client_request.session.clone();
135
                    // TODO we should be able to upgrade a session-less Storage to one with a Session.
136
                    // The Session needs to be looked up from the client based on the request's session id.
137
72053
                    let result = match client_request.server.storage.assume_session(session) {
138
72049
                        Ok(storage) => {
139
72049
                            let client = HandlerSession {
140
72049
                                server: &client_request.server,
141
72049
                                client: &client_request.client,
142
72049
                                as_client: Self {
143
72049
                                    data: client_request.server.data.clone(),
144
72049
                                    storage,
145
72049
                                },
146
72049
                            };
147
72049
                            ServerDispatcher::dispatch_api_request(
148
72049
                                client,
149
72049
                                &request.name,
150
72049
                                request.value.unwrap(),
151
72049
                            )
152
132355
                            .await
153
72045
                            .map_err(bonsaidb_core::Error::from)
154
                        }
155
4
                        Err(err) => Err(err),
156
                    };
157
72049
                    drop(client_request.result_sender.send((request.name, result)));
158
                }
159
1472
            });
160
1472
        }
161

            
162
92
        let storage = AsyncStorage::open(configuration.storage.with_schema::<Hosted>()?).await?;
163

            
164
184
        storage.create_database::<Hosted>("_hosted", true).await?;
165

            
166
92
        let default_permissions = Permissions::from(configuration.default_permissions);
167
92

            
168
92
        let server = Self {
169
92
            storage,
170
92
            data: Arc::new(Data {
171
92
                backend: configuration.backend,
172
92
                clients: parking_lot::RwLock::default(),
173
92
                request_processor: request_sender,
174
92
                default_session: Session {
175
92
                    permissions: default_permissions,
176
92
                    ..Session::default()
177
92
                },
178
92
                client_simultaneous_request_limit: configuration.client_simultaneous_request_limit,
179
92
                primary_tls_key: CachedCertifiedKey::default(),
180
92
                primary_domain: configuration.server_name,
181
92
                custom_apis: parking_lot::RwLock::new(configuration.custom_apis),
182
92
                #[cfg(feature = "acme")]
183
92
                acme: configuration.acme,
184
92
                #[cfg(feature = "acme")]
185
92
                alpn_keys: AlpnKeys::default(),
186
92
                shutdown: Shutdown::new(),
187
92
            }),
188
92
        };
189
92

            
190
98
        server.data.backend.initialize(&server).await?;
191
92
        Ok(server)
192
92
    }
193

            
194
    /// Returns the path to the public pinned certificate, if this server has
195
    /// one. Note: this function will always succeed, but the file may not
196
    /// exist.
197
    #[must_use]
198
168
    pub fn pinned_certificate_path(&self) -> PathBuf {
199
168
        self.path().join("pinned-certificate.der")
200
168
    }
201

            
202
    /// Returns the primary domain configured for this server.
203
    #[must_use]
204
38
    pub fn primary_domain(&self) -> &str {
205
38
        &self.data.primary_domain
206
38
    }
207

            
208
    /// Returns the [`Backend`] implementor for this server.
209
    #[must_use]
210
24
    pub fn backend(&self) -> &B {
211
24
        &self.data.backend
212
24
    }
213

            
214
    /// Returns the administration database.
215
2
    pub async fn admin(&self) -> ServerDatabase<B> {
216
2
        let db = self.storage.admin().await;
217
2
        ServerDatabase {
218
2
            server: self.clone(),
219
2
            db,
220
2
        }
221
2
    }
222

            
223
326
    pub(crate) async fn hosted(&self) -> ServerDatabase<B> {
224
326
        let db = self.storage.database::<Hosted>("_hosted").await.unwrap();
225
326
        ServerDatabase {
226
326
            server: self.clone(),
227
326
            db,
228
326
        }
229
326
    }
230

            
231
72049
    pub(crate) fn custom_api_dispatcher(&self, name: &ApiName) -> Option<Arc<dyn AnyHandler<B>>> {
232
72049
        let dispatchers = self.data.custom_apis.read();
233
72049
        dispatchers.get(name).cloned()
234
72049
    }
235

            
236
    /// Installs an X.509 certificate used for general purpose connections.
237
85
    pub async fn install_self_signed_certificate(&self, overwrite: bool) -> Result<(), Error> {
238
85
        let keypair = KeyPair::new_self_signed(&self.data.primary_domain);
239
85

            
240
170
        if self.certificate_chain().await.is_ok() && !overwrite {
241
1
            return Err(Error::Core(bonsaidb_core::Error::other("bonsaidb-server config", "Certificate already installed. Enable overwrite if you wish to replace the existing certificate.")));
242
84
        }
243
84

            
244
84
        self.install_certificate(keypair.certificate_chain(), keypair.private_key())
245
589
            .await?;
246

            
247
84
        tokio::fs::write(
248
84
            self.pinned_certificate_path(),
249
84
            keypair.end_entity_certificate().as_ref(),
250
84
        )
251
84
        .await?;
252

            
253
84
        Ok(())
254
85
    }
255

            
256
    /// Installs a certificate chain and private key used for TLS connections.
257
    #[cfg(feature = "pem")]
258
    pub async fn install_pem_certificate(
259
        &self,
260
        certificate_chain: &[u8],
261
        private_key: &[u8],
262
    ) -> Result<(), Error> {
263
        let private_key = match pem::parse(private_key) {
264
            Ok(pem) => PrivateKey::unchecked_from_der(pem.contents()),
265
            Err(_) => PrivateKey::from_der(private_key)?,
266
        };
267
        let certificates = pem::parse_many(certificate_chain)?
268
            .into_iter()
269
            .map(|entry| fabruic::Certificate::unchecked_from_der(entry.contents()))
270
            .collect::<Vec<_>>();
271

            
272
        self.install_certificate(
273
            &CertificateChain::unchecked_from_certificates(certificates),
274
            &private_key,
275
        )
276
        .await
277
    }
278

            
279
    /// Installs a certificate chain and private key used for TLS connections.
280
84
    pub async fn install_certificate(
281
84
        &self,
282
84
        certificate_chain: &CertificateChain,
283
84
        private_key: &PrivateKey,
284
84
    ) -> Result<(), Error> {
285
84
        let db = self.hosted().await;
286

            
287
84
        TlsCertificate::entry_async(&self.data.primary_domain, &db)
288
84
            .update_with(|cert: &mut TlsCertificate| {
289
1
                cert.certificate_chain = certificate_chain.clone();
290
1
                cert.private_key = SerializablePrivateKey(private_key.clone());
291
84
            })
292
84
            .or_insert_with(|| TlsCertificate {
293
83
                domains: vec![self.data.primary_domain.clone()],
294
83
                private_key: SerializablePrivateKey(private_key.clone()),
295
83
                certificate_chain: certificate_chain.clone(),
296
84
            })
297
252
            .await?;
298

            
299
252
        self.refresh_certified_key().await?;
300

            
301
84
        let pinned_certificate_path = self.pinned_certificate_path();
302
84
        if pinned_certificate_path.exists() {
303
1
            tokio::fs::remove_file(&pinned_certificate_path).await?;
304
83
        }
305

            
306
84
        Ok(())
307
84
    }
308

            
309
85
    async fn refresh_certified_key(&self) -> Result<(), Error> {
310
255
        let certificate = self.tls_certificate().await?;
311

            
312
85
        let mut cached_key = self.data.primary_tls_key.lock();
313
85
        let private_key = rustls::PrivateKey(
314
85
            fabruic::dangerous::PrivateKey::as_ref(&certificate.private_key.0).to_vec(),
315
85
        );
316
85
        let private_key = rustls::sign::any_ecdsa_type(&Arc::new(private_key))?;
317

            
318
85
        let certificates = certificate
319
85
            .certificate_chain
320
85
            .iter()
321
85
            .map(|cert| rustls::Certificate(cert.as_ref().to_vec()))
322
85
            .collect::<Vec<_>>();
323
85

            
324
85
        let certified_key = Arc::new(CertifiedKey::new(certificates, private_key));
325
85
        *cached_key = Some(certified_key);
326
85
        Ok(())
327
85
    }
328

            
329
111
    async fn tls_certificate(&self) -> Result<TlsCertificate, Error> {
330
111
        let db = self.hosted().await;
331
111
        let (_, certificate) = db
332
111
            .view::<TlsCertificatesByDomain>()
333
111
            .with_key(&self.data.primary_domain)
334
111
            .query_with_collection_docs()
335
222
            .await?
336
            .documents
337
111
            .into_iter()
338
111
            .next()
339
111
            .ok_or_else(|| {
340
                Error::Core(bonsaidb_core::Error::other(
341
                    "bonsaidb-server config",
342
                    format!("no certificate found for {}", self.data.primary_domain),
343
                ))
344
111
            })?;
345
111
        Ok(certificate.contents)
346
111
    }
347

            
348
    /// Returns the current certificate chain.
349
131
    pub async fn certificate_chain(&self) -> Result<CertificateChain, Error> {
350
131
        let db = self.hosted().await;
351
131
        if let Some(mapping) = db
352
131
            .view::<TlsCertificatesByDomain>()
353
131
            .with_key(&self.data.primary_domain)
354
131
            .query()
355
131
            .await?
356
131
            .into_iter()
357
131
            .next()
358
        {
359
45
            Ok(mapping.value)
360
        } else {
361
86
            Err(Error::Core(bonsaidb_core::Error::other(
362
86
                "bonsaidb-server config",
363
86
                format!("no certificate found for {}", self.data.primary_domain),
364
86
            )))
365
        }
366
131
    }
367

            
368
    /// Listens for incoming client connections. Does not return until the
369
    /// server shuts down.
370
    ///
371
    /// ## Listening on a port
372
    ///
373
    /// When passing a `u16` to this function, the server will begin listening
374
    /// on an "unspecified" address. This typically is accessible to other
375
    /// machines on the network/internet, so care should be taken to ensure this
376
    /// is what is intended.
377
    ///
378
    /// To ensure that the server only listens for local traffic, specify a
379
    /// local IP or localhost in addition to the port number.
380
26
    pub async fn listen_on(&self, config: impl Into<BonsaiListenConfig>) -> Result<(), Error> {
381
26
        let config = config.into();
382
78
        let certificate = self.tls_certificate().await?;
383
26
        let keypair =
384
26
            KeyPair::from_parts(certificate.certificate_chain, certificate.private_key.0)?;
385
26
        let mut builder = Endpoint::builder();
386
26
        builder.set_protocols([CURRENT_PROTOCOL_VERSION.as_bytes().to_vec()]);
387
26
        builder.set_address(config.address);
388
26
        builder.set_max_idle_timeout(None)?;
389
26
        builder.set_server_key_pair(Some(keypair));
390
26
        builder.set_reuse_address(config.reuse_address);
391
26
        let mut server = builder.build()?;
392

            
393
26
        let mut shutdown_watcher = self
394
26
            .data
395
26
            .shutdown
396
26
            .watcher()
397
            .await
398
26
            .expect("server already shut down");
399

            
400
102
        while let Some(incoming) = tokio::select! {
401
5
            shutdown_state = shutdown_watcher.wait_for_shutdown() => {
402
                drop(server.close_incoming());
403
                if matches!(shutdown_state, ShutdownState::GracefulShutdown) {
404
                    server.wait_idle().await;
405
                }
406
                None
407
            },
408
76
            msg = server.next() => msg
409
        } {
410
76
            let address = incoming.remote_address();
411
76
            let connection = match incoming.accept::<()>().await {
412
76
                Ok(connection) => connection,
413
                Err(err) => {
414
                    log::error!("[server] error on incoming connection from {address}: {err:?}");
415
                    continue;
416
                }
417
            };
418
76
            let task_self = self.clone();
419
76
            tokio::spawn(async move {
420
76
                if let Err(err) = task_self.handle_bonsai_connection(connection).await {
421
                    log::error!("[server] closing connection {address}: {err:?}");
422
76
                }
423
76
            });
424
        }
425

            
426
5
        Ok(())
427
5
    }
428

            
429
    /// Returns all of the currently connected clients.
430
    #[must_use]
431
1
    pub fn connected_clients(&self) -> Vec<ConnectedClient<B>> {
432
1
        let clients = self.data.clients.read();
433
1
        clients.values().cloned().collect()
434
1
    }
435

            
436
    /// Sends a custom API response to all connected clients.
437
    pub fn broadcast<Api: api::Api>(&self, response: &Api::Response) {
438
        let clients = self.data.clients.read();
439
        for client in clients.values() {
440
            // TODO should this broadcast to all sessions too rather than only the global session?
441
            drop(client.send::<Api>(None, response));
442
        }
443
    }
444

            
445
181
    async fn initialize_client(
446
181
        &self,
447
181
        transport: Transport,
448
181
        address: SocketAddr,
449
181
        sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
450
181
    ) -> Option<OwnedClient<B>> {
451
181
        if !self.data.default_session.allowed_to(
452
181
            bonsaidb_resource_name(),
453
181
            &BonsaiAction::Server(ServerAction::Connect),
454
181
        ) {
455
            return None;
456
181
        }
457

            
458
181
        let client = loop {
459
181
            let next_id = CONNECTED_CLIENT_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
460
181
            let mut clients = self.data.clients.write();
461
181
            if let hash_map::Entry::Vacant(e) = clients.entry(next_id) {
462
181
                let client = OwnedClient::new(
463
181
                    next_id,
464
181
                    address,
465
181
                    transport,
466
181
                    sender,
467
181
                    self.clone(),
468
181
                    self.data.default_session.clone(),
469
181
                );
470
181
                e.insert(client.clone());
471
181
                break client;
472
            }
473
        };
474

            
475
181
        match self.data.backend.client_connected(&client, self).await {
476
181
            Ok(ConnectionHandling::Accept) => Some(client),
477
            Ok(ConnectionHandling::Reject) => None,
478
            Err(err) => {
479
                log::error!(
480
                    "[server] Rejecting connection due to error in `client_connected`: {err:?}"
481
                );
482
                None
483
            }
484
        }
485
181
    }
486

            
487
112
    async fn disconnect_client(&self, id: u32) {
488
112
        let removed_client = {
489
112
            let mut clients = self.data.clients.write();
490
112
            clients.remove(&id)
491
        };
492

            
493
112
        if let Some(client) = removed_client {
494
112
            client.set_disconnected();
495
128
            for session in client.all_sessions::<Vec<_>>() {
496
128
                if let Err(err) = self
497
128
                    .data
498
128
                    .backend
499
128
                    .client_session_ended(session, &client, true, self)
500
                    .await
501
                {
502
                    log::error!("[server] Error in `client_session_ended`: {err:?}");
503
128
                }
504
            }
505

            
506
112
            if let Err(err) = self.data.backend.client_disconnected(client, self).await {
507
                log::error!("[server] Error in `client_disconnected`: {err:?}");
508
112
            }
509
        }
510
112
    }
511

            
512
76
    async fn handle_bonsai_connection(
513
76
        &self,
514
76
        mut connection: fabruic::Connection<()>,
515
76
    ) -> Result<(), Error> {
516
76
        if let Some(incoming) = connection.next().await {
517
76
            let incoming = match incoming {
518
76
                Ok(incoming) => incoming,
519
                Err(err) => {
520
                    log::error!("[server] Error establishing a stream: {err:?}");
521
                    return Ok(());
522
                }
523
            };
524

            
525
76
            match incoming
526
76
                .accept::<networking::Payload, networking::Payload>()
527
                .await
528
            {
529
76
                Ok((sender, receiver)) => {
530
76
                    let (api_response_sender, api_response_receiver) = flume::unbounded();
531
76
                    if let Some(disconnector) = self
532
76
                        .initialize_client(
533
76
                            Transport::Bonsai,
534
76
                            connection.remote_address(),
535
76
                            api_response_sender,
536
76
                        )
537
                        .await
538
                    {
539
76
                        let task_sender = sender.clone();
540
76
                        tokio::spawn(async move {
541
15
                            while let Ok((session_id, name, bytes)) =
542
91
                                api_response_receiver.recv_async().await
543
                            {
544
15
                                if task_sender
545
15
                                    .send(&Payload {
546
15
                                        id: None,
547
15
                                        session_id,
548
15
                                        name,
549
15
                                        value: Ok(bytes),
550
15
                                    })
551
15
                                    .is_err()
552
                                {
553
                                    break;
554
15
                                }
555
                            }
556
76
                        });
557
76

            
558
76
                        let task_self = self.clone();
559
76
                        let Some(shutdown) = self.data.shutdown.watcher().await else {
560
                            return Ok(());
561
                        };
562
76
                        tokio::spawn(async move {
563
76
                            if let Err(err) = task_self
564
76
                                .handle_stream(disconnector, sender, receiver, shutdown)
565
29538
                                .await
566
                            {
567
                                log::error!("[server] Error handling stream: {err:?}");
568
22
                            }
569
76
                        });
570
                    } else {
571
                        log::error!("[server] Backend rejected connection.");
572
                        return Ok(());
573
                    }
574
                }
575
                Err(err) => {
576
                    log::error!("[server] Error accepting incoming stream: {err:?}");
577
                    return Ok(());
578
                }
579
            }
580
        }
581
76
        Ok(())
582
76
    }
583

            
584
181
    async fn handle_client_requests(
585
181
        &self,
586
181
        client: ConnectedClient<B>,
587
181
        request_receiver: flume::Receiver<Payload>,
588
181
        response_sender: flume::Sender<Payload>,
589
181
        mut shutdown: ShutdownStateWatcher,
590
181
    ) {
591
181
        let notify = Arc::new(Notify::new());
592
181
        let requests_in_queue = Arc::new(AtomicUsize::new(0));
593
101306
        loop {
594
101306
            let current_requests = requests_in_queue.load(Ordering::SeqCst);
595
101306
            if current_requests == self.data.client_simultaneous_request_limit {
596
                // Wait for requests to finish.
597
29069
                notify.notified().await;
598
72237
            } else if requests_in_queue
599
72237
                .compare_exchange(
600
72237
                    current_requests,
601
72237
                    current_requests + 1,
602
72237
                    Ordering::SeqCst,
603
72237
                    Ordering::SeqCst,
604
72237
                )
605
72237
                .is_ok()
606
72053
            {
607
72234
                let payload = 'payload: loop {
608
108664
                    tokio::select! {
609
72161
                        payload = request_receiver.recv_async() => {
610
                            if let Ok(payload) = payload {
611
                                break 'payload payload
612
                            }
613

            
614
                            return
615
                        },
616
4
                        state = shutdown.wait_for_shutdown() => {
617
                            if matches!(state, ShutdownState::Shutdown | ShutdownState::GracefulShutdown) {
618
                                return
619
                            }
620
                        }
621
72234
                    }
622
72234
                };
623
72053
                let session_id = payload.session_id;
624
72053
                let id = payload.id;
625
72053
                let task_sender = response_sender.clone();
626
72053

            
627
72053
                let notify = notify.clone();
628
72053
                let requests_in_queue = requests_in_queue.clone();
629
72053
                self.handle_request_through_worker(
630
72053
                    payload,
631
72053
                    move |name, value| async move {
632
72049
                        drop(task_sender.send(Payload {
633
72049
                            session_id,
634
72049
                            id,
635
72049
                            name,
636
72049
                            value,
637
72049
                        }));
638
72049

            
639
72049
                        requests_in_queue.fetch_sub(1, Ordering::SeqCst);
640
72049

            
641
72049
                        notify.notify_one();
642
72049

            
643
72049
                        Ok(())
644
72053
                    },
645
72053
                    client.clone(),
646
72053
                )
647
72053
                .unwrap();
648
3
            }
649
        }
650
112
    }
651

            
652
72053
    fn handle_request_through_worker<
653
72053
        F: FnOnce(ApiName, Result<Bytes, bonsaidb_core::Error>) -> R + Send + 'static,
654
72053
        R: Future<Output = Result<(), Error>> + Send,
655
72053
    >(
656
72053
        &self,
657
72053
        request: Payload,
658
72053
        callback: F,
659
72053
        client: ConnectedClient<B>,
660
72053
    ) -> Result<(), Error> {
661
72053
        let (result_sender, result_receiver) = oneshot::channel();
662
72053
        let session = client
663
72053
            .session(request.session_id)
664
72053
            .unwrap_or_else(|| self.data.default_session.clone());
665
72053
        self.data
666
72053
            .request_processor
667
72053
            .send(ClientRequest::<B>::new(
668
72053
                request,
669
72053
                self.clone(),
670
72053
                client,
671
72053
                session,
672
72053
                result_sender,
673
72053
            ))
674
72053
            .map_err(|_| Error::InternalCommunication)?;
675
72053
        tokio::spawn(async move {
676
72053
            let (name, result) = result_receiver.await?;
677
            // Map the error into a Response::Error. The jobs system supports
678
            // multiple receivers receiving output, and wraps Err to avoid
679
            // requiring the error to be cloneable. As such, we have to unwrap
680
            // it. Thankfully, we can guarantee nothing else is waiting on a
681
            // response to a request than the original requestor, so this can be
682
            // safely unwrapped.
683
72049
            callback(name, result).await?;
684
72049
            Result::<(), Error>::Ok(())
685
72053
        });
686
72053
        Ok(())
687
72053
    }
688

            
689
76
    async fn handle_stream(
690
76
        &self,
691
76
        client: OwnedClient<B>,
692
76
        sender: fabruic::Sender<Payload>,
693
76
        mut receiver: fabruic::Receiver<Payload>,
694
76
        mut shutdown: ShutdownStateWatcher,
695
76
    ) -> Result<(), Error> {
696
76
        let (payload_sender, payload_receiver) = flume::unbounded();
697
76
        tokio::spawn({
698
76
            let mut shutdown = shutdown.clone();
699
76
            async move {
700
                'stream: loop {
701
34672
                    let payload = loop {
702
68935
                        tokio::select! {
703
34612
                            payload = payload_receiver.recv_async() => {
704
                                if let Ok(payload) = payload {
705
                                    break payload
706
                                }
707
                                break 'stream
708
                            }
709
5
                            shutdown = shutdown.wait_for_shutdown() => {
710
                                if matches!(shutdown, ShutdownState::Shutdown | ShutdownState::GracefulShutdown) {
711
                                    break 'stream
712
                                }
713
                            }
714
34672
                        }
715
34672
                    };
716
34596
                    if sender.send(&payload).is_err() {
717
                        break;
718
34596
                    }
719
                }
720
76
            }
721
76
        });
722
76

            
723
76
        let (request_sender, request_receiver) =
724
76
            flume::bounded::<Payload>(self.data.client_simultaneous_request_limit);
725
76
        let task_self = self.clone();
726
76
        tokio::spawn({
727
76
            let shutdown = shutdown.clone();
728
76
            async move {
729
76
                task_self
730
76
                    .handle_client_requests(
731
76
                        client.clone(),
732
76
                        request_receiver,
733
76
                        payload_sender,
734
76
                        shutdown,
735
76
                    )
736
30146
                    .await;
737
76
            }
738
76
        });
739

            
740
        loop {
741
34674
            let payload = loop {
742
54362
                tokio::select! {
743
34615
                    payload = receiver.next() => {
744
                        if let Some(payload) = payload {
745
                            break payload
746
                        }
747

            
748
                        receiver.finish().await?;
749

            
750
                        return Ok(());
751
                    }
752
5
                    shutdown = shutdown.wait_for_shutdown() => {
753
                        if matches!(shutdown, ShutdownState::Shutdown | ShutdownState::GracefulShutdown) {
754
                            return Ok(());
755
                        }
756
                    }
757
34674
                }
758
34674
            };
759
34598
            drop(request_sender.send_async(payload?).await);
760
        }
761
22
    }
762

            
763
    /// Shuts the server down. If a `timeout` is provided, the server will stop
764
    /// accepting new connections and attempt to respond to any outstanding
765
    /// requests already being processed. After the `timeout` has elapsed or if
766
    /// no `timeout` was provided, the server is forcefully shut down.
767
36
    pub async fn shutdown(&self, timeout: Option<Duration>) -> Result<(), Error> {
768
36
        if let Some(timeout) = timeout {
769
4
            self.data.shutdown.graceful_shutdown(timeout).await;
770
        } else {
771
32
            self.data.shutdown.shutdown().await;
772
        }
773

            
774
36
        Ok(())
775
36
    }
776

            
777
    /// Listens for signals from the operating system that the server should
778
    /// shut down and attempts to gracefully shut down.
779
1
    pub async fn listen_for_shutdown(&self) -> Result<(), Error> {
780
1
        const GRACEFUL_SHUTDOWN: usize = 1;
781
1
        const TERMINATE: usize = 2;
782
1

            
783
1
        enum SignalShutdownState {
784
1
            Running,
785
1
            ShuttingDown(flume::Receiver<()>),
786
1
        }
787
1

            
788
1
        let shutdown_state = Arc::new(async_lock::Mutex::new(SignalShutdownState::Running));
789
1
        let flag = Arc::new(AtomicUsize::default());
790
1
        let register_hook = |flag: &Arc<AtomicUsize>| {
791
1
            signal_hook::flag::register_usize(SIGINT, flag.clone(), GRACEFUL_SHUTDOWN)?;
792
1
            signal_hook::flag::register_usize(SIGTERM, flag.clone(), TERMINATE)?;
793
            #[cfg(not(windows))]
794
1
            signal_hook::flag::register_usize(SIGQUIT, flag.clone(), TERMINATE)?;
795
1
            Result::<(), std::io::Error>::Ok(())
796
1
        };
797
1
        if let Err(err) = register_hook(&flag) {
798
            log::error!("Error installing signals for graceful shutdown: {err:?}");
799
            tokio::time::sleep(Duration::MAX).await;
800
        } else {
801
            loop {
802
1
                match flag.load(Ordering::Relaxed) {
803
1
                    0 => {
804
1
                        // No signal
805
1
                    }
806
                    GRACEFUL_SHUTDOWN => {
807
                        let mut state = fast_async_lock!(shutdown_state);
808
                        match *state {
809
                            SignalShutdownState::Running => {
810
                                log::error!("Interrupt signal received. Shutting down gracefully.");
811
                                let task_server = self.clone();
812
                                let (shutdown_sender, shutdown_receiver) = flume::bounded(1);
813
                                tokio::task::spawn(async move {
814
                                    task_server.shutdown(Some(Duration::from_secs(30))).await?;
815
                                    let _: Result<_, _> = shutdown_sender.send(());
816
                                    Result::<(), Error>::Ok(())
817
                                });
818
                                *state = SignalShutdownState::ShuttingDown(shutdown_receiver);
819
                            }
820
                            SignalShutdownState::ShuttingDown(_) => {
821
                                // Two interrupts, go ahead and force the shutdown
822
                                break;
823
                            }
824
                        }
825
                    }
826
                    TERMINATE => {
827
                        log::error!("Quit signal received. Shutting down.");
828
                        break;
829
                    }
830
                    _ => unreachable!(),
831
                }
832

            
833
1
                let state = fast_async_lock!(shutdown_state);
834
1
                if let SignalShutdownState::ShuttingDown(receiver) = &*state {
835
                    if receiver.try_recv().is_ok() {
836
                        // Fully shut down.
837
                        return Ok(());
838
                    }
839
1
                } else if self.data.shutdown.should_shutdown() {
840
                    return Ok(());
841
1
                }
842

            
843
1
                tokio::time::sleep(Duration::from_millis(300)).await;
844
            }
845
            self.shutdown(None).await?;
846
        }
847

            
848
        Ok(())
849
    }
850
}
851

            
852
impl<B: Backend> Deref for CustomServer<B> {
853
    type Target = AsyncStorage;
854

            
855
70956
    fn deref(&self) -> &Self::Target {
856
70956
        &self.storage
857
70956
    }
858
}
859

            
860
#[derive(Debug)]
861
struct ClientRequest<B: Backend> {
862
    request: Option<Payload>,
863
    client: ConnectedClient<B>,
864
    session: Session,
865
    server: CustomServer<B>,
866
    result_sender: oneshot::Sender<(ApiName, Result<Bytes, bonsaidb_core::Error>)>,
867
}
868

            
869
impl<B: Backend> ClientRequest<B> {
870
72053
    pub fn new(
871
72053
        request: Payload,
872
72053
        server: CustomServer<B>,
873
72053
        client: ConnectedClient<B>,
874
72053
        session: Session,
875
72053
        result_sender: oneshot::Sender<(ApiName, Result<Bytes, bonsaidb_core::Error>)>,
876
72053
    ) -> Self {
877
72053
        Self {
878
72053
            request: Some(request),
879
72053
            server,
880
72053
            client,
881
72053
            session,
882
72053
            result_sender,
883
72053
        }
884
72053
    }
885
}
886

            
887
impl<B: Backend> HasSession for CustomServer<B> {
888
109
    fn session(&self) -> Option<&Session> {
889
109
        self.storage.session()
890
109
    }
891
}
892

            
893
#[async_trait]
894
impl<B: Backend> AsyncStorageConnection for CustomServer<B> {
895
    type Authenticated = Self;
896
    type Database = ServerDatabase<B>;
897

            
898
    async fn admin(&self) -> Self::Database {
899
        self.database::<Admin>(ADMIN_DATABASE_NAME).await.unwrap()
900
    }
901

            
902
711
    async fn create_database_with_schema(
903
711
        &self,
904
711
        name: &str,
905
711
        schema: SchemaName,
906
711
        only_if_needed: bool,
907
711
    ) -> Result<(), bonsaidb_core::Error> {
908
711
        self.storage
909
711
            .create_database_with_schema(name, schema, only_if_needed)
910
710
            .await
911
2133
    }
912

            
913
175
    async fn database<DB: Schema>(
914
175
        &self,
915
175
        name: &str,
916
175
    ) -> Result<Self::Database, bonsaidb_core::Error> {
917
175
        let db = self.storage.database::<DB>(name).await?;
918
175
        Ok(ServerDatabase {
919
175
            server: self.clone(),
920
175
            db,
921
175
        })
922
525
    }
923

            
924
508
    async fn delete_database(&self, name: &str) -> Result<(), bonsaidb_core::Error> {
925
508
        self.storage.delete_database(name).await
926
1524
    }
927

            
928
4
    async fn list_databases(&self) -> Result<Vec<connection::Database>, bonsaidb_core::Error> {
929
4
        self.storage.list_databases().await
930
12
    }
931

            
932
4
    async fn list_available_schemas(&self) -> Result<Vec<SchemaSummary>, bonsaidb_core::Error> {
933
4
        self.storage.list_available_schemas().await
934
12
    }
935

            
936
17
    async fn create_user(&self, username: &str) -> Result<u64, bonsaidb_core::Error> {
937
17
        self.storage.create_user(username).await
938
51
    }
939

            
940
4
    async fn delete_user<'user, U: Nameable<'user, u64> + Send + Sync>(
941
4
        &self,
942
4
        user: U,
943
4
    ) -> Result<(), bonsaidb_core::Error> {
944
4
        self.storage.delete_user(user).await
945
12
    }
946

            
947
    #[cfg(feature = "password-hashing")]
948
5
    async fn set_user_password<'user, U: Nameable<'user, u64> + Send + Sync>(
949
5
        &self,
950
5
        user: U,
951
5
        password: bonsaidb_core::connection::SensitiveString,
952
5
    ) -> Result<(), bonsaidb_core::Error> {
953
5
        self.storage.set_user_password(user, password).await
954
15
    }
955

            
956
    #[cfg(any(feature = "token-authentication", feature = "password-hashing"))]
957
23
    async fn authenticate(
958
23
        &self,
959
23
        authentication: bonsaidb_core::connection::Authentication,
960
23
    ) -> Result<Self::Authenticated, bonsaidb_core::Error> {
961
23
        let storage = self.storage.authenticate(authentication).await?;
962
23
        Ok(Self {
963
23
            data: self.data.clone(),
964
23
            storage,
965
23
        })
966
69
    }
967

            
968
2
    async fn assume_identity(
969
2
        &self,
970
2
        identity: IdentityReference<'_>,
971
2
    ) -> Result<Self::Authenticated, bonsaidb_core::Error> {
972
2
        let storage = self.storage.assume_identity(identity).await?;
973
2
        Ok(Self {
974
2
            data: self.data.clone(),
975
2
            storage,
976
2
        })
977
6
    }
978

            
979
12
    async fn add_permission_group_to_user<
980
12
        'user,
981
12
        'group,
982
12
        U: Nameable<'user, u64> + Send + Sync,
983
12
        G: Nameable<'group, u64> + Send + Sync,
984
12
    >(
985
12
        &self,
986
12
        user: U,
987
12
        permission_group: G,
988
12
    ) -> Result<(), bonsaidb_core::Error> {
989
12
        self.storage
990
12
            .add_permission_group_to_user(user, permission_group)
991
12
            .await
992
36
    }
993

            
994
8
    async fn remove_permission_group_from_user<
995
8
        'user,
996
8
        'group,
997
8
        U: Nameable<'user, u64> + Send + Sync,
998
8
        G: Nameable<'group, u64> + Send + Sync,
999
8
    >(
8
        &self,
8
        user: U,
8
        permission_group: G,
8
    ) -> Result<(), bonsaidb_core::Error> {
8
        self.storage
8
            .remove_permission_group_from_user(user, permission_group)
7
            .await
24
    }

            
8
    async fn add_role_to_user<
8
        'user,
8
        'group,
8
        U: Nameable<'user, u64> + Send + Sync,
8
        G: Nameable<'group, u64> + Send + Sync,
8
    >(
8
        &self,
8
        user: U,
8
        role: G,
8
    ) -> Result<(), bonsaidb_core::Error> {
8
        self.storage.add_role_to_user(user, role).await
24
    }

            
8
    async fn remove_role_from_user<
8
        'user,
8
        'group,
8
        U: Nameable<'user, u64> + Send + Sync,
8
        G: Nameable<'group, u64> + Send + Sync,
8
    >(
8
        &self,
8
        user: U,
8
        role: G,
8
    ) -> Result<(), bonsaidb_core::Error> {
8
        self.storage.remove_role_from_user(user, role).await
24
    }
}

            
92
#[derive(Default)]
struct AlpnKeys(Arc<Mutex<HashMap<String, Arc<rustls::sign::CertifiedKey>>>>);

            
impl Debug for AlpnKeys {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_tuple("AlpnKeys").finish()
    }
}

            
impl Deref for AlpnKeys {
    type Target = Arc<Mutex<HashMap<String, Arc<rustls::sign::CertifiedKey>>>>;

            
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}