1
use std::collections::HashMap;
2
use std::marker::PhantomData;
3
use std::sync::Arc;
4
use std::time::Duration;
5

            
6
use bonsaidb_core::api;
7
use bonsaidb_core::api::ApiName;
8
use bonsaidb_core::networking::CURRENT_PROTOCOL_VERSION;
9
#[cfg(not(target_arch = "wasm32"))]
10
use fabruic::Certificate;
11
#[cfg(not(target_arch = "wasm32"))]
12
use tokio::runtime::Handle;
13
use url::Url;
14

            
15
use crate::client::{AnyApiCallback, ApiCallback};
16
#[cfg(not(target_arch = "wasm32"))]
17
use crate::BlockingClient;
18
use crate::{AsyncClient, Error};
19

            
20
/// A type marker for [`Builder`] indicating the returned client should be an
21
/// [`AsyncClient`].
22
pub struct Async;
23

            
24
/// A type marker for [`Builder`] indicating the returned client should be an
25
/// [`BlockingClient`].
26
#[cfg(not(target_arch = "wasm32"))]
27
pub struct Blocking;
28

            
29
/// Builder for a [`BlockingClient`] or an [`AsyncClient`].
30
#[must_use]
31
pub struct Builder<AsyncMode> {
32
    url: Url,
33
    protocol_version: &'static str,
34
    custom_apis: HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>,
35
    connect_timeout: Option<Duration>,
36
    request_timeout: Option<Duration>,
37
    #[cfg(not(target_arch = "wasm32"))]
38
    certificate: Option<fabruic::Certificate>,
39
    #[cfg(not(target_arch = "wasm32"))]
40
    tokio: Option<Handle>,
41
    mode: PhantomData<AsyncMode>,
42
}
43

            
44
impl<AsyncMode> Builder<AsyncMode> {
45
    /// Creates a new builder for a client connecting to `url`.
46
3390
    pub(crate) fn new(url: Url) -> Self {
47
3390
        Self {
48
3390
            url,
49
3390
            protocol_version: CURRENT_PROTOCOL_VERSION,
50
3390
            custom_apis: HashMap::new(),
51
3390
            request_timeout: None,
52
3390
            connect_timeout: None,
53
3390
            #[cfg(not(target_arch = "wasm32"))]
54
3390
            certificate: None,
55
3390
            #[cfg(not(target_arch = "wasm32"))]
56
3390
            tokio: None,
57
3390
            mode: PhantomData,
58
3390
        }
59
3390
    }
60

            
61
    /// Specifies the tokio runtime this client should use for its async tasks.
62
    /// If not specified, `Client` will try to acquire a handle via
63
    /// `tokio::runtime::Handle::try_current()`.
64
    #[cfg(not(target_arch = "wasm32"))]
65
    #[allow(clippy::missing_const_for_fn)]
66
    pub fn with_runtime(mut self, handle: Handle) -> Self {
67
        self.tokio = Some(handle);
68
        self
69
    }
70

            
71
    /// Enables using a [`Api`](api::Api) with this client. If you want to
72
    /// receive out-of-band API requests, set a callback using
73
    /// `with_custom_api_callback` instead.
74
1
    pub fn with_api<Api: api::Api>(mut self) -> Self {
75
1
        self.custom_apis.insert(Api::name(), None);
76
1
        self
77
1
    }
78

            
79
    /// Enables using a [`Api`](api::Api) with this client. `callback` will be
80
    /// invoked when custom API responses are received from the server.
81
    pub fn with_api_callback<Api: api::Api>(mut self, callback: ApiCallback<Api>) -> Self {
82
        self.custom_apis
83
            .insert(Api::name(), Some(Arc::new(callback)));
84
        self
85
    }
86

            
87
    /// Connects to a server using a pinned `certificate`. Only supported with BonsaiDb protocol-based connections.
88
    #[cfg(not(target_arch = "wasm32"))]
89
    #[allow(clippy::missing_const_for_fn)]
90
78
    pub fn with_certificate(mut self, certificate: Certificate) -> Self {
91
78
        self.certificate = Some(certificate);
92
78
        self
93
78
    }
94

            
95
    /// Overrides the protocol version. Only for testing purposes.
96
    #[cfg(feature = "test-util")]
97
    #[allow(clippy::missing_const_for_fn)]
98
2
    pub fn with_protocol_version(mut self, version: &'static str) -> Self {
99
2
        self.protocol_version = version;
100
2
        self
101
2
    }
102

            
103
    /// Sets the request timeout for the client.
104
    ///
105
    /// If not specified, requests will time out after 60 seconds.
106
4
    pub fn with_request_timeout(mut self, timeout: impl Into<Duration>) -> Self {
107
4
        self.request_timeout = Some(timeout.into());
108
4
        self
109
4
    }
110

            
111
    /// Sets the connection timeout for the client.
112
    ///
113
    /// If not specified, the client will time out after 60 seconds if a
114
    /// connection cannot be established.
115
4
    pub fn with_connect_timeout(mut self, timeout: impl Into<Duration>) -> Self {
116
4
        self.connect_timeout = Some(timeout.into());
117
4
        self
118
4
    }
119

            
120
3390
    fn finish_internal(self) -> Result<AsyncClient, Error> {
121
3390
        AsyncClient::new_from_parts(
122
3390
            self.url,
123
3390
            self.protocol_version,
124
3390
            self.custom_apis,
125
3390
            self.connect_timeout,
126
3390
            self.request_timeout,
127
3390
            #[cfg(not(target_arch = "wasm32"))]
128
3390
            self.certificate,
129
3390
            #[cfg(not(target_arch = "wasm32"))]
130
3390
            self.tokio.or_else(|| Handle::try_current().ok()),
131
3390
        )
132
3390
    }
133
}
134

            
135
#[cfg(not(target_arch = "wasm32"))]
136
impl Builder<Blocking> {
137
    /// Finishes building the client for use in a blocking (not async) context.
138
120
    pub fn build(self) -> Result<BlockingClient, Error> {
139
120
        self.finish_internal().map(BlockingClient)
140
120
    }
141
}
142

            
143
impl Builder<Async> {
144
    /// Finishes building the client for use in a tokio async context.
145
3270
    pub fn build(self) -> Result<AsyncClient, Error> {
146
3270
        self.finish_internal()
147
3270
    }
148
}