1
#[cfg(feature = "test-util")]
2
use std::sync::atomic::AtomicBool;
3
use std::{
4
    any::TypeId,
5
    collections::HashMap,
6
    fmt::Debug,
7
    marker::PhantomData,
8
    ops::Deref,
9
    sync::{
10
        atomic::{AtomicU32, Ordering},
11
        Arc,
12
    },
13
};
14

            
15
use async_lock::Mutex;
16
use async_trait::async_trait;
17
#[cfg(feature = "password-hashing")]
18
use bonsaidb_core::connection::{Authenticated, Authentication};
19
use bonsaidb_core::{
20
    connection::{Database, StorageConnection},
21
    custom_api::{CustomApi, CustomApiResult},
22
    networking::{
23
        self, Payload, Request, Response, ServerRequest, ServerResponse, CURRENT_PROTOCOL_VERSION,
24
    },
25
    permissions::Permissions,
26
    schema::{Nameable, Schema, SchemaName, Schematic},
27
};
28
use bonsaidb_utils::fast_async_lock;
29
use derive_where::derive_where;
30
use flume::Sender;
31
#[cfg(not(target_arch = "wasm32"))]
32
use tokio::task::JoinHandle;
33
use url::Url;
34

            
35
pub use self::remote_database::{RemoteDatabase, RemoteSubscriber};
36
use crate::{error::Error, Builder};
37

            
38
#[cfg(not(target_arch = "wasm32"))]
39
mod quic_worker;
40
mod remote_database;
41
#[cfg(all(feature = "websockets", not(target_arch = "wasm32")))]
42
mod tungstenite_worker;
43
#[cfg(all(feature = "websockets", target_arch = "wasm32"))]
44
mod wasm_websocket_worker;
45

            
46
234
#[derive(Debug, Clone, Default)]
47
pub struct SubscriberMap(Arc<Mutex<HashMap<u64, flume::Sender<Arc<Message>>>>>);
48

            
49
impl SubscriberMap {
50
2541
    pub async fn clear(&self) {
51
121
        let mut data = fast_async_lock!(self);
52
121
        data.clear();
53
121
    }
54
}
55

            
56
impl Deref for SubscriberMap {
57
    type Target = Mutex<HashMap<u64, flume::Sender<Arc<Message>>>>;
58

            
59
3423
    fn deref(&self) -> &Self::Target {
60
3423
        &self.0
61
3423
    }
62
}
63

            
64
use bonsaidb_core::{circulate::Message, networking::DatabaseRequest};
65

            
66
#[cfg(all(feature = "websockets", not(target_arch = "wasm32")))]
67
pub type WebSocketError = tokio_tungstenite::tungstenite::Error;
68

            
69
#[cfg(all(feature = "websockets", target_arch = "wasm32"))]
70
pub type WebSocketError = wasm_websocket_worker::WebSocketError;
71

            
72
/// Client for connecting to a BonsaiDb server.
73
///
74
///
75
///
76
/// ## Connecting via QUIC
77
///
78
/// The URL scheme to connect via QUIC is `bonsaidb`. If no port is specified,
79
/// port 5645 is assumed.
80
///
81
/// ### With a valid TLS certificate
82
///
83
/// ```rust
84
/// # use bonsaidb_client::{Client, fabruic::Certificate, url::Url};
85
/// # async fn test_fn() -> anyhow::Result<()> {
86
/// let client = Client::build(Url::parse("bonsaidb://my-server.com")?)
87
///     .finish()
88
///     .await?;
89
/// # Ok(())
90
/// # }
91
/// ```
92
///
93
/// ### With a Self-Signed Pinned Certificate
94
///
95
/// When using `install_self_signed_certificate()`, clients will need the
96
/// contents of the `pinned-certificate.der` file within the database. It can be
97
/// specified when building the client:
98
///
99
/// ```rust
100
/// # use bonsaidb_client::{Client, fabruic::Certificate, url::Url};
101
/// # async fn test_fn() -> anyhow::Result<()> {
102
/// let certificate =
103
///     Certificate::from_der(std::fs::read("mydb.bonsaidb/pinned-certificate.der")?)?;
104
/// let client = Client::build(Url::parse("bonsaidb://localhost")?)
105
///     .with_certificate(certificate)
106
///     .finish()
107
///     .await?;
108
/// # Ok(())
109
/// # }
110
/// ```
111
///
112
/// ## Connecting via WebSockets
113
///
114
/// WebSockets are built atop the HTTP protocol. There are two URL schemes for
115
/// WebSockets:
116
///
117
/// - `ws`: Insecure WebSockets. Port 80 is assumed if no port is specified.
118
/// - `wss`: Secure WebSockets. Port 443 is assumed if no port is specified.
119
///
120
/// ### Without TLS
121
///
122
/// ```rust
123
/// # use bonsaidb_client::{Client, fabruic::Certificate, url::Url};
124
/// # async fn test_fn() -> anyhow::Result<()> {
125
/// let client = Client::build(Url::parse("ws://localhost")?)
126
///     .finish()
127
///     .await?;
128
/// # Ok(())
129
/// # }
130
/// ```
131
///
132
/// ### With TLS
133
///
134
/// ```rust
135
/// # use bonsaidb_client::{Client, fabruic::Certificate, url::Url};
136
/// # async fn test_fn() -> anyhow::Result<()> {
137
/// let client = Client::build(Url::parse("wss://my-server.com")?)
138
///     .finish()
139
///     .await?;
140
/// # Ok(())
141
/// # }
142
/// ```
143
///
144
/// ## Using a `CustomApi`
145
///
146
/// Our user guide has a [section on creating and using a
147
/// CustomApi](https://dev.bonsaidb.io/release/guide/about/access-models/custom-api-server.html).
148
///
149
/// ```rust
150
/// # use bonsaidb_client::{Client, fabruic::Certificate, url::Url};
151
/// // `bonsaidb_core` is re-exported to `bonsaidb::core` or `bonsaidb_client::core`.
152
/// use bonsaidb_core::custom_api::{CustomApi, Infallible};
153
/// use serde::{Deserialize, Serialize};
154
///
155
/// #[derive(Serialize, Deserialize, Debug)]
156
/// pub enum Request {
157
///     Ping,
158
/// }
159
///
160
/// #[derive(Serialize, Deserialize, Clone, Debug)]
161
/// pub enum Response {
162
///     Pong,
163
/// }
164
///
165
/// #[derive(Debug)]
166
/// pub enum MyApi {}
167
///
168
/// impl CustomApi for MyApi {
169
///     type Request = Request;
170
///     type Response = Response;
171
///     type Error = Infallible;
172
/// }
173
///
174
/// # async fn test_fn() -> anyhow::Result<()> {
175
/// let client = Client::build(Url::parse("bonsaidb://localhost")?)
176
///     .with_custom_api::<MyApi>()
177
///     .finish()
178
///     .await?;
179
/// let Response::Pong = client.send_api_request(Request::Ping).await?;
180
/// # Ok(())
181
/// # }
182
/// ```
183
///
184
/// ### Receiving out-of-band messages from the server
185
///
186
/// If the server sends a message that isn't in response to a request, the
187
/// client will invoke it's [custom api
188
/// callback](Builder::with_custom_api_callback):
189
///
190
/// ```rust
191
/// # use bonsaidb_client::{Client, fabruic::Certificate, url::Url};
192
/// # // `bonsaidb_core` is re-exported to `bonsaidb::core` or `bonsaidb_client::core`.
193
/// # use bonsaidb_core::custom_api::{CustomApi, Infallible};
194
/// # use serde::{Serialize, Deserialize};
195
/// # #[derive(Serialize, Deserialize, Debug)]
196
/// # pub enum Request {
197
/// #     Ping
198
/// # }
199
/// # #[derive(Serialize, Deserialize, Clone, Debug)]
200
/// # pub enum Response {
201
/// #     Pong
202
/// # }
203
/// # #[derive(Debug)]
204
/// # pub enum MyApi {}
205
/// # impl CustomApi for MyApi {
206
/// #     type Request = Request;
207
/// #     type Response = Response;
208
/// #     type Error = Infallible;
209
/// # }
210
/// # async fn test_fn() -> anyhow::Result<()> {
211
/// let client = Client::build(Url::parse("bonsaidb://localhost")?)
212
///     .with_custom_api_callback::<MyApi, _>(|result: Result<Response, Infallible>| {
213
///         let Response::Pong = result.unwrap();
214
///     })
215
///     .finish()
216
///     .await?;
217
/// # Ok(())
218
/// # }
219
/// ```
220
#[derive(Debug)]
221
916
#[derive_where(Clone)]
222
pub struct Client<A: CustomApi = ()> {
223
    pub(crate) data: Arc<Data<A>>,
224
}
225

            
226
impl<A> PartialEq for Client<A>
227
where
228
    A: CustomApi,
