1
#[cfg(feature = "test-util")]
2
use std::sync::atomic::AtomicBool;
3
use std::{
4
    any::TypeId,
5
    collections::HashMap,
6
    fmt::Debug,
7
    marker::PhantomData,
8
    ops::Deref,
9
    sync::{
10
        atomic::{AtomicU32, Ordering},
11
        Arc,
12
    },
13
};
14

            
15
use async_lock::Mutex;
16
use async_trait::async_trait;
17
#[cfg(feature = "password-hashing")]
18
use bonsaidb_core::connection::{Authenticated, Authentication};
19
use bonsaidb_core::{
20
    connection::{Database, StorageConnection},
21
    custom_api::{CustomApi, CustomApiResult},
22
    networking::{
23
        self, Payload, Request, Response, ServerRequest, ServerResponse, CURRENT_PROTOCOL_VERSION,
24
    },
25
    permissions::Permissions,
26
    schema::{NamedReference, Schema, SchemaName, Schematic},
27
};
28
use bonsaidb_utils::fast_async_lock;
29
use derive_where::derive_where;
30
use flume::Sender;
31
#[cfg(not(target_arch = "wasm32"))]
32
use tokio::task::JoinHandle;
33
use url::Url;
34

            
35
pub use self::remote_database::{RemoteDatabase, RemoteSubscriber};
36
use crate::{error::Error, Builder};
37

            
38
#[cfg(not(target_arch = "wasm32"))]
39
mod quic_worker;
40
mod remote_database;
41
#[cfg(all(feature = "websockets", not(target_arch = "wasm32")))]
42
mod tungstenite_worker;
43
#[cfg(all(feature = "websockets", target_arch = "wasm32"))]
44
mod wasm_websocket_worker;
45

            
46
234
#[derive(Debug, Clone, Default)]
47
pub struct SubscriberMap(Arc<Mutex<HashMap<u64, flume::Sender<Arc<Message>>>>>);
48

            
49
impl SubscriberMap {
50
2299
    pub async fn clear(&self) {
51
121
        let mut data = fast_async_lock!(self);
52
121
        data.clear();
53
121
    }
54
}
55

            
56
impl Deref for SubscriberMap {
57
    type Target = Mutex<HashMap<u64, flume::Sender<Arc<Message>>>>;
58

            
59
3097
    fn deref(&self) -> &Self::Target {
60
3097
        &self.0
61
3097
    }
62
}
63

            
64
use bonsaidb_core::{circulate::Message, networking::DatabaseRequest};
65

            
66
#[cfg(all(feature = "websockets", not(target_arch = "wasm32")))]
67
pub type WebSocketError = tokio_tungstenite::tungstenite::Error;
68

            
69
#[cfg(all(feature = "websockets", target_arch = "wasm32"))]
70
pub type WebSocketError = wasm_websocket_worker::WebSocketError;
71

            
72
/// Client for connecting to a `BonsaiDb` server.
73
#[derive(Debug)]
74
916
#[derive_where(Clone)]
75
pub struct Client<A: CustomApi = ()> {
76
    pub(crate) data: Arc<Data<A>>,
77
}
78

            
79
impl<A> PartialEq for Client<A>
80
where
81
    A: CustomApi,
