1
use std::collections::HashMap;
2
use std::net::SocketAddr;
3
use std::ops::{Deref, DerefMut};
4
use std::sync::atomic::{AtomicBool, Ordering};
5
use std::sync::Arc;
6

            
7
use async_lock::{Mutex, MutexGuard};
8
use bonsaidb_core::api;
9
use bonsaidb_core::api::ApiName;
10
use bonsaidb_core::arc_bytes::serde::Bytes;
11
use bonsaidb_core::connection::{Session, SessionId};
12
use bonsaidb_core::networking::MessageReceived;
13
use bonsaidb_core::pubsub::{Receiver, Subscriber as _};
14
use bonsaidb_local::Subscriber;
15
use bonsaidb_utils::fast_async_lock;
16
use derive_where::derive_where;
17
use flume::Sender;
18
use parking_lot::RwLock;
19

            
20
use crate::{Backend, CustomServer, Error, NoBackend};
21

            
22
/// The ways a client can be connected to the server.
23
#[derive(Debug, PartialEq, Eq)]
24
pub enum Transport {
25
    /// A connection over BonsaiDb's QUIC-based protocol.
26
    Bonsai,
27
    /// A connection over WebSockets.
28
    #[cfg(feature = "websockets")]
29
    WebSocket,
30
}
31

            
32
/// A connected database client.
33
#[derive(Debug)]
34
72461
#[derive_where(Clone)]
35
pub struct ConnectedClient<B: Backend = NoBackend> {
36
    data: Arc<ConnectedClientData<B>>,
37
}
38

            
39
#[derive(Debug)]
40
struct ConnectedClientData<B: Backend = NoBackend> {
41
    id: u32,
42
    sessions: RwLock<HashMap<Option<SessionId>, ClientSession>>,
43
    address: SocketAddr,
44
    transport: Transport,
45
    response_sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
46
    client_data: Mutex<Option<B::ClientData>>,
47
    connected: AtomicBool,
48
}
49

            
50
#[derive(Debug)]
51
struct ClientSession {
52
    session: Session,
53
    subscribers: HashMap<u64, Subscriber>,
54
}
55

            
56
impl<B: Backend> ConnectedClient<B> {
57
    /// Returns the address of the connected client.
58
    #[must_use]
59
    pub fn address(&self) -> &SocketAddr {
60
        &self.data.address
61
    }
62

            
63
    /// Returns the transport method the client is connected via.
64
    #[must_use]
65
    pub fn transport(&self) -> &Transport {
66
        &self.data.transport
67
    }
68

            
69
    /// Returns true if the server still believes the client is connected.
70
    #[must_use]
71
4
    pub fn connected(&self) -> bool {
72
4
        self.data.connected.load(Ordering::Relaxed)
73
4
    }
74

            
75
112
    pub(crate) fn set_disconnected(&self) {
76
112
        self.data.connected.store(false, Ordering::Relaxed);
77
112
    }
78

            
79
21
    pub(crate) fn logged_in_as(&self, session: Session) {
80
21
        let mut sessions = self.data.sessions.write();
81
21
        sessions.insert(
82
21
            session.id,
83
21
            ClientSession {
84
21
                session,
85
21
                subscribers: HashMap::default(),
86
21
            },
87
21
        );
88
21
    }
89

            
90
3
    pub(crate) fn log_out(&self, session: SessionId) -> Option<Session> {
91
3
        let mut sessions = self.data.sessions.write();
92
3
        sessions.remove(&Some(session)).map(|cs| cs.session)
93
3
    }
94

            
95
    /// Sends a custom API response to the client.
96
45
    pub fn send<Api: api::Api>(
97
45
        &self,
98
45
        session: Option<&Session>,
99
45
        response: &Api::Response,
100
45
    ) -> Result<(), Error> {
101
45
        let encoded = pot::to_vec(&Result::<&Api::Response, Api::Error>::Ok(response))?;
102
45
        self.data.response_sender.send((
103
45
            session.and_then(|session| session.id),
104
45
            Api::name(),
105
45
            Bytes::from(encoded),
106
45
        ))?;
107
45
        Ok(())
108
45
    }
109

            
110
    /// Returns a locked reference to the stored client data.
111
2
    pub async fn client_data(&self) -> LockedClientDataGuard<'_, B::ClientData> {
112
        LockedClientDataGuard(fast_async_lock!(self.data.client_data))
113
2
    }