229
{
230
    fn eq(&self, other: &Self) -> bool {
231
        Arc::ptr_eq(&self.data, &other.data)
232
    }
233
}
234

            
235
#[derive(Debug)]
236
pub struct Data<A: CustomApi> {
237
    request_sender: Sender<PendingRequest<A>>,
238
    #[cfg(not(target_arch = "wasm32"))]
239
    _worker: CancellableHandle<Result<(), Error<A::Error>>>,
240
    effective_permissions: Mutex<Option<Permissions>>,
241
    schemas: Mutex<HashMap<TypeId, Arc<Schematic>>>,
242
    request_id: AtomicU32,
243
    subscribers: SubscriberMap,
244
    #[cfg(feature = "test-util")]
245
    background_task_running: Arc<AtomicBool>,
246
}
247

            
248
impl Client<()> {
249
    /// Returns a builder for a new client connecting to `url`.
250
1827
    pub fn build(url: Url) -> Builder<()> {
251
1827
        Builder::new(url)
252
1827
    }
253
}
254

            
255
impl<A: CustomApi> Client<A> {
256
    /// Initialize a client connecting to `url`. This client can be shared by
257
    /// cloning it. All requests are done asynchronously over the same
258
    /// connection.
259
    ///
260
    /// If the client has an error connecting, the first request made will
261
    /// present that error. If the client disconnects while processing requests,
262
    /// all requests being processed will exit and return
263
    /// [`Error::Disconnected`]. The client will automatically try reconnecting.
264
    ///
265
    /// The goal of this design of this reconnection strategy is to make it
266
    /// easier to build resilliant apps. By allowing existing Client instances
267
    /// to recover and reconnect, each component of the apps built can adopt a
268
    /// "retry-to-recover" design, or "abort-and-fail" depending on how critical
269
    /// the database is to operation.
270
31
    pub async fn new(url: Url) -> Result<Self, Error<A::Error>> {
271
31
        Self::new_from_parts(
272
31
            url,
273
31
            CURRENT_PROTOCOL_VERSION,
274
31
            #[cfg(not(target_arch = "wasm32"))]
275
31
            None,
276
31
            None,
277
31
        )
278
        .await
279
31
    }
280

            
281
    /// Initialize a client connecting to `url` with `certificate` being used to
282
    /// validate and encrypt the connection. This client can be shared by
283
    /// cloning it. All requests are done asynchronously over the same
284
    /// connection.
285
    ///
286
    /// If the client has an error connecting, the first request made will
287
    /// present that error. If the client disconnects while processing requests,
288
    /// all requests being processed will exit and return
289
    /// [`Error::Disconnected`]. The client will automatically try reconnecting.
290
    ///
291
    /// The goal of this design of this reconnection strategy is to make it
292
    /// easier to build resilliant apps. By allowing existing Client instances
293
    /// to recover and reconnect, each component of the apps built can adopt a
294
    /// "retry-to-recover" design, or "abort-and-fail" depending on how critical
295
    /// the database is to operation.
296
118
    pub(crate) async fn new_from_parts(
297
118
        url: Url,
298
118
        protocol_version: &'static str,
299
118
        custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
300
118
        #[cfg(not(target_arch = "wasm32"))] certificate: Option<fabruic::Certificate>,
301
118
    ) -> Result<Self, Error<A::Error>> {
302
118
        match url.scheme() {
303
118
            #[cfg(not(target_arch = "wasm32"))]
304
118
            "bonsaidb" => Ok(Self::new_bonsai_client(
305
62
                url,
306
62
                protocol_version,
307
62
                certificate,
308
62
                custom_api_callback,
309
62
            )),
310
            #[cfg(feature = "websockets")]
311
56
            "wss" | "ws" => {
312
56
                Self::new_websocket_client(url, protocol_version, custom_api_callback).await
313
            }
314
            other => {
315
                return Err(Error::InvalidUrl(format!("unsupported scheme {}", other)));
316
            }
317
        }
318
118
    }
319

            
320
    #[cfg(not(target_arch = "wasm32"))]
321
62
    fn new_bonsai_client(
322
62
        url: Url,
323
62
        protocol_version: &'static str,
324
62
        certificate: Option<fabruic::Certificate>,
325
62
        custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
326
62
    ) -> Self {
327
62
        let (request_sender, request_receiver) = flume::unbounded();
328
62

            
329
62
        let subscribers = SubscriberMap::default();
330
62
        let worker = tokio::task::spawn(quic_worker::reconnecting_client_loop(
331
62
            url,
332
62
            protocol_version,
333
62
            certificate,
334
62
            request_receiver,
335
62
            custom_api_callback,
336
62
            subscribers.clone(),
337
62
        ));
338
62

            
339
62
        #[cfg(feature = "test-util")]
340
62
        let background_task_running = Arc::new(AtomicBool::new(true));
341
62

            
342
62
        Self {
343
62
            data: Arc::new(Data {
344
62
                request_sender,
345
62
                _worker: CancellableHandle {
346
62
                    worker,
347
62
                    #[cfg(feature = "test-util")]
348
62
                    background_task_running: background_task_running.clone(),
349
62
                },
350
62
                schemas: Mutex::default(),
351
62
                request_id: AtomicU32::default(),
352
62
                effective_permissions: Mutex::default(),
353
62
                subscribers,
354
62
                #[cfg(feature = "test-util")]
355
62
                background_task_running,
356
62
            }),
357
62
        }
358
62
    }
359

            
360
    #[cfg(all(feature = "websockets", not(target_arch = "wasm32")))]
361
56
    async fn new_websocket_client(
362
56
        url: Url,
363
56
        protocol_version: &'static str,
364
56
        custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
365
56
    ) -> Result<Self, Error<A::Error>> {
366
56
        let (request_sender, request_receiver) = flume::unbounded();
367
56

            
368
56
        let subscribers = SubscriberMap::default();
369
56

            
370
56
        let worker = tokio::task::spawn(tungstenite_worker::reconnecting_client_loop(
371
56
            url,
372
56
            protocol_version,
373
56
            request_receiver,
374
56
            custom_api_callback,
375
56
            subscribers.clone(),
376
56
        ));
377
56

            
378
56
        #[cfg(feature = "test-util")]
379
56
        let background_task_running = Arc::new(AtomicBool::new(true));
380
56

            
381
56
        let client = Self {
382
56
            data: Arc::new(Data {
383
56
                request_sender,
384
56
                #[cfg(not(target_arch = "wasm32"))]
385
56
                _worker: CancellableHandle {
386
56
                    worker,
387
56
                    #[cfg(feature = "test-util")]
388
56
                    background_task_running: background_task_running.clone(),
389
56
                },
390
56
                schemas: Mutex::default(),
391
56
                request_id: AtomicU32::default(),
392
56
                effective_permissions: Mutex::default(),
393
56
                subscribers,
394
56
                #[cfg(feature = "test-util")]
395
56
                background_task_running,
396
56
            }),
397
56
        };
398
56

            
399
56
        Ok(client)
400
56
    }
401

            
402
    #[cfg(all(feature = "websockets", target_arch = "wasm32"))]
403
    async fn new_websocket_client(
404
        url: Url,
405
        protocol_version: &'static str,
406
        custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
407
    ) -> Result<Self, Error<A::Error>> {
408
        let (request_sender, request_receiver) = flume::unbounded();
409

            
410
        let subscribers = SubscriberMap::default();
411

            
412
        wasm_websocket_worker::spawn_client(
413
            Arc::new(url),
414
            protocol_version,
415
            request_receiver,
416
            custom_api_callback.clone(),
417
            subscribers.clone(),
418
        );
419

            
420
        #[cfg(feature = "test-util")]
421
        let background_task_running = Arc::new(AtomicBool::new(true));
422

            
423
        let client = Self {
424
            data: Arc::new(Data {
425
                request_sender,
426
                #[cfg(not(target_arch = "wasm32"))]
427
                worker: CancellableHandle {
428
                    worker,
429
                    #[cfg(feature = "test-util")]
430
                    background_task_running: background_task_running.clone(),
431
                },
432
                schemas: Mutex::default(),
433
                request_id: AtomicU32::default(),
434
                effective_permissions: Mutex::default(),
435
                subscribers,
436
                #[cfg(feature = "test-util")]
437
                background_task_running,
438
            }),
439
        };
440

            
441
        Ok(client)
442
    }
443

            
444
52530
    async fn send_request(
445
52530
        &self,
446
52530
        request: Request<<A as CustomApi>::Request>,
447
52530
    ) -> Result<Response<CustomApiResult<A>>, Error<A::Error>> {
448
52530
        let (result_sender, result_receiver) = flume::bounded(1);
449
52530
        let id = self.data.request_id.fetch_add(1, Ordering::SeqCst);
450
52530
        self.data.request_sender.send(PendingRequest {
451
52530
            request: Payload {
452
52530
                id: Some(id),
453
52530
                wrapped: request,
454
52530
            },
455
52530
            responder: result_sender.clone(),
456
52530
            _phantom: PhantomData,
457
52530
        })?;
458

            
459
72777
        result_receiver.recv_async().await?
460
52530
    }
461

            
462
    /// Sends an api `request`.
463
14
    pub async fn send_api_request(
464
14
        &self,
465
14
        request: <A as CustomApi>::Request,
466
14
    ) -> Result<A::Response, Error<A::Error>> {
467
18
        match self.send_request(Request::Api(request)).await? {
468
10
            Response::Api(response) => response.map_err(Error::Api),
469
4
            Response::Error(err) => Err(Error::Core(err)),
470
            other => Err(Error::Network(networking::Error::UnexpectedResponse(
471
                format!("{:?}", other),
472
            ))),
473
        }
474
14
    }
475

            
476
    /// Returns the current effective permissions for the client. Returns None
477
    /// if unauthenticated.
478
    pub async fn effective_permissions(&self) -> Option<Permissions> {
479
        let effective_permissions = fast_async_lock!(self.data.effective_permissions);
480
        effective_permissions.clone()
481
    }
482

            
483
    #[cfg(feature = "test-util")]
484
    #[doc(hidden)]
485
    #[must_use]
486
    pub fn background_task_running(&self) -> Arc<AtomicBool> {
487
        self.data.background_task_running.clone()
488
    }
489

            
490
14
    pub(crate) async fn register_subscriber(&self, id: u64, sender: flume::Sender<Arc<Message>>) {
491
14
        let mut subscribers = fast_async_lock!(self.data.subscribers);
492
14
        subscribers.insert(id, sender);
493
14
    }
494

            
495
    pub(crate) async fn unregister_subscriber(&self, database: String, id: u64) {
496
        drop(
497
            self.send_request(Request::Database {
498
                database,
499
                request: DatabaseRequest::UnregisterSubscriber { subscriber_id: id },
500
            })
501
            .await,
502
        );
503
        let mut subscribers = fast_async_lock!(self.data.subscribers);
504
        subscribers.remove(&id);
505
    }
506
}
507

            
508
#[async_trait]
509
impl<A: CustomApi> StorageConnection for Client<A> {
510
    type Database = RemoteDatabase<A>;
511

            
512
570
    async fn create_database_with_schema(
513
570
        &self,
514
570
        name: &str,
515
570
        schema: SchemaName,
516
570
        only_if_needed: bool,
517
570
    ) -> Result<(), bonsaidb_core::Error> {
518
570
        match self
519
570
            .send_request(Request::Server(ServerRequest::CreateDatabase {
520
570
                database: Database {
521
570
                    name: name.to_string(),
522
570
                    schema,
523
570
                },
524
570
                only_if_needed,
525
2005
            }))
526
2005
            .await?
527
        {
528
564
            Response::Server(ServerResponse::DatabaseCreated { .. }) => Ok(()),
529
6
            Response::Error(err) => Err(err),
530
            other => Err(bonsaidb_core::Error::Networking(
531
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
532
            )),
533
        }
534
1140
    }
535

            
536
616
    async fn database<DB: Schema>(
537
616
        &self,
538
616
        name: &str,
539
616
    ) -> Result<Self::Database, bonsaidb_core::Error> {
540
616
        let mut schemas = fast_async_lock!(self.data.schemas);
541
616
        let type_id = TypeId::of::<DB>();
542
616
        let schematic = if let Some(schematic) = schemas.get(&type_id) {
543
500
            schematic.clone()
544
        } else {
545
116
            let schematic = Arc::new(DB::schematic()?);
546
116
            schemas.insert(type_id, schematic.clone());
547
116
            schematic
548
        };
549
616
        Ok(RemoteDatabase::new(
550
616
            self.clone(),
551
616
            name.to_string(),
552
616
            schematic,
553
616
        ))
554
1232
    }
555

            
556
504
    async fn delete_database(&self, name: &str) -> Result<(), bonsaidb_core::Error> {
557
504
        match self
558
504
            .send_request(Request::Server(ServerRequest::DeleteDatabase {
559
504
                name: name.to_string(),
560
1830
            }))
561
1830
            .await?
562
        {
563
502
            Response::Server(ServerResponse::DatabaseDeleted { .. }) => Ok(()),
564
2
            Response::Error(err) => Err(err),
565
            other => Err(bonsaidb_core::Error::Networking(
566
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
567
            )),
568
        }
569
1008
    }
570

            
571
2
    async fn list_databases(&self) -> Result<Vec<Database>, bonsaidb_core::Error> {
572
2
        match self
573
2
            .send_request(Request::Server(ServerRequest::ListDatabases))
574
2
            .await?
575
        {
576
2
            Response::Server(ServerResponse::Databases(databases)) => Ok(databases),
577
            Response::Error(err) => Err(err),
578
            other => Err(bonsaidb_core::Error::Networking(
579
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
580
            )),
581
        }
582
4
    }
583

            
584
2
    async fn list_available_schemas(&self) -> Result<Vec<SchemaName>, bonsaidb_core::Error> {
585
2
        match self
586
2
            .send_request(Request::Server(ServerRequest::ListAvailableSchemas))
587
2
            .await?
588
        {
589
2
            Response::Server(ServerResponse::AvailableSchemas(schemas)) => Ok(schemas),
590
            Response::Error(err) => Err(err),
591
            other => Err(bonsaidb_core::Error::Networking(
592
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
593
            )),
594
        }
595
4
    }
596

            
597
4
    async fn create_user(&self, username: &str) -> Result<u64, bonsaidb_core::Error> {
598
4
        match self
599
4
            .send_request(Request::Server(ServerRequest::CreateUser {
600
4
                username: username.to_string(),
601
4
            }))
602
4
            .await?
603
        {
604
3
            Response::Server(ServerResponse::UserCreated { id }) => Ok(id),
605
1
            Response::Error(err) => Err(err),
606
            other => Err(bonsaidb_core::Error::Networking(
607
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
608
            )),
609
        }
610
8
    }
611

            
612
2
    async fn delete_user<'user, U: Nameable<'user, u64> + Send + Sync>(
613
2
        &self,
614
2
        user: U,
615
2
    ) -> Result<(), bonsaidb_core::Error> {
616
        match self
617
            .send_request(Request::Server(ServerRequest::DeleteUser {
618
2
                user: user.name()?.into_owned(),
619
2
            }))
620
2
            .await?
621
        {
622
2
            Response::Ok => Ok(()),
623
            Response::Error(err) => Err(err),
624
            other => Err(bonsaidb_core::Error::Networking(
625
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
626
            )),
627
        }
628
4
    }
629

            
630
    #[cfg(feature = "password-hashing")]
631
    async fn set_user_password<'user, U: Nameable<'user, u64> + Send + Sync>(
632
        &self,
633
        user: U,
634
        password: bonsaidb_core::connection::SensitiveString,
635
    ) -> Result<(), bonsaidb_core::Error> {
636
        match self
637
            .send_request(Request::Server(ServerRequest::SetUserPassword {
638
                user: user.name()?.into_owned(),
639
                password,
640
            }))
641
            .await?
642
        {
643
            Response::Ok => Ok(()),
644
            Response::Error(err) => Err(err),
645
            other => Err(bonsaidb_core::Error::Networking(
646
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
647
            )),
648
        }
649
    }
650

            
651
    #[cfg(feature = "password-hashing")]
652
5
    async fn authenticate<'user, U: Nameable<'user, u64> + Send + Sync>(
653
5
        &self,
654
5
        user: U,
655
5
        authentication: Authentication,
656
5
    ) -> Result<Authenticated, bonsaidb_core::Error> {
657
        match self
658
            .send_request(Request::Server(ServerRequest::Authenticate {
659
5
                user: user.name()?.into_owned(),
660
5
                authentication,
661
12
            }))
662
12
            .await?
663
        {
664
5
            Response::Server(ServerResponse::Authenticated(response)) => Ok(response),
665
            Response::Error(err) => Err(err),
666
            other => Err(bonsaidb_core::Error::Networking(
667
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
668
            )),
669
        }
670
10
    }
671

            
672
4
    async fn add_permission_group_to_user<
673
4
        'user,
674
4
        'group,
675
4
        U: Nameable<'user, u64> + Send + Sync,
676
4
        G: Nameable<'group, u64> + Send + Sync,
677
4
    >(
678
4
        &self,
679
4
        user: U,
680
4
        permission_group: G,
681
4
    ) -> Result<(), bonsaidb_core::Error> {
682
        match self
683
            .send_request(Request::Server(
684
                ServerRequest::AlterUserPermissionGroupMembership {
685
4
                    user: user.name()?.into_owned(),
686
4
                    group: permission_group.name()?.into_owned(),
687
                    should_be_member: true,
688
                },
689
4
            ))
690
4
            .await?
691
        {
692
4
            Response::Ok => Ok(()),
693
            Response::Error(err) => Err(err),
694
            other => Err(bonsaidb_core::Error::Networking(
695
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
696
            )),
697
        }
698
8
    }
699

            
700
4
    async fn remove_permission_group_from_user<
701
4
        'user,
702
4
        'group,
703
4
        U: Nameable<'user, u64> + Send + Sync,
704
4
        G: Nameable<'group, u64> + Send + Sync,
705
4
    >(
706
4
        &self,
707
4
        user: U,
708
4
        permission_group: G,
709
4
    ) -> Result<(), bonsaidb_core::Error> {
710
        match self
711
            .send_request(Request::Server(
712
                ServerRequest::AlterUserPermissionGroupMembership {
713
4
                    user: user.name()?.into_owned(),
714
4
                    group: permission_group.name()?.into_owned(),
715
                    should_be_member: false,
716
                },
717
4
            ))
718
4
            .await?
719
        {
720
4
            Response::Ok => Ok(()),
721
            Response::Error(err) => Err(err),
722
            other => Err(bonsaidb_core::Error::Networking(
723
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
724
            )),
725
        }
726
8
    }
727

            
728
4
    async fn add_role_to_user<
729
4
        'user,
730
4
        'group,
731
4
        U: Nameable<'user, u64> + Send + Sync,
732
4
        G: Nameable<'group, u64> + Send + Sync,
733
4
    >(
734
4
        &self,
735
4
        user: U,
736
4
        role: G,
737
4
    ) -> Result<(), bonsaidb_core::Error> {
738
        match self
739
            .send_request(Request::Server(ServerRequest::AlterUserRoleMembership {
740
4
                user: user.name()?.into_owned(),
741
4
                role: role.name()?.into_owned(),
742
                should_be_member: true,
743
4
            }))
744
4
            .await?
745
        {
746
4
            Response::Ok => Ok(()),
747
            Response::Error(err) => Err(err),
748
            other => Err(bonsaidb_core::Error::Networking(
749
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
750
            )),
751
        }
752
8
    }
753

            
754
4
    async fn remove_role_from_user<
755
4
        'user,
756
4
        'group,
757
4
        U: Nameable<'user, u64> + Send + Sync,
758
4
        G: Nameable<'group, u64> + Send + Sync,
759
4
    >(
760
4
        &self,
761
4
        user: U,
762
4
        role: G,
763
4
    ) -> Result<(), bonsaidb_core::Error> {
764
        match self
765
            .send_request(Request::Server(ServerRequest::AlterUserRoleMembership {
766
4
                user: user.name()?.into_owned(),
767
4
                role: role.name()?.into_owned(),
768
                should_be_member: false,
769
4
            }))
770
4
            .await?
771
        {
772
4
            Response::Ok => Ok(()),
773
            Response::Error(err) => Err(err),
774
            other => Err(bonsaidb_core::Error::Networking(
775
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
776
            )),
777
        }
778
8
    }
779
}
780

            
781
type OutstandingRequestMap<Api> = HashMap<u32, PendingRequest<Api>>;
782
type OutstandingRequestMapHandle<Api> = Arc<Mutex<OutstandingRequestMap<Api>>>;
783
type PendingRequestResponder<Api> =
784
    Sender<Result<Response<CustomApiResult<Api>>, Error<<Api as CustomApi>::Error>>>;