82
{
83
    fn eq(&self, other: &Self) -> bool {
84
        Arc::ptr_eq(&self.data, &other.data)
85
    }
86
}
87

            
88
#[derive(Debug)]
89
pub struct Data<A: CustomApi> {
90
    request_sender: Sender<PendingRequest<A>>,
91
    #[cfg(not(target_arch = "wasm32"))]
92
    _worker: CancellableHandle<Result<(), Error<A::Error>>>,
93
    effective_permissions: Mutex<Option<Permissions>>,
94
    schemas: Mutex<HashMap<TypeId, Arc<Schematic>>>,
95
    request_id: AtomicU32,
96
    subscribers: SubscriberMap,
97
    #[cfg(feature = "test-util")]
98
    background_task_running: Arc<AtomicBool>,
99
}
100

            
101
impl Client<()> {
102
    /// Returns a builder for a new client connecting to `url`.
103
1653
    pub fn build(url: Url) -> Builder<()> {
104
1653
        Builder::new(url)
105
1653
    }
106
}
107

            
108
impl<A: CustomApi> Client<A> {
109
    /// Initialize a client connecting to `url`. This client can be shared by
110
    /// cloning it. All requests are done asynchronously over the same
111
    /// connection.
112
    ///
113
    /// If the client has an error connecting, the first request made will
114
    /// present that error. If the client disconnects while processing requests,
115
    /// all requests being processed will exit and return
116
    /// [`Error::Disconnected`]. The client will automatically try reconnecting.
117
    ///
118
    /// The goal of this design of this reconnection strategy is to make it
119
    /// easier to build resilliant apps. By allowing existing Client instances
120
    /// to recover and reconnect, each component of the apps built can adopt a
121
    /// "retry-to-recover" design, or "abort-and-fail" depending on how critical
122
    /// the database is to operation.
123
31
    pub async fn new(url: Url) -> Result<Self, Error<A::Error>> {
124
31
        Self::new_from_parts(
125
31
            url,
126
31
            CURRENT_PROTOCOL_VERSION,
127
31
            #[cfg(not(target_arch = "wasm32"))]
128
31
            None,
129
31
            None,
130
31
        )
131
        .await
132
31
    }
133

            
134
    /// Initialize a client connecting to `url` with `certificate` being used to
135
    /// validate and encrypt the connection. This client can be shared by
136
    /// cloning it. All requests are done asynchronously over the same
137
    /// connection.
138
    ///
139
    /// If the client has an error connecting, the first request made will
140
    /// present that error. If the client disconnects while processing requests,
141
    /// all requests being processed will exit and return
142
    /// [`Error::Disconnected`]. The client will automatically try reconnecting.
143
    ///
144
    /// The goal of this design of this reconnection strategy is to make it
145
    /// easier to build resilliant apps. By allowing existing Client instances
146
    /// to recover and reconnect, each component of the apps built can adopt a
147
    /// "retry-to-recover" design, or "abort-and-fail" depending on how critical
148
    /// the database is to operation.
149
118
    pub(crate) async fn new_from_parts(
150
118
        url: Url,
151
118
        protocol_version: &'static str,
152
118
        custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
153
118
        #[cfg(not(target_arch = "wasm32"))] certificate: Option<fabruic::Certificate>,
154
118
    ) -> Result<Self, Error<A::Error>> {
155
118
        match url.scheme() {
156
118
            #[cfg(not(target_arch = "wasm32"))]
157
118
            "bonsaidb" => Ok(Self::new_bonsai_client(
158
62
                url,
159
62
                protocol_version,
160
62
                certificate,
161
62
                custom_api_callback,
162
62
            )),
163
            #[cfg(feature = "websockets")]
164
56
            "wss" | "ws" => {
165
56
                Self::new_websocket_client(url, protocol_version, custom_api_callback).await
166
            }
167
            other => {
168
                return Err(Error::InvalidUrl(format!("unsupported scheme {}", other)));
169
            }
170
        }
171
118
    }
172

            
173
    #[cfg(not(target_arch = "wasm32"))]
174
62
    fn new_bonsai_client(
175
62
        url: Url,
176
62
        protocol_version: &'static str,
177
62
        certificate: Option<fabruic::Certificate>,
178
62
        custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
179
62
    ) -> Self {
180
62
        let (request_sender, request_receiver) = flume::unbounded();
181
62

            
182
62
        let subscribers = SubscriberMap::default();
183
62
        let worker = tokio::task::spawn(quic_worker::reconnecting_client_loop(
184
62
            url,
185
62
            protocol_version,
186
62
            certificate,
187
62
            request_receiver,
188
62
            custom_api_callback,
189
62
            subscribers.clone(),
190
62
        ));
191
62

            
192
62
        #[cfg(feature = "test-util")]
193
62
        let background_task_running = Arc::new(AtomicBool::new(true));
194
62

            
195
62
        Self {
196
62
            data: Arc::new(Data {
197
62
                request_sender,
198
62
                _worker: CancellableHandle {
199
62
                    worker,
200
62
                    #[cfg(feature = "test-util")]
201
62
                    background_task_running: background_task_running.clone(),
202
62
                },
203
62
                schemas: Mutex::default(),
204
62
                request_id: AtomicU32::default(),
205
62
                effective_permissions: Mutex::default(),
206
62
                subscribers,
207
62
                #[cfg(feature = "test-util")]
208
62
                background_task_running,
209
62
            }),
210
62
        }
211
62
    }
212

            
213
    #[cfg(all(feature = "websockets", not(target_arch = "wasm32")))]
214
56
    async fn new_websocket_client(
215
56
        url: Url,
216
56
        protocol_version: &'static str,
217
56
        custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
218
56
    ) -> Result<Self, Error<A::Error>> {
219
56
        let (request_sender, request_receiver) = flume::unbounded();
220
56

            
221
56
        let subscribers = SubscriberMap::default();
222
56

            
223
56
        let worker = tokio::task::spawn(tungstenite_worker::reconnecting_client_loop(
224
56
            url,
225
56
            protocol_version,
226
56
            request_receiver,
227
56
            custom_api_callback,
228
56
            subscribers.clone(),
229
56
        ));
230
56

            
231
56
        #[cfg(feature = "test-util")]
232
56
        let background_task_running = Arc::new(AtomicBool::new(true));
233
56

            
234
56
        let client = Self {
235
56
            data: Arc::new(Data {
236
56
                request_sender,
237
56
                #[cfg(not(target_arch = "wasm32"))]
238
56
                _worker: CancellableHandle {
239
56
                    worker,
240
56
                    #[cfg(feature = "test-util")]
241
56
                    background_task_running: background_task_running.clone(),
242
56
                },
243
56
                schemas: Mutex::default(),
244
56
                request_id: AtomicU32::default(),
245
56
                effective_permissions: Mutex::default(),
246
56
                subscribers,
247
56
                #[cfg(feature = "test-util")]
248
56
                background_task_running,
249
56
            }),
250
56
        };
251
56

            
252
56
        Ok(client)
253
56
    }
254

            
255
    #[cfg(all(feature = "websockets", target_arch = "wasm32"))]
256
    async fn new_websocket_client(
257
        url: Url,
258
        protocol_version: &'static str,
259
        custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
260
    ) -> Result<Self, Error<A::Error>> {
261
        let (request_sender, request_receiver) = flume::unbounded();
262

            
263
        let subscribers = SubscriberMap::default();
264

            
265
        wasm_websocket_worker::spawn_client(
266
            Arc::new(url),
267
            protocol_version,
268
            request_receiver,
269
            custom_api_callback.clone(),
270
            subscribers.clone(),
271
        );
272

            
273
        #[cfg(feature = "test-util")]
274
        let background_task_running = Arc::new(AtomicBool::new(true));
275

            
276
        let client = Self {
277
            data: Arc::new(Data {
278
                request_sender,
279
                #[cfg(not(target_arch = "wasm32"))]
280
                worker: CancellableHandle {
281
                    worker,
282
                    #[cfg(feature = "test-util")]
283
                    background_task_running: background_task_running.clone(),
284
                },
285
                schemas: Mutex::default(),
286
                request_id: AtomicU32::default(),
287
                effective_permissions: Mutex::default(),
288
                subscribers,
289
                #[cfg(feature = "test-util")]
290
                background_task_running,
291
            }),
292
        };
293

            
294
        Ok(client)
295
    }
296

            
297
51446
    async fn send_request(
298
51446
        &self,
299
51446
        request: Request<<A as CustomApi>::Request>,
300
51446
    ) -> Result<Response<CustomApiResult<A>>, Error<A::Error>> {
301
51446
        let (result_sender, result_receiver) = flume::bounded(1);
302
51446
        let id = self.data.request_id.fetch_add(1, Ordering::SeqCst);
303
51446
        self.data.request_sender.send(PendingRequest {
304
51446
            request: Payload {
305
51446
                id: Some(id),
306
51446
                wrapped: request,
307
51446
            },
308
51446
            responder: result_sender.clone(),
309
51446
            _phantom: PhantomData,
310
51446
        })?;
311

            
312
69880
        result_receiver.recv_async().await?
313
51447
    }
314

            
315
    /// Sends an api `request`.
316
14
    pub async fn send_api_request(
317
14
        &self,
318
14
        request: <A as CustomApi>::Request,
319
14
    ) -> Result<A::Response, Error<A::Error>> {
320
18
        match self.send_request(Request::Api(request)).await? {
321
10
            Response::Api(response) => response.map_err(Error::Api),
322
4
            Response::Error(err) => Err(Error::Core(err)),
323
            other => Err(Error::Network(networking::Error::UnexpectedResponse(
324
                format!("{:?}", other),
325
            ))),
326
        }
327
14
    }
328

            
329
    /// Returns the current effective permissions for the client. Returns None
330
    /// if unauthenticated.
331
    pub async fn effective_permissions(&self) -> Option<Permissions> {
332
        let effective_permissions = fast_async_lock!(self.data.effective_permissions);
333
        effective_permissions.clone()
334
    }
335

            
336
    #[cfg(feature = "test-util")]
337
    #[doc(hidden)]
338
    #[must_use]
339
    pub fn background_task_running(&self) -> Arc<AtomicBool> {
340
        self.data.background_task_running.clone()
341
    }
342

            
343
14
    pub(crate) async fn register_subscriber(&self, id: u64, sender: flume::Sender<Arc<Message>>) {
344
14
        let mut subscribers = fast_async_lock!(self.data.subscribers);
345
14
        subscribers.insert(id, sender);
346
14
    }
347

            
348
    pub(crate) async fn unregister_subscriber(&self, database: String, id: u64) {
349
        drop(
350
            self.send_request(Request::Database {
351
                database,
352
                request: DatabaseRequest::UnregisterSubscriber { subscriber_id: id },
353
            })
354
            .await,
355
        );
356
        let mut subscribers = fast_async_lock!(self.data.subscribers);
357
        subscribers.remove(&id);
358
    }
359
}
360

            
361
#[async_trait]
362
impl<A: CustomApi> StorageConnection for Client<A> {
363
    type Database = RemoteDatabase<A>;
364

            
365
570
    async fn create_database_with_schema(
366
570
        &self,
367
570
        name: &str,
368
570
        schema: SchemaName,
369
570
        only_if_needed: bool,
370
570
    ) -> Result<(), bonsaidb_core::Error> {
371
570
        match self
372
570
            .send_request(Request::Server(ServerRequest::CreateDatabase {
373
570
                database: Database {
374
570
                    name: name.to_string(),
375
570
                    schema,
376
570
                },
377
570
                only_if_needed,
378
1872
            }))
379
1872
            .await?
380
        {
381
564
            Response::Server(ServerResponse::DatabaseCreated { .. }) => Ok(()),
382
6
            Response::Error(err) => Err(err),
383
            other => Err(bonsaidb_core::Error::Networking(
384
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
385
            )),
386
        }
387
1140
    }
388

            
389
616
    async fn database<DB: Schema>(
390
616
        &self,
391
616
        name: &str,
392
616
    ) -> Result<Self::Database, bonsaidb_core::Error> {
393
616
        let mut schemas = fast_async_lock!(self.data.schemas);
394
616
        let type_id = TypeId::of::<DB>();
395
616
        let schematic = if let Some(schematic) = schemas.get(&type_id) {
396
500
            schematic.clone()
397
        } else {
398
116
            let schematic = Arc::new(DB::schematic()?);
399
116
            schemas.insert(type_id, schematic.clone());
400
116
            schematic
401
        };
402
616
        Ok(RemoteDatabase::new(
403
616
            self.clone(),
404
616
            name.to_string(),
405
616
            schematic,
406
616
        ))
407
1232
    }
408

            
409
504
    async fn delete_database(&self, name: &str) -> Result<(), bonsaidb_core::Error> {
410
504
        match self
411
504
            .send_request(Request::Server(ServerRequest::DeleteDatabase {
412
504
                name: name.to_string(),
413
1830
            }))
414
1830
            .await?
415
        {
416
502
            Response::Server(ServerResponse::DatabaseDeleted { .. }) => Ok(()),
417
2
            Response::Error(err) => Err(err),
418
            other => Err(bonsaidb_core::Error::Networking(
419
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
420
            )),
421
        }
422
1008
    }
423

            
424
2
    async fn list_databases(&self) -> Result<Vec<Database>, bonsaidb_core::Error> {
425
2
        match self
426
2
            .send_request(Request::Server(ServerRequest::ListDatabases))
427
2
            .await?
428
        {
429
2
            Response::Server(ServerResponse::Databases(databases)) => Ok(databases),
430
            Response::Error(err) => Err(err),
431
            other => Err(bonsaidb_core::Error::Networking(
432
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
433
            )),
434
        }
435
4
    }
436

            
437
2
    async fn list_available_schemas(&self) -> Result<Vec<SchemaName>, bonsaidb_core::Error> {
438
2
        match self
439
2
            .send_request(Request::Server(ServerRequest::ListAvailableSchemas))
440
2
            .await?
441
        {
442
2
            Response::Server(ServerResponse::AvailableSchemas(schemas)) => Ok(schemas),
443
            Response::Error(err) => Err(err),
444
            other => Err(bonsaidb_core::Error::Networking(
445
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
446
            )),
447
        }
448
4
    }
449

            
450
4
    async fn create_user(&self, username: &str) -> Result<u64, bonsaidb_core::Error> {
451
4
        match self
452
4
            .send_request(Request::Server(ServerRequest::CreateUser {
453
4
                username: username.to_string(),
454
4
            }))
455
4
            .await?
456
        {
457
3
            Response::Server(ServerResponse::UserCreated { id }) => Ok(id),
458
1
            Response::Error(err) => Err(err),
459
            other => Err(bonsaidb_core::Error::Networking(
460
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
461
            )),
462
        }
463
8
    }
464

            
465
    #[cfg(feature = "password-hashing")]
466
    async fn set_user_password<'user, U: Into<NamedReference<'user>> + Send + Sync>(
467
        &self,
468
        user: U,
469
        password: bonsaidb_core::connection::SensitiveString,
470
    ) -> Result<(), bonsaidb_core::Error> {
471
        match self
472
            .send_request(Request::Server(ServerRequest::SetUserPassword {
473
                user: user.into().into_owned(),
474
                password,
475
            }))
476
            .await?
477
        {
478
            Response::Ok => Ok(()),
479
            Response::Error(err) => Err(err),
480
            other => Err(bonsaidb_core::Error::Networking(
481
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
482
            )),
483
        }
484
    }
485

            
486
    #[cfg(feature = "password-hashing")]
487
5
    async fn authenticate<'user, U: Into<NamedReference<'user>> + Send + Sync>(
488
5
        &self,
489
5
        user: U,
490
5
        authentication: Authentication,
491
5
    ) -> Result<Authenticated, bonsaidb_core::Error> {
492
5
        match self
493
5
            .send_request(Request::Server(ServerRequest::Authenticate {
494
5
                user: user.into().into_owned(),
495
5
                authentication,
496
12
            }))
497
12
            .await?
498
        {
499
5
            Response::Server(ServerResponse::Authenticated(response)) => Ok(response),
500
            Response::Error(err) => Err(err),
501
            other => Err(bonsaidb_core::Error::Networking(
502
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
503
            )),
504
        }
505
10
    }
506

            
507
4
    async fn add_permission_group_to_user<
508
4
        'user,
509
4
        'group,
510
4
        U: Into<NamedReference<'user>> + Send + Sync,
511
4
        G: Into<NamedReference<'group>> + Send + Sync,
512
4
    >(
513
4
        &self,
514
4
        user: U,
515
4
        permission_group: G,
516
4
    ) -> Result<(), bonsaidb_core::Error> {
517
4
        match self
518
4
            .send_request(Request::Server(
519
4
                ServerRequest::AlterUserPermissionGroupMembership {
520
4
                    user: user.into().into_owned(),
521
4
                    group: permission_group.into().into_owned(),
522
4
                    should_be_member: true,
523
4
                },
524
4
            ))
525
4
            .await?
526
        {
527
4
            Response::Ok => Ok(()),
528
            Response::Error(err) => Err(err),
529
            other => Err(bonsaidb_core::Error::Networking(
530
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
531
            )),
532
        }
533
8
    }
534

            
535
4
    async fn remove_permission_group_from_user<
536
4
        'user,
537
4
        'group,
538
4
        U: Into<NamedReference<'user>> + Send + Sync,
539
4
        G: Into<NamedReference<'group>> + Send + Sync,
540
4
    >(
541
4
        &self,
542
4
        user: U,
543
4
        permission_group: G,
544
4
    ) -> Result<(), bonsaidb_core::Error> {
545
4
        match self
546
4
            .send_request(Request::Server(
547
4
                ServerRequest::AlterUserPermissionGroupMembership {
548
4
                    user: user.into().into_owned(),
549
4
                    group: permission_group.into().into_owned(),
550
4
                    should_be_member: false,
551
4
                },
552
4
            ))
553
4
            .await?
554
        {
555
4
            Response::Ok => Ok(()),
556
            Response::Error(err) => Err(err),
557
            other => Err(bonsaidb_core::Error::Networking(
558
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
559
            )),
560
        }
561
8
    }
562

            
563
4
    async fn add_role_to_user<
564
4
        'user,
565
4
        'group,
566
4
        U: Into<NamedReference<'user>> + Send + Sync,
567
4
        G: Into<NamedReference<'group>> + Send + Sync,
568
4
    >(
569
4
        &self,
570
4
        user: U,
571
4
        role: G,
572
4
    ) -> Result<(), bonsaidb_core::Error> {
573
4
        match self
574
4
            .send_request(Request::Server(ServerRequest::AlterUserRoleMembership {
575
4
                user: user.into().into_owned(),
576
4
                role: role.into().into_owned(),
577
4
                should_be_member: true,
578
4
            }))
579
4
            .await?
580
        {
581
4
            Response::Ok => Ok(()),
582
            Response::Error(err) => Err(err),
583
            other => Err(bonsaidb_core::Error::Networking(
584
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
585
            )),
586
        }
587
8
    }
588

            
589
4
    async fn remove_role_from_user<
590
4
        'user,
591
4
        'group,
592
4
        U: Into<NamedReference<'user>> + Send + Sync,
593
4
        G: Into<NamedReference<'group>> + Send + Sync,
594
4
    >(
595
4
        &self,
596
4
        user: U,
597
4
        role: G,
598
4
    ) -> Result<(), bonsaidb_core::Error> {
599
4
        match self
600
4
            .send_request(Request::Server(ServerRequest::AlterUserRoleMembership {
601
4
                user: user.into().into_owned(),
602
4
                role: role.into().into_owned(),
603
4
                should_be_member: false,
604
4
            }))
605
4
            .await?
606
        {
607
4
            Response::Ok => Ok(()),
608
            Response::Error(err) => Err(err),
609
            other => Err(bonsaidb_core::Error::Networking(
610
                networking::Error::UnexpectedResponse(format!("{:?}", other)),
611
            )),
612
        }
613
8
    }
614
}
615

            
616
type OutstandingRequestMap<Api> = HashMap<u32, PendingRequest<Api>>;
617
type OutstandingRequestMapHandle<Api> = Arc<Mutex<OutstandingRequestMap<Api>>>;
618
type PendingRequestResponder<Api> =
619
    Sender<Result<Response<CustomApiResult<Api>>, Error<<Api as CustomApi>::Error>>>;