114

            
115
    /// Looks up an active authentication session by its unique id. `None`
116
    /// represents the unauthenticated session, and the result can be used to
117
    /// check what permissions are allowed by default.
118
    #[must_use]
119
72077
    pub fn session(&self, session_id: Option<SessionId>) -> Option<Session> {
120
72077
        let sessions = self.data.sessions.read();
121
72077
        sessions.get(&session_id).map(|data| data.session.clone())
122
72077
    }
123

            
124
    /// Returns a collection of all active [`Session`]s for this client.
125
    #[must_use]
126
112
    pub fn all_sessions<C: FromIterator<Session>>(&self) -> C {
127
112
        let sessions = self.data.sessions.read();
128
128
        sessions.values().map(|s| s.session.clone()).collect()
129
112
    }
130

            
131
24
    pub(crate) fn register_subscriber(
132
24
        &self,
133
24
        subscriber: Subscriber,
134
24
        session_id: Option<SessionId>,
135
24
    ) {
136
24
        let subscriber_id = subscriber.id();
137
24
        let receiver = subscriber.receiver().clone();
138
24
        {
139
24
            let mut sessions = self.data.sessions.write();
140
24
            if let Some(client_session) = sessions.get_mut(&session_id) {
141
24
                client_session
142
24
                    .subscribers
143
24
                    .insert(subscriber.id(), subscriber);
144
24
            } else {
145
                // TODO return error for session not found.
146
                return;
147
            }
148
        }
149
24
        let task_self = self.clone();
150
24
        tokio::task::spawn(async move {
151
24
            task_self
152
24
                .forward_notifications_for(session_id, subscriber_id, receiver)
153
46
                .await;
154
24
        });
155
24
    }
156

            
157
    /// Sets the associated data for this client.
158
    pub async fn set_client_data(&self, data: B::ClientData) {
159
        let mut client_data = fast_async_lock!(self.data.client_data);
160
        *client_data = Some(data);
161
    }
162

            
163
24
    async fn forward_notifications_for(
164
24
        &self,
165
24
        session_id: Option<SessionId>,
166
24
        subscriber_id: u64,
167
24
        receiver: Receiver,
168
24
    ) {
169
24
        let session = self.session(session_id);
170
69
        while let Ok(message) = receiver.receive_async().await {
171
45
            if self
172
45
                .send::<MessageReceived>(
173
45
                    session.as_ref(),
174
45
                    &MessageReceived {
175
45
                        subscriber_id,
176
45
                        topic: Bytes::from(message.topic.0.into_vec()),
177
45
                        payload: Bytes::from(&message.payload[..]),
178
45
                    },
179
45
                )
180
45
                .is_err()
181
            {
182
                break;
183
45
            }
184
        }
185
10
    }
186

            
187
39
    pub(crate) fn subscribe_by_id(
188
39
        &self,
189
39
        subscriber_id: u64,
190
39
        topic: Bytes,
191
39
        check_session_id: Option<SessionId>,
192
39
    ) -> Result<(), crate::Error> {
193
39
        let mut sessions = self.data.sessions.write();
194
39
        if let Some(client_session) = sessions.get_mut(&check_session_id) {
195
39
            if let Some(subscriber) = client_session.subscribers.get(&subscriber_id) {
196
39
                subscriber.subscribe_to_bytes(topic.0)?;
197
39
                Ok(())
198
            } else {
199
                Err(Error::other(
200
                    "bonsaidb-server pubsub",
201
                    "invalid subscriber id",
202
                ))
203
            }
204
        } else {
205
            Err(Error::other("bonsaidb-server auth", "invalid session id"))
206
        }
207
39
    }