785

            
786
#[derive(Debug)]
787
pub struct PendingRequest<Api: CustomApi> {
788
    request: Payload<Request<Api::Request>>,
789
    responder: PendingRequestResponder<Api>,
790
    _phantom: PhantomData<Api>,
791
}
792

            
793
#[cfg(not(target_arch = "wasm32"))]
794
#[derive(Debug)]
795
struct CancellableHandle<T> {
796
    worker: JoinHandle<T>,
797
    #[cfg(feature = "test-util")]
798
    background_task_running: Arc<AtomicBool>,
799
}
800

            
801
#[cfg(not(target_arch = "wasm32"))]
802
impl<T> Drop for CancellableHandle<T> {
803
1989
    fn drop(&mut self) {
804
1989
        self.worker.abort();
805
1989
        #[cfg(feature = "test-util")]
806
1989
        self.background_task_running.store(false, Ordering::Release);
807
1989
    }
808
}
809

            
810
52556
async fn process_response_payload<A: CustomApi>(
811
52556
    payload: Payload<Response<CustomApiResult<A>>>,
812
52556
    outstanding_requests: &OutstandingRequestMapHandle<A>,
813
52556
    custom_api_callback: Option<&dyn CustomApiCallback<A>>,
814
52556
    subscribers: &SubscriberMap,
815
52556
) {
816
52556
    if let Some(payload_id) = payload.id {
817
52528
        if let Response::Api(response) = &payload.wrapped {
818
10
            if let Some(custom_api_callback) = custom_api_callback {
819
                custom_api_callback
820
                    .request_response_received(response)
821
                    .await;
822
10
            }
823
52518
        }
824

            
825
52528
        let request = {
826
52528
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
827
52528
            outstanding_requests
828
52528
                .remove(&payload_id)
829
52528
                .expect("missing responder")
830
52528
        };
831
52528
        drop(request.responder.send(Ok(payload.wrapped)));
832
    } else {
833
28
        match payload.wrapped {
834
            Response::Api(response) => {
835
                if let Some(custom_api_callback) = custom_api_callback {
836
                    custom_api_callback.response_received(response).await;
837
                }
838
            }
839
            Response::Database(bonsaidb_core::networking::DatabaseResponse::MessageReceived {
840
28
                subscriber_id,
841
28
                topic,
842
28
                payload,
843
            }) => {
844
28
                let mut subscribers = fast_async_lock!(subscribers);
845
28
                if let Some(sender) = subscribers.get(&subscriber_id) {
846
28
                    if sender
847
28
                        .send(std::sync::Arc::new(bonsaidb_core::circulate::Message {
848
28
                            topic,
849
28
                            payload: payload.into_vec(),
850
28
                        }))
851
28
                        .is_err()
852
                    {
853
                        subscribers.remove(&subscriber_id);
854
28
                    }
855
                }
856
            }
857
            _ => {
858
                log::error!("unexpected adhoc response");
859
            }
860
        }
861
    }
862
52556
}
863

            
864
/// A handler of [`CustomApi`] responses.
865
#[async_trait]
866
pub trait CustomApiCallback<A: CustomApi>: Send + Sync + 'static {
867
    /// An out-of-band `response` was received. This happens when the server
868
    /// sends a response that isn't in response to a request.
869
    async fn response_received(&self, response: CustomApiResult<A>);
870

            
871
    /// A response was received. Unlike in `response_received` this response
872
    /// will be returned to the original requestor. This is invoked before the
873
    /// requestor recives the response.
874
    #[allow(unused_variables)]
875
    async fn request_response_received(&self, response: &CustomApiResult<A>) {
876
        // This is provided in case you'd like to see a response always, even if
877
        // it is also being handled by the code that made the request.
878
    }
879
}
880

            
881
#[async_trait]
882
impl<F, T> CustomApiCallback<T> for F
883
where
884
    F: Fn(CustomApiResult<T>) + Send + Sync + 'static,
885
    T: CustomApi,
886
{
887
    async fn response_received(&self, response: CustomApiResult<T>) {
888
        self(response);
889
    }
890
}
891

            
892
#[async_trait]
893
impl<T> CustomApiCallback<T> for ()
894
where
895
    T: CustomApi,
896
{
897
    async fn response_received(&self, _response: CustomApiResult<T>) {}
898
}