1
use std::{
2
    collections::HashMap,
3
    net::SocketAddr,
4
    ops::{Deref, DerefMut},
5
    sync::Arc,
6
};
7

            
8
use async_lock::{Mutex, MutexGuard};
9
use bonsaidb_core::{
10
    api,
11
    arc_bytes::serde::Bytes,
12
    connection::{Session, SessionId},
13
    networking::MessageReceived,
14
    pubsub::{Receiver, Subscriber as _},
15
    schema::ApiName,
16
};
17
use bonsaidb_local::Subscriber;
18
use bonsaidb_utils::fast_async_lock;
19
use derive_where::derive_where;
20
use flume::Sender;
21
use parking_lot::RwLock;
22

            
23
use crate::{Backend, CustomServer, Error, NoBackend};
24

            
25
/// The ways a client can be connected to the server.
26
#[derive(Debug, PartialEq, Eq)]
27
pub enum Transport {
28
    /// A connection over BonsaiDb's QUIC-based protocol.
29
    Bonsai,
30
    /// A connection over WebSockets.
31
    #[cfg(feature = "websockets")]
32
    WebSocket,
33
}
34

            
35
/// A connected database client.
36
#[derive(Debug)]
37
71457
#[derive_where(Clone)]
38
pub struct ConnectedClient<B: Backend = NoBackend> {
39
    data: Arc<Data<B>>,
40
}
41

            
42
#[derive(Debug)]
43
struct Data<B: Backend = NoBackend> {
44
    id: u32,
45
    sessions: RwLock<HashMap<Option<SessionId>, ClientSession>>,
46
    address: SocketAddr,
47
    transport: Transport,
48
    response_sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
49
    client_data: Mutex<Option<B::ClientData>>,
50
}
51

            
52
#[derive(Debug)]
53
struct ClientSession {
54
    session: Session,
55
    subscribers: HashMap<u64, Subscriber>,
56
}
57

            
58
impl<B: Backend> ConnectedClient<B> {
59
    /// Returns the address of the connected client.
60
    #[must_use]
61
    pub fn address(&self) -> &SocketAddr {
62
        &self.data.address
63
    }
64

            
65
    /// Returns the transport method the client is connected via.
66
    #[must_use]
67
    pub fn transport(&self) -> &Transport {
68
        &self.data.transport
69
    }
70

            
71
7
    pub(crate) fn logged_in_as(&self, session: Session) {
72
7
        let mut sessions = self.data.sessions.write();
73
7
        sessions.insert(
74
7
            session.id,
75
7
            ClientSession {
76
7
                session,
77
7
                subscribers: HashMap::default(),
78
7
            },
79
7
        );
80
7
    }
81

            
82
2
    pub(crate) fn log_out(&self, session: SessionId) {
83
2
        let mut sessions = self.data.sessions.write();
84
2
        sessions.remove(&Some(session));
85
2
    }
86

            
87
    /// Sends a custom API response to the client.
88
45
    pub fn send<Api: api::Api>(
89
45
        &self,
90
45
        session: Option<&Session>,
91
45
        response: &Api::Response,
92
45
    ) -> Result<(), Error> {
93
45
        let encoded = pot::to_vec(&Result::<&Api::Response, Api::Error>::Ok(response))?;
94
45
        self.data.response_sender.send((
95
45
            session.and_then(|session| session.id),
96
45
            Api::name(),
97
45
            Bytes::from(encoded),
98
45
        ))?;
99
45
        Ok(())
100
45
    }
101

            
102
    /// Returns a locked reference to the stored client data.
103
2
    pub async fn client_data(&self) -> LockedClientDataGuard<'_, B::ClientData> {
104
        LockedClientDataGuard(fast_async_lock!(self.data.client_data))
105
2
    }
106

            
107
    /// Looks up an active authentication session by its unique id. `None`
108
    /// represents the unauthenticated session, and the result can be used to
109
    /// check what permissions are allowed by default.
110
    #[must_use]
111
71163
    pub fn session(&self, session_id: Option<SessionId>) -> Option<Session> {
112
71163
        let sessions = self.data.sessions.read();
113
71163
        sessions.get(&session_id).map(|data| data.session.clone())
114
71163
    }
115

            
116
24
    pub(crate) fn register_subscriber(
117
24
        &self,
118
24
        subscriber: Subscriber,
119
24
        session_id: Option<SessionId>,
120
24
    ) {
121
24
        let subscriber_id = subscriber.id();
122
24
        let receiver = subscriber.receiver().clone();
123
24
        {
124
24
            let mut sessions = self.data.sessions.write();
125
24
            if let Some(client_session) = sessions.get_mut(&session_id) {
126
24
                client_session
127
24
                    .subscribers
128
24
                    .insert(subscriber.id(), subscriber);
129
24
            } else {
130
                // TODO return error for session not found.
131
                return;
132
            }
133
        }
134
24
        let task_self = self.clone();
135
24
        tokio::task::spawn(async move {
136
24
            task_self
137
46
                .forward_notifications_for(session_id, subscriber_id, receiver)
138
46
                .await;
139
24
        });
140
24
    }
141

            
142
    /// Sets the associated data for this client.
143
    pub async fn set_client_data(&self, data: B::ClientData) {
144
        let mut client_data = fast_async_lock!(self.data.client_data);
145
        *client_data = Some(data);
146
    }
