1
use std::{
2
    collections::HashSet,
3
    net::SocketAddr,
4
    ops::{Deref, DerefMut},
5
    sync::Arc,
6
};
7

            
8
use async_lock::{Mutex, MutexGuard, RwLock};
9
use bonsaidb_core::{custom_api::CustomApiResult, permissions::Permissions};
10
use bonsaidb_utils::{fast_async_lock, fast_async_read, fast_async_write};
11
use derive_where::derive_where;
12
use flume::Sender;
13

            
14
use crate::{Backend, CustomServer, NoBackend};
15

            
16
/// The ways a client can be connected to the server.
17
#[derive(Debug, PartialEq, Eq)]
18
pub enum Transport {
19
    /// A connection over BonsaiDb's QUIC-based protocol.
20
    Bonsai,
21
    /// A connection over WebSockets.
22
    #[cfg(feature = "websockets")]
23
    WebSocket,
24
}
25

            
26
/// A connected database client.
27
#[derive(Debug)]
28
52165
#[derive_where(Clone)]
29
pub struct ConnectedClient<B: Backend = NoBackend> {
30
    data: Arc<Data<B>>,
31
}
32

            
33
#[derive(Debug)]
34
struct Data<B: Backend = NoBackend> {
35
    id: u32,
36
    address: SocketAddr,
37
    transport: Transport,
38
    response_sender: Sender<CustomApiResult<B::CustomApi>>,
39
    auth_state: RwLock<AuthenticationState>,
40
    client_data: Mutex<Option<B::ClientData>>,
41
    subscriber_ids: Mutex<HashSet<u64>>,
42
}
43

            
44
#[derive(Debug, Default)]
45
struct AuthenticationState {
46
    user_id: Option<u64>,
47
    permissions: Permissions,
48
}
49

            
50
impl<B: Backend> ConnectedClient<B> {
51
    /// Returns the address of the connected client.
52
    #[must_use]
53
    pub fn address(&self) -> &SocketAddr {
54
        &self.data.address
55
    }
56

            
57
    /// Returns the transport method the client is connected via.
58
    #[must_use]
59
    pub fn transport(&self) -> &Transport {
60
        &self.data.transport
61
    }
62

            
63
    /// Returns the current permissions for this client. Will reflect the
64
    /// current state of authentication.
65
51931
    pub async fn permissions(&self) -> Permissions {
66
51931
        let auth_state = fast_async_read!(self.data.auth_state);
67
51931
        auth_state.permissions.clone()
68
51931
    }
69

            
70
    /// Returns the unique id of the user this client is connected as. Returns
71
    /// None if the connection isn't authenticated.
72
    pub async fn user_id(&self) -> Option<u64> {
73
        let auth_state = fast_async_read!(self.data.auth_state);
74
        auth_state.user_id
75
    }
76

            
77
5
    pub(crate) async fn logged_in_as(&self, user_id: u64, new_permissions: Permissions) {
78
5
        let mut auth_state = fast_async_write!(self.data.auth_state);
79
5
        auth_state.user_id = Some(user_id);
80
5
        auth_state.permissions = new_permissions;
81
5
    }
82

            
83
26
    pub(crate) async fn owns_subscriber(&self, subscriber_id: u64) -> bool {
84
26
        let subscriber_ids = fast_async_lock!(self.data.subscriber_ids);
85
26
        subscriber_ids.contains(&subscriber_id)
86
26
    }
87

            
88
14
    pub(crate) async fn register_subscriber(&self, subscriber_id: u64) {
89
14
        let mut subscriber_ids = fast_async_lock!(self.data.subscriber_ids);
90
14
        subscriber_ids.insert(subscriber_id);
91
14
    }
92

            
93
    pub(crate) async fn remove_subscriber(&self, subscriber_id: u64) -> bool {
94
        let mut subscriber_ids = fast_async_lock!(self.data.subscriber_ids);
95
        subscriber_ids.remove(&subscriber_id)
96
    }
97

            
98
    /// Sends a custom API response to the client.
99
    pub fn send(
100
        &self,
101
        response: CustomApiResult<B::CustomApi>,
102
    ) -> Result<(), flume::SendError<CustomApiResult<B::CustomApi>>> {
103
        self.data.response_sender.send(response)
104
    }
105

            
106
    /// Returns a locked reference to the stored client data.
107
2
    pub async fn client_data(&self) -> LockedClientDataGuard<'_, B::ClientData> {
108
        LockedClientDataGuard(fast_async_lock!(self.data.client_data))
109
2
    }
110

            
111
    /// Sets the associated data for this client.
112
    pub async fn set_client_data(&self, data: B::ClientData) {
113
        let mut client_data = fast_async_lock!(self.data.client_data);
114
        *client_data = Some(data);
115
    }
116
}
117

            
118
/// A locked reference to associated client data.
119
pub struct LockedClientDataGuard<'client, ClientData>(MutexGuard<'client, Option<ClientData>>);
120

            
121
impl<'client, ClientData> Deref for LockedClientDataGuard<'client, ClientData> {
122
    type Target = Option<ClientData>;
123

            
124
    fn deref(&self) -> &Self::Target {
125
        &self.0
126
    }
127
}
128

            
129
impl<'client, ClientData> DerefMut for LockedClientDataGuard<'client, ClientData> {
130
2
    fn deref_mut(&mut self) -> &mut Self::Target {
131
2
        &mut self.0
132
2
    }
133
}
134

            
135
#[derive(Debug)]
136
pub struct OwnedClient<B: Backend> {
137
    client: ConnectedClient<B>,
138
    runtime: tokio::runtime::Handle,
139
    server: Option<CustomServer<B>>,
140
}
141

            
142
impl<B: Backend> OwnedClient<B> {
143
116
    pub(crate) fn new(
144
116
        id: u32,
145
116
        address: SocketAddr,
146
116
        transport: Transport,
147
116
        response_sender: Sender<CustomApiResult<B::CustomApi>>,
148
116
        server: CustomServer<B>,
149
116
    ) -> Self {
150
116
        Self {
151
116
            client: ConnectedClient {
152
116
                data: Arc::new(Data {
153
116
                    id,
154
116
                    address,
155
116
                    transport,
156
116
                    response_sender,
157
116
                    auth_state: RwLock::new(AuthenticationState {
158
116
                        permissions: server.data.default_permissions.clone(),
159
116
                        user_id: None,
160
116
                    }),
161
116
                    client_data: Mutex::default(),
162
116
                    subscriber_ids: Mutex::default(),
163
116
                }),
164
116
            },
165
116
            runtime: tokio::runtime::Handle::current(),
166
116
            server: Some(server),
167
116
        }
168
116
    }
169

            
170
232
    pub fn clone(&self) -> ConnectedClient<B> {
171
232
        self.client.clone()
172
232
    }
173
}
174

            
175
impl<B: Backend> Drop for OwnedClient<B> {
176
85
    fn drop(&mut self) {
177
85
        let id = self.client.data.id;
178
85
        let server = self.server.take().unwrap();
179
85
        self.runtime.spawn(async move {
180
56
            server.disconnect_client(id).await;
181
85
        });
182
85
    }
183
}
184

            
185
impl<B: Backend> Deref for OwnedClient<B> {
186
    type Target = ConnectedClient<B>;
187

            
188
116
    fn deref(&self) -> &Self::Target {
189
116
        &self.client
190
116
    }
191
}