1
use bonsaidb_core::networking::CURRENT_PROTOCOL_VERSION;
2
use tokio::io::{AsyncRead, AsyncWrite};
3

            
4
use crate::{Backend, CustomServer, Error};
5

            
6
impl<B: Backend> CustomServer<B> {
7
    /// Listens for websocket connections on `addr`.
8
16
    pub async fn listen_for_websockets_on<T: tokio::net::ToSocketAddrs + Send + Sync>(
9
16
        &self,
10
16
        addr: T,
11
16
        with_tls: bool,
12
16
    ) -> Result<(), Error> {
13
16
        if with_tls {
14
            self.listen_for_secure_tcp_on(addr, ()).await
15
        } else {
16
57
            self.listen_for_tcp_on(addr, ()).await
17
        }
18
2
    }
19

            
20
55
    pub(crate) async fn handle_raw_websocket_connection<
21
55
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
22
55
    >(
23
55
        &self,
24
55
        connection: S,
25
55
        peer_address: std::net::SocketAddr,
26
55
    ) -> Result<(), Error> {
27
55
        let stream = tokio_tungstenite::accept_hdr_async(connection, VersionChecker).await?;
28
15097
        self.handle_websocket(stream, peer_address).await;
29
41
        Ok(())
30
42
    }
31

            
32
    /// Handles upgrading an HTTP connection to the `WebSocket` protocol based
33
    /// on the upgrade `request`. Requires feature `hyper` to be enabled.
34
    #[cfg(feature = "hyper")]
35
1
    pub async fn upgrade_websocket(
36
1
        &self,
37
1
        peer_address: std::net::SocketAddr,
38
1
        mut request: hyper::Request<hyper::Body>,
39
1
    ) -> hyper::Response<hyper::Body> {
40
1
        use hyper::{
41
1
            header::{HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE},
42
1
            StatusCode,
43
1
        };
44
1
        use tokio_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
45
1

            
46
1
        let mut response = hyper::Response::new(hyper::Body::empty());
47
1
        // Send a 400 to any request that doesn't have
48
1
        // an `Upgrade` header.
49
1
        if !request.headers().contains_key(UPGRADE) {
50
            *response.status_mut() = StatusCode::BAD_REQUEST;
51
            return response;
52
1
        }
53

            
54
1
        let sec_websocket_key = if let Some(key) = request.headers_mut().remove(SEC_WEBSOCKET_KEY) {
55
1
            key
56
        } else {
57
            *response.status_mut() = StatusCode::BAD_REQUEST;
58
            return response;
59
        };
60

            
61
1
        let task_self = self.clone();
62
1
        tokio::spawn(async move {
63
1
            match hyper::upgrade::on(&mut request).await {
64
1
                Ok(upgraded) => {
65
1
                    let ws = WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
66
2
                    task_self.handle_websocket(ws, peer_address).await;
67
                }
68
                Err(err) => {
69
                    log::error!("Error upgrading websocket: {:?}", err);
70
                }
71
            }
72
1
        });
73
1

            
74
1
        *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
75
1
        response
76
1
            .headers_mut()
77
1
            .insert(UPGRADE, HeaderValue::from_static("websocket"));
78
1
        response
79
1
            .headers_mut()
80
1
            .insert(CONNECTION, HeaderValue::from_static("upgrade"));
81
1
        response.headers_mut().insert(
82
1
            SEC_WEBSOCKET_ACCEPT,
83
1
            compute_websocket_accept_header(sec_websocket_key.as_bytes()),
84
1
        );
85
1

            
86
1
        response
87
1
    }
88

            
89
    /// Handles an established `tokio-tungstenite` `WebSocket` stream.
90
55
    pub async fn handle_websocket<
91
55
        S: futures::Stream<Item = Result<tokio_tungstenite::tungstenite::Message, E>>
92
55
            + futures::Sink<tokio_tungstenite::tungstenite::Message>
93
55
            + Send
94
55
            + 'static,
95
55
        E: std::fmt::Debug + Send,
96
55
    >(
97
55
        &self,
98
55
        connection: S,
99
55
        peer_address: std::net::SocketAddr,
100
55
    ) {
101
55
        use bonsaidb_core::{
102
55
            custom_api::CustomApi,
103
55
            networking::{Payload, Request, Response},
104
55
        };
105
55
        use futures::{SinkExt, StreamExt};
106
55
        use tokio_tungstenite::tungstenite::Message;
107
55

            
108
55
        use crate::Transport;
109
55

            
110
55
        let (mut sender, mut receiver) = connection.split();
111
55
        let (response_sender, response_receiver) = flume::unbounded();
112
55
        let (message_sender, message_receiver) = flume::unbounded();
113
55

            
114
55
        let (api_response_sender, api_response_receiver) = flume::unbounded();
115
55
        let client = if let Some(client) = self
116
55
            .initialize_client(Transport::WebSocket, peer_address, api_response_sender)
117
            .await
118
        {
119
55
            client
120
        } else {
121
            return;
122
        };
123
55
        let task_sender = response_sender.clone();
124
55
        tokio::spawn(async move {
125
55
            while let Ok(response) = api_response_receiver.recv_async().await {
126
                if task_sender
127
                    .send(Payload {
128
                        id: None,
129
                        wrapped: Response::Api(response),
130
                    })
131
                    .is_err()
132
                {
133
                    break;
134
                }
135
            }
136
55
        });
137
55

            
138
55
        tokio::spawn(async move {
139
32146
            while let Ok(response) = message_receiver.recv_async().await {
140
21722
                if sender.send(response).await.is_err() {
141
                    break;
142
21722
                }
143
            }
144

            
145
37
            Result::<(), Error>::Ok(())
146
55
        });
147
55

            
148
55
        let task_sender = message_sender.clone();
149
55
        tokio::spawn(async move {
150
21777
            while let Ok(response) = response_receiver.recv_async().await {
151
                if task_sender
152
21722
                    .send(Message::Binary(bincode::serialize(&response)?))
153
21722
                    .is_err()
154
                {
155
                    break;
156
21722
                }
157
            }
158

            
159
37
            Result::<(), Error>::Ok(())
160
55
        });
161
55

            
162
55
        let (request_sender, request_receiver) =
163
55
            flume::bounded::<Payload<Request<<B::CustomApi as CustomApi>::Request>>>(
164
55
                self.data.client_simultaneous_request_limit,
165
55
            );
166
55
        let task_self = self.clone();
167
55
        tokio::spawn(async move {
168
55
            task_self
169
21027
                .handle_client_requests(client.clone(), request_receiver, response_sender)
170
21027
                .await;
171
55
        });
172

            
173
21817
        while let Some(payload) = receiver.next().await {
174
21708
            match payload {
175
21708
                Ok(Message::Binary(binary)) => {
176
21708
                    match bincode::deserialize::<
177
21708
                        Payload<Request<<B::CustomApi as CustomApi>::Request>>,
178
21708
                    >(&binary)
179
                    {
180
21708
                        Ok(payload) => drop(request_sender.send_async(payload).await),
181
                        Err(err) => {
182
                            log::error!("[server] error decoding message: {:?}", err);
183
                            break;
184
                        }
185
                    }
186
                }
187
                Ok(Message::Close(_)) => break,
188
                Ok(Message::Ping(payload)) => {
189
                    drop(message_sender.send(Message::Pong(payload)));
190
                }
191
41
                other => {
192
32
                    log::error!("[server] unexpected message: {:?}", other);
193
                }
194
            }
195
        }
196
41
    }
197
}
198

            
199
#[cfg(feature = "hyper")]
200
18
fn compute_websocket_accept_header(key: &[u8]) -> hyper::header::HeaderValue {
201
18
    use sha1::{Digest, Sha1};
202
18

            
203
18
    let mut digest = Sha1::default();
204
18
    digest.update(key);
205
18
    digest.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
206
18
    let encoded = base64::encode(&digest.finalize());
207
18
    hyper::header::HeaderValue::from_str(&encoded).expect("base64 is a valid value")
208
18
}
209

            
210
struct VersionChecker;
211

            
212
impl tokio_tungstenite::tungstenite::handshake::server::Callback for VersionChecker {
213
    fn on_request(
214
        self,
215
        request: &tokio_tungstenite::tungstenite::handshake::server::Request,
216
        mut response: tokio_tungstenite::tungstenite::handshake::server::Response,
217
    ) -> Result<
218
        tokio_tungstenite::tungstenite::handshake::server::Response,
219
        tokio_tungstenite::tungstenite::handshake::server::ErrorResponse,
220
    > {
221
990
        if let Some(protocols) = request.headers().get("Sec-WebSocket-Protocol") {
222
990
            if let Ok(protocols) = protocols.to_str() {
223
990
                for protocol in protocols.split(',').map(str::trim) {
224
990
                    if protocol == CURRENT_PROTOCOL_VERSION {
225
972
                        response.headers_mut().insert(
226
972
                            "Sec-WebSocket-Protocol",
227
972
                            CURRENT_PROTOCOL_VERSION.try_into().unwrap(),
228
972
                        );
229
972
                        return Ok(response);
230
18
                    }
231
                }
232
            }
233
        }
234

            
235
18
        let mut err = tokio_tungstenite::tungstenite::handshake::server::ErrorResponse::new(None);
236
18
        *err.status_mut() = 406_u16.try_into().unwrap();
237
18
        Err(err)
238
990
    }
239
}