147

            
148
24
    async fn forward_notifications_for(
149
24
        &self,
150
24
        session_id: Option<SessionId>,
151
24
        subscriber_id: u64,
152
24
        receiver: Receiver,
153
24
    ) {
154
24
        let session = self.session(session_id);
155
69
        while let Ok(message) = receiver.receive_async().await {
156
45
            if self
157
45
                .send::<MessageReceived>(
158
45
                    session.as_ref(),
159
45
                    &MessageReceived {
160
45
                        subscriber_id,
161
45
                        topic: Bytes::from(message.topic.0.into_vec()),
162
45
                        payload: Bytes::from(&message.payload[..]),
163
45
                    },
164
45
                )
165
45
                .is_err()
166
            {
167
                break;
168
45
            }
169
        }
170
10
    }
171

            
172
39
    pub(crate) fn subscribe_by_id(
173
39
        &self,
174
39
        subscriber_id: u64,
175
39
        topic: Bytes,
176
39
        check_session_id: Option<SessionId>,
177
39
    ) -> Result<(), crate::Error> {
178
39
        let mut sessions = self.data.sessions.write();
179
39
        if let Some(client_session) = sessions.get_mut(&check_session_id) {
180
39
            if let Some(subscriber) = client_session.subscribers.get(&subscriber_id) {
181
39
                subscriber.subscribe_to_bytes(topic.0)?;
182
39
                Ok(())
183
            } else {
184
                Err(Error::Transport(String::from("invalid subscriber id")))
185
            }
186
        } else {
187
            Err(Error::Transport(String::from("invalid session id")))
188
        }
189
39
    }
190

            
191
3
    pub(crate) fn unsubscribe_by_id(
192
3
        &self,
193
3
        subscriber_id: u64,
194
3
        topic: &[u8],
195
3
        check_session_id: Option<SessionId>,
196
3
    ) -> Result<(), crate::Error> {
197
3
        let mut sessions = self.data.sessions.write();
198
3
        if let Some(client_session) = sessions.get_mut(&check_session_id) {
199
3
            if let Some(subscriber) = client_session.subscribers.get(&subscriber_id) {
200
3
                subscriber.unsubscribe_from_bytes(topic)?;
201
3
                Ok(())
202
            } else {
203
                Err(Error::Transport(String::from("invalid subscriber id")))
204
            }
205
        } else {
206
            Err(Error::Transport(String::from("invalid session id")))
207
        }
208
3
    }
209

            
210
10
    pub(crate) fn unregister_subscriber_by_id(
211
10
        &self,
212
10
        subscriber_id: u64,
213
10
        check_session_id: Option<SessionId>,
214
10
    ) -> Result<(), crate::Error> {
215
10
        let mut sessions = self.data.sessions.write();
216
10
        if let Some(client_session) = sessions.get_mut(&check_session_id) {
217
10
            if client_session.subscribers.remove(&subscriber_id).is_some() {
218
10
                Ok(())
219
            } else {
220
                Err(Error::Transport(String::from("invalid subscriber id")))
221
            }
222
        } else {
223
            Err(Error::Transport(String::from("invalid session id")))
224
        }
225
10
    }
226
}
227

            
228
/// A locked reference to associated client data.
229
pub struct LockedClientDataGuard<'client, ClientData>(MutexGuard<'client, Option<ClientData>>);
230

            
231
impl<'client, ClientData> Deref for LockedClientDataGuard<'client, ClientData> {
232
    type Target = Option<ClientData>;
233

            
234
    fn deref(&self) -> &Self::Target {
235
        &self.0
236
    }
237
}
238

            
239
impl<'client, ClientData> DerefMut for LockedClientDataGuard<'client, ClientData> {
240
2
    fn deref_mut(&mut self) -> &mut Self::Target {
241
2
        &mut self.0
242
2
    }
243
}
244

            
245
#[derive(Debug)]
246
pub struct OwnedClient<B: Backend> {
247
    client: ConnectedClient<B>,
248
    runtime: Arc<tokio::runtime::Handle>,
249
    server: Option<CustomServer<B>>,
250
}
251

            
252
impl<B: Backend> OwnedClient<B> {
253
147
    pub(crate) fn new(
254
147
        id: u32,
255
147
        address: SocketAddr,
256
147
        transport: Transport,
257
147
        response_sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
258
147
        server: CustomServer<B>,
259
147
        default_session: Session,
260
147
    ) -> Self {
261
147
        let mut session = HashMap::new();
262
147
        session.insert(
263
147
            None,
264
147
            ClientSession {
265
147
                session: default_session,
266
147
                subscribers: HashMap::default(),
267
147
            },
268
147
        );
269
147
        Self {
270
147
            client: ConnectedClient {
271
147
                data: Arc::new(Data {
272
147
                    id,
273
147
                    address,
274
147
                    transport,
275
147
                    response_sender,
276
147
                    sessions: RwLock::new(session),
277
147
                    client_data: Mutex::default(),
278
147
                }),
279
147
            },
280
147
            runtime: Arc::new(tokio::runtime::Handle::current()),
281
147
            server: Some(server),
282
147
        }
283
147
    }
284

            
285
294
    pub fn clone(&self) -> ConnectedClient<B> {
286
294
        self.client.clone()
287
294
    }
288
}
289

            
290
impl<B: Backend> Drop for OwnedClient<B> {
291
115
    fn drop(&mut self) {
292
115
        let id = self.client.data.id;
293
115
        let server = self.server.take().unwrap();
294
115
        self.runtime.spawn(async move {
295
86
            server.disconnect_client(id).await;
296
115
        });
297
115
    }
298
}
299

            
300
impl<B: Backend> Deref for OwnedClient<B> {
301
    type Target = ConnectedClient<B>;
302

            
303
147
    fn deref(&self) -> &Self::Target {
304
147
        &self.client
305
147
    }
306
}