208

            
209
3
    pub(crate) fn unsubscribe_by_id(
210
3
        &self,
211
3
        subscriber_id: u64,
212
3
        topic: &[u8],
213
3
        check_session_id: Option<SessionId>,
214
3
    ) -> Result<(), crate::Error> {
215
3
        let mut sessions = self.data.sessions.write();
216
3
        if let Some(client_session) = sessions.get_mut(&check_session_id) {
217
3
            if let Some(subscriber) = client_session.subscribers.get(&subscriber_id) {
218
3
                subscriber.unsubscribe_from_bytes(topic)?;
219
3
                Ok(())
220
            } else {
221
                Err(Error::other(
222
                    "bonsaidb-server pubsub",
223
                    "invalid subscriber id",
224
                ))
225
            }
226
        } else {
227
            Err(Error::other("bonsaidb-server auth", "invalid session id"))
228
        }
229
3
    }
230

            
231
10
    pub(crate) fn unregister_subscriber_by_id(
232
10
        &self,
233
10
        subscriber_id: u64,
234
10
        check_session_id: Option<SessionId>,
235
10
    ) -> Result<(), crate::Error> {
236
10
        let mut sessions = self.data.sessions.write();
237
10
        if let Some(client_session) = sessions.get_mut(&check_session_id) {
238
10
            if client_session.subscribers.remove(&subscriber_id).is_some() {
239
10
                Ok(())
240
            } else {
241
                Err(Error::other(
242
                    "bonsaidb-server pubsub",
243
                    "invalid subscriber id",
244
                ))
245
            }
246
        } else {
247
            Err(Error::other("bonsaidb-server auth", "invalid session id"))
248
        }
249
10
    }
250
}
251

            
252
/// A locked reference to associated client data.
253
pub struct LockedClientDataGuard<'client, ClientData>(MutexGuard<'client, Option<ClientData>>);
254

            
255
impl<'client, ClientData> Deref for LockedClientDataGuard<'client, ClientData> {
256
    type Target = Option<ClientData>;
257

            
258
    fn deref(&self) -> &Self::Target {
259
        &self.0
260
    }
261
}
262

            
263
impl<'client, ClientData> DerefMut for LockedClientDataGuard<'client, ClientData> {
264
2
    fn deref_mut(&mut self) -> &mut Self::Target {
265
2
        &mut self.0
266
2
    }
267
}
268

            
269
#[derive(Debug)]
270
pub struct OwnedClient<B: Backend> {
271
    client: ConnectedClient<B>,
272
    runtime: Arc<tokio::runtime::Handle>,
273
    server: Option<CustomServer<B>>,
274
}
275

            
276
impl<B: Backend> OwnedClient<B> {
277
181
    pub(crate) fn new(
278
181
        id: u32,
279
181
        address: SocketAddr,
280
181
        transport: Transport,
281
181
        response_sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
282
181
        server: CustomServer<B>,
283
181
        default_session: Session,
284
181
    ) -> Self {
285
181
        let mut session = HashMap::new();
286
181
        session.insert(
287
181
            None,
288
181
            ClientSession {
289
181
                session: default_session,
290
181
                subscribers: HashMap::default(),
291
181
            },
292
181
        );
293
181
        Self {
294
181
            client: ConnectedClient {
295
181
                data: Arc::new(ConnectedClientData {
296
181
                    id,
297
181
                    address,
298
181
                    transport,
299
181
                    response_sender,
300
181
                    sessions: RwLock::new(session),
301
181
                    client_data: Mutex::default(),
302
181
                    connected: AtomicBool::new(true),
303
181
                }),
304
181
            },
305
181
            runtime: Arc::new(tokio::runtime::Handle::current()),
306
181
            server: Some(server),
307
181
        }
308
181
    }
309

            
310
362
    pub fn clone(&self) -> ConnectedClient<B> {
311
362
        self.client.clone()
312
362
    }
313
}
314

            
315
impl<B: Backend> Drop for OwnedClient<B> {
316
144
    fn drop(&mut self) {
317
144
        let id = self.client.data.id;
318
144
        let server = self.server.take().unwrap();
319
144
        self.runtime
320
144
            .spawn(async move { server.disconnect_client(id).await });
321
144
    }
322
}
323

            
324
impl<B: Backend> Deref for OwnedClient<B> {
325
    type Target = ConnectedClient<B>;
326

            
327
181
    fn deref(&self) -> &Self::Target {
328
181
        &self.client
329
181
    }
330
}