1
use bonsaidb_core::networking::{Payload, CURRENT_PROTOCOL_VERSION};
2
use futures::{SinkExt, StreamExt};
3
use tokio::io::{AsyncRead, AsyncWrite};
4
use tokio_tungstenite::tungstenite::Message;
5

            
6
use crate::server::connected_client::OwnedClient;
7
use crate::server::shutdown::{ShutdownState, ShutdownStateWatcher};
8
use crate::{Backend, CustomServer, Error, Transport};
9

            
10
impl<B: Backend> CustomServer<B> {
11
    /// Listens for websocket connections on `addr`.
12
19
    pub async fn listen_for_websockets_on<T: tokio::net::ToSocketAddrs + Send + Sync>(
13
19
        &self,
14
19
        addr: T,
15
19
        with_tls: bool,
16
19
    ) -> Result<(), Error> {
17
19
        if with_tls {
18
            self.listen_for_secure_tcp_on(addr, ()).await
19
        } else {
20
103
            self.listen_for_tcp_on(addr, ()).await
21
        }
22
3
    }
23

            
24
99
    pub(crate) async fn handle_raw_websocket_connection<
25
99
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
26
99
    >(
27
99
        &self,
28
99
        connection: S,
29
99
        peer_address: std::net::SocketAddr,
30
99
    ) -> Result<(), Error> {
31
99
        let stream = tokio_tungstenite::accept_hdr_async(connection, VersionChecker).await?;
32
25219
        self.handle_websocket(stream, peer_address).await;
33
85
        Ok(())
34
86
    }
35

            
36
    /// Handles upgrading an HTTP connection to the `WebSocket` protocol based
37
    /// on the upgrade `request`. Requires feature `hyper` to be enabled.
38
    #[cfg(feature = "hyper")]
39
1
    pub fn upgrade_websocket(
40
1
        &self,
41
1
        peer_address: std::net::SocketAddr,
42
1
        mut request: hyper::Request<hyper::Body>,
43
1
    ) -> hyper::Response<hyper::Body> {
44
1
        use hyper::header::{
45
1
            HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE,
46
1
        };
47
1
        use hyper::StatusCode;
48
1
        use tokio_tungstenite::tungstenite::protocol::Role;
49
1
        use tokio_tungstenite::WebSocketStream;
50
1

            
51
1
        let mut response = hyper::Response::new(hyper::Body::empty());
52
1
        // Send a 400 to any request that doesn't have
53
1
        // an `Upgrade` header.
54
1
        if !request.headers().contains_key(UPGRADE) {
55
            *response.status_mut() = StatusCode::BAD_REQUEST;
56
            return response;
57
1
        }
58

            
59
1
        let Some(sec_websocket_key) = request.headers_mut().remove(SEC_WEBSOCKET_KEY) else {
60
            *response.status_mut() = StatusCode::BAD_REQUEST;
61
            return response;
62
        };
63

            
64
1
        let task_self = self.clone();
65
1
        tokio::spawn(async move {
66
1
            match hyper::upgrade::on(&mut request).await {
67
1
                Ok(upgraded) => {
68
1
                    let ws = WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
69
2
                    task_self.handle_websocket(ws, peer_address).await;
70
                }
71
                Err(err) => {
72
                    log::error!("Error upgrading websocket: {:?}", err);
73
                }
74
            }
75
1
        });
76
1

            
77
1
        *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
78
1
        response
79
1
            .headers_mut()
80
1
            .insert(UPGRADE, HeaderValue::from_static("websocket"));
81
1
        response
82
1
            .headers_mut()
83
1
            .insert(CONNECTION, HeaderValue::from_static("upgrade"));
84
1
        response.headers_mut().insert(
85
1
            SEC_WEBSOCKET_ACCEPT,
86
1
            compute_websocket_accept_header(sec_websocket_key.as_bytes()),
87
1
        );
88
1

            
89
1
        response
90
1
    }
91

            
92
    /// Handles an established `tokio-tungstenite` `WebSocket` stream.
93
99
    pub async fn handle_websocket<
94
99
        S: futures::Stream<Item = Result<tokio_tungstenite::tungstenite::Message, E>>
95
99
            + futures::Sink<tokio_tungstenite::tungstenite::Message>
96
99
            + Send
97
99
            + 'static,
98
99
        E: std::fmt::Debug + Send,
99
99
    >(
100
99
        &self,
101
99
        connection: S,
102
99
        peer_address: std::net::SocketAddr,
103
99
    ) {
104
99
        let mut shutdown = self
105
99
            .data
106
99
            .shutdown
107
99
            .watcher()
108
            .await
109
99
            .expect("watcher shut down");
110
99

            
111
99
        let (mut sender, mut receiver) = connection.split();
112
99
        let (response_sender, response_receiver) = flume::unbounded();
113
99
        let (message_sender, message_receiver) = flume::unbounded();
114
99

            
115
99
        let (api_response_sender, api_response_receiver) = flume::unbounded();
116
99
        let Some(client) = self
117
99
            .initialize_client(Transport::WebSocket, peer_address, api_response_sender)
118
            .await
119
        else {
120
            return;
121
        };
122
99
        let task_sender = response_sender.clone();
123
99
        tokio::spawn(async move {
124
129
            while let Ok((session_id, name, value)) = api_response_receiver.recv_async().await {
125
30
                if task_sender
126
30
                    .send(Payload {
127
30
                        id: None,
128
30
                        session_id,
129
30
                        name,
130
30
                        value: Ok(value),
131
30
                    })
132
30
                    .is_err()
133
                {
134
                    break;
135
30
                }
136
            }
137
99
        });
138
99

            
139
99
        tokio::spawn(async move {
140
52561
            while let Ok(response) = message_receiver.recv_async().await {
141
37655
                if sender.send(response).await.is_err() {
142
2
                    break;
143
37653
                }
144
            }
145

            
146
79
            Result::<(), Error>::Ok(())
147
99
        });
148
99

            
149
99
        let task_sender = message_sender.clone();
150
99
        tokio::spawn(async move {
151
37754
            while let Ok(response) = response_receiver.recv_async().await {
152
37656
                if task_sender
153
37656
                    .send(Message::Binary(bincode::serialize(&response)?))
154
37656
                    .is_err()
155
                {
156
1
                    break;
157
37655
                }
158
            }
159

            
160
79
            Result::<(), Error>::Ok(())
161
99
        });
162
99

            
163
99
        let (request_sender, request_receiver) =
164
99
            flume::bounded::<Payload>(self.data.client_simultaneous_request_limit);
165
99

            
166
99
        self.spawn_client_request_handler(client, request_receiver, response_sender, &shutdown);
167

            
168
37728
        loop {
169
54802
            tokio::select! {
170
37713
                payload = receiver.next() => {
171
                    if let Some(payload) = payload {
172
                        match payload {
173
                            Ok(Message::Binary(binary)) => match bincode::deserialize::<Payload>(&binary) {
174
                                Ok(payload) => drop(request_sender.send_async(payload).await),
175
                                Err(err) => {
176
                                    log::error!("[server] error decoding message: {:?}", err);
177
                                    break;
178
                                }
179
                            },
180
                            Ok(Message::Close(_)) => break,
181
                            Ok(Message::Ping(payload)) => {
182
                                drop(message_sender.send(Message::Pong(payload)));
183
                            }
184
                            other => {
185
                                log::error!("[server] unexpected message: {:?}", other);
186
                                break;
187
                            }
188
                        }
189
                    } else {
190
                        return;
191
                    }
192
                },
193
1
                shutdown = shutdown.wait_for_shutdown() => {
194
                    if matches!(shutdown, ShutdownState::Shutdown) {
195
                        return;
196
                    }
197
                }
198
37728
            }
199
37728
        }
200
85
    }
201

            
202
99
    fn spawn_client_request_handler(
203
99
        &self,
204
99
        client: OwnedClient<B>,
205
99
        request_receiver: flume::Receiver<Payload>,
206
99
        response_sender: flume::Sender<Payload>,
207
99
        shutdown: &ShutdownStateWatcher,
208
99
    ) {
209
99
        tokio::spawn({
210
99
            let task_self = self.clone();
211
99
            let shutdown = shutdown.clone();
212
99
            async move {
213
99
                task_self
214
99
                    .handle_client_requests(
215
99
                        client.clone(),
216
99
                        request_receiver,
217
99
                        response_sender,
218
99
                        shutdown,
219
99
                    )
220
33901
                    .await;
221
99
            }
222
99
        });
223
99
    }
224
}
225

            
226
#[cfg(feature = "hyper")]
227
14
fn compute_websocket_accept_header(key: &[u8]) -> hyper::header::HeaderValue {
228
14
    use base64::engine::general_purpose::STANDARD as BASE64;
229
14
    use base64::Engine;
230
14
    use sha1::{Digest, Sha1};
231
14

            
232
14
    let mut digest = Sha1::default();
233
14
    digest.update(key);
234
14
    digest.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
235
14
    let encoded = BASE64.encode(digest.finalize());
236
14
    hyper::header::HeaderValue::from_str(&encoded).expect("base64 is a valid value")
237
14
}
238

            
239
struct VersionChecker;
240

            
241
impl tokio_tungstenite::tungstenite::handshake::server::Callback for VersionChecker {
242
    fn on_request(
243
        self,
244
        request: &tokio_tungstenite::tungstenite::handshake::server::Request,
245
        mut response: tokio_tungstenite::tungstenite::handshake::server::Response,
246
    ) -> Result<
247
        tokio_tungstenite::tungstenite::handshake::server::Response,
248
        tokio_tungstenite::tungstenite::handshake::server::ErrorResponse,
249
    > {
250
1386
        if let Some(protocols) = request.headers().get("Sec-WebSocket-Protocol") {
251
1386
            if let Ok(protocols) = protocols.to_str() {
252
1386
                for protocol in protocols.split(',').map(str::trim) {
253
1386
                    if protocol == CURRENT_PROTOCOL_VERSION {
254
1372
                        response.headers_mut().insert(
255
1372
                            "Sec-WebSocket-Protocol",
256
1372
                            CURRENT_PROTOCOL_VERSION.try_into().unwrap(),
257
1372
                        );
258
1372
                        return Ok(response);
259
14
                    }
260
                }
261
            }
262
        }
263

            
264
14
        let mut err = tokio_tungstenite::tungstenite::handshake::server::ErrorResponse::new(None);
265
14
        *err.status_mut() = 406_u16.try_into().unwrap();
266
14
        Err(err)
267
1386
    }
268
}