620

            
621
#[derive(Debug)]
622
pub struct PendingRequest<Api: CustomApi> {
623
    request: Payload<Request<Api::Request>>,
624
    responder: PendingRequestResponder<Api>,
625
    _phantom: PhantomData<Api>,
626
}
627

            
628
#[cfg(not(target_arch = "wasm32"))]
629
#[derive(Debug)]
630
struct CancellableHandle<T> {
631
    worker: JoinHandle<T>,
632
    #[cfg(feature = "test-util")]
633
    background_task_running: Arc<AtomicBool>,
634
}
635

            
636
#[cfg(not(target_arch = "wasm32"))]
637
impl<T> Drop for CancellableHandle<T> {
638
1755
    fn drop(&mut self) {
639
1755
        self.worker.abort();
640
1755
        #[cfg(feature = "test-util")]
641
1755
        self.background_task_running.store(false, Ordering::Release);
642
1755
    }
643
}
644

            
645
51473
async fn process_response_payload<A: CustomApi>(
646
51473
    payload: Payload<Response<CustomApiResult<A>>>,
647
51473
    outstanding_requests: &OutstandingRequestMapHandle<A>,
648
51473
    custom_api_callback: Option<&dyn CustomApiCallback<A>>,
649
51473
    subscribers: &SubscriberMap,
650
51473
) {
651
51473
    if let Some(payload_id) = payload.id {
652
51445
        if let Response::Api(response) = &payload.wrapped {
653
11
            if let Some(custom_api_callback) = custom_api_callback {
654
1
                custom_api_callback
655
1
                    .request_response_received(response)
656
                    .await;
657
10
            }
658
51434
        }
659

            
660
51445
        let request = {
661
51445
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
662
51445
            outstanding_requests
663
51445
                .remove(&payload_id)
664
51445
                .expect("missing responder")
665
51445
        };
666
51445
        drop(request.responder.send(Ok(payload.wrapped)));
667
    } else {
668
28
        match payload.wrapped {
669
            Response::Api(response) => {
670
                if let Some(custom_api_callback) = custom_api_callback {
671
                    custom_api_callback.response_received(response).await;
672
                }
673
            }
674
            Response::Database(bonsaidb_core::networking::DatabaseResponse::MessageReceived {
675
28
                subscriber_id,
676
28
                topic,
677
28
                payload,
678
            }) => {
679
28
                let mut subscribers = fast_async_lock!(subscribers);
680
28
                if let Some(sender) = subscribers.get(&subscriber_id) {
681
28
                    if sender
682
28
                        .send(std::sync::Arc::new(bonsaidb_core::circulate::Message {
683
28
                            topic,
684
28
                            payload: payload.into_vec(),
685
28
                        }))
686
28
                        .is_err()
687
                    {
688
                        subscribers.remove(&subscriber_id);
689
28
                    }
690
                }
691
            }
692
            _ => {
693
                log::error!("unexpected adhoc response");
694
            }
695
        }
696
    }
697
51473
}
698

            
699
/// A handler of [`CustomApi`] responses.
700
#[async_trait]
701
pub trait CustomApiCallback<A: CustomApi>: Send + Sync + 'static {
702
    /// An out-of-band `response` was received. This happens when the server
703
    /// sends a response that isn't in response to a request.
704
    async fn response_received(&self, response: CustomApiResult<A>);
705

            
706
    /// A response was received. Unlike in `response_received` this response
707
    /// will be returned to the original requestor. This is invoked before the
708
    /// requestor recives the response.
709
    #[allow(unused_variables)]
710
    async fn request_response_received(&self, response: &CustomApiResult<A>) {
711
        // This is provided in case you'd like to see a response always, even if
712
        // it is also being handled by the code that made the request.
713
    }
714
}
715

            
716
#[async_trait]
717
impl<F, T> CustomApiCallback<T> for F
718
where
719
    F: Fn(CustomApiResult<T>) + Send + Sync + 'static,
720
    T: CustomApi,
721
{
722
    async fn response_received(&self, response: CustomApiResult<T>) {
723
        self(response);
724
    }
725
}
726

            
727
#[async_trait]
728
impl<T> CustomApiCallback<T> for ()
729
where
730
    T: CustomApi,
731
{
732
    async fn response_received(&self, _response: CustomApiResult<T>) {}
733
}