1
use bonsaidb_core::networking::CURRENT_PROTOCOL_VERSION;
2
use tokio::io::{AsyncRead, AsyncWrite};
3

            
4
use crate::{Backend, CustomServer, Error};
5

            
6
impl<B: Backend> CustomServer<B> {
7
    /// Listens for websocket connections on `addr`.
8
16
    pub async fn listen_for_websockets_on<T: tokio::net::ToSocketAddrs + Send + Sync>(
9
16
        &self,
10
16
        addr: T,
11
16
        with_tls: bool,
12
16
    ) -> Result<(), Error> {
13
16
        if with_tls {
14
            self.listen_for_secure_tcp_on(addr, ()).await
15
        } else {
16
90
            self.listen_for_tcp_on(addr, ()).await
17
        }
18
2
    }
19

            
20
86
    pub(crate) async fn handle_raw_websocket_connection<
21
86
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
22
86
    >(
23
86
        &self,
24
86
        connection: S,
25
86
        peer_address: std::net::SocketAddr,
26
86
    ) -> Result<(), Error> {
27
86
        let stream = tokio_tungstenite::accept_hdr_async(connection, VersionChecker).await?;
28
28236
        self.handle_websocket(stream, peer_address).await;
29
72
        Ok(())
30
73
    }
31

            
32
    /// Handles upgrading an HTTP connection to the `WebSocket` protocol based
33
    /// on the upgrade `request`. Requires feature `hyper` to be enabled.
34
    #[cfg(feature = "hyper")]
35
1
    pub async fn upgrade_websocket(
36
1
        &self,
37
1
        peer_address: std::net::SocketAddr,
38
1
        mut request: hyper::Request<hyper::Body>,
39
1
    ) -> hyper::Response<hyper::Body> {
40
1
        use hyper::{
41
1
            header::{HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE},
42
1
            StatusCode,
43
1
        };
44
1
        use tokio_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
45
1

            
46
1
        let mut response = hyper::Response::new(hyper::Body::empty());
47
1
        // Send a 400 to any request that doesn't have
48
1
        // an `Upgrade` header.
49
1
        if !request.headers().contains_key(UPGRADE) {
50
            *response.status_mut() = StatusCode::BAD_REQUEST;
51
            return response;
52
1
        }
53

            
54
1
        let sec_websocket_key = if let Some(key) = request.headers_mut().remove(SEC_WEBSOCKET_KEY) {
55
1
            key
56
        } else {
57
            *response.status_mut() = StatusCode::BAD_REQUEST;
58
            return response;
59
        };
60

            
61
1
        let task_self = self.clone();
62
1
        tokio::spawn(async move {
63
1
            match hyper::upgrade::on(&mut request).await {
64
1
                Ok(upgraded) => {
65
1
                    let ws = WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
66
2
                    task_self.handle_websocket(ws, peer_address).await;
67
                }
68
                Err(err) => {
69
                    log::error!("Error upgrading websocket: {:?}", err);
70
                }
71
            }
72
1
        });
73
1

            
74
1
        *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
75
1
        response
76
1
            .headers_mut()
77
1
            .insert(UPGRADE, HeaderValue::from_static("websocket"));
78
1
        response
79
1
            .headers_mut()
80
1
            .insert(CONNECTION, HeaderValue::from_static("upgrade"));
81
1
        response.headers_mut().insert(
82
1
            SEC_WEBSOCKET_ACCEPT,
83
1
            compute_websocket_accept_header(sec_websocket_key.as_bytes()),
84
1
        );
85
1

            
86
1
        response
87
1
    }
88

            
89
    /// Handles an established `tokio-tungstenite` `WebSocket` stream.
90
86
    pub async fn handle_websocket<
91
86
        S: futures::Stream<Item = Result<tokio_tungstenite::tungstenite::Message, E>>
92
86
            + futures::Sink<tokio_tungstenite::tungstenite::Message>
93
86
            + Send
94
86
            + 'static,
95
86
        E: std::fmt::Debug + Send,
96
86
    >(
97
86
        &self,
98
86
        connection: S,
99
86
        peer_address: std::net::SocketAddr,
100
86
    ) {
101
86
        use bonsaidb_core::networking::Payload;
102
86
        use futures::{SinkExt, StreamExt};
103
86
        use tokio_tungstenite::tungstenite::Message;
104
86

            
105
86
        use crate::Transport;
106
86

            
107
86
        let (mut sender, mut receiver) = connection.split();
108
86
        let (response_sender, response_receiver) = flume::unbounded();
109
86
        let (message_sender, message_receiver) = flume::unbounded();
110
86

            
111
86
        let (api_response_sender, api_response_receiver) = flume::unbounded();
112
86
        let client = if let Some(client) = self
113
86
            .initialize_client(Transport::WebSocket, peer_address, api_response_sender)
114
            .await
115
        {
116
86
            client
117
        } else {
118
            return;
119
        };
120
86
        let task_sender = response_sender.clone();
121
86
        tokio::spawn(async move {
122
116
            while let Ok((session_id, name, value)) = api_response_receiver.recv_async().await {
123
30
                if task_sender
124
30
                    .send(Payload {
125
30
                        id: None,
126
30
                        session_id,
127
30
                        name,
128
30
                        value: Ok(value),
129
30
                    })
130
30
                    .is_err()
131
                {
132
                    break;
133
30
                }
134
            }
135
86
        });
136
86

            
137
86
        tokio::spawn(async move {
138
51380
            while let Ok(response) = message_receiver.recv_async().await {
139
37005
                if sender.send(response).await.is_err() {
140
1
                    break;
141
37004
                }
142
            }
143

            
144
68
            Result::<(), Error>::Ok(())
145
86
        });
146
86

            
147
86
        let task_sender = message_sender.clone();
148
86
        tokio::spawn(async move {
149
37091
            while let Ok(response) = response_receiver.recv_async().await {
150
                if task_sender
151
37005
                    .send(Message::Binary(bincode::serialize(&response)?))
152
37005
                    .is_err()
153
                {
154
                    break;
155
37005
                }
156
            }
157

            
158
68
            Result::<(), Error>::Ok(())
159
86
        });
160
86

            
161
86
        let (request_sender, request_receiver) =
162
86
            flume::bounded::<Payload>(self.data.client_simultaneous_request_limit);
163
86
        let task_self = self.clone();
164
86
        tokio::spawn(async move {
165
86
            task_self
166
35375
                .handle_client_requests(client.clone(), request_receiver, response_sender)
167
35374
                .await;
168
86
        });
169

            
170
37162
        while let Some(payload) = receiver.next().await {
171
36975
            match payload {
172
36975
                Ok(Message::Binary(binary)) => match bincode::deserialize::<Payload>(&binary) {
173
36975
                    Ok(payload) => drop(request_sender.send_async(payload).await),
174
                    Err(err) => {
175
                        log::error!("[server] error decoding message: {:?}", err);
176
                        break;
177
                    }
178
                },
179
                Ok(Message::Close(_)) => break,
180
                Ok(Message::Ping(payload)) => {
181
                    drop(message_sender.send(Message::Pong(payload)));
182
                }
183
72
                other => {
184
64
                    log::error!("[server] unexpected message: {:?}", other);
185
                }
186
            }
187
        }
188
72
    }
189
}
190

            
191
#[cfg(feature = "hyper")]
192
13
fn compute_websocket_accept_header(key: &[u8]) -> hyper::header::HeaderValue {
193
13
    use sha1::{Digest, Sha1};
194
13

            
195
13
    let mut digest = Sha1::default();
196
13
    digest.update(key);
197
13
    digest.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
198
13
    let encoded = base64::encode(&digest.finalize());
199
13
    hyper::header::HeaderValue::from_str(&encoded).expect("base64 is a valid value")
200
13
}
201

            
202
struct VersionChecker;
203

            
204
impl tokio_tungstenite::tungstenite::handshake::server::Callback for VersionChecker {
205
    fn on_request(
206
        self,
207
        request: &tokio_tungstenite::tungstenite::handshake::server::Request,
208
        mut response: tokio_tungstenite::tungstenite::handshake::server::Response,
209
    ) -> Result<
210
        tokio_tungstenite::tungstenite::handshake::server::Response,
211
        tokio_tungstenite::tungstenite::handshake::server::ErrorResponse,
212
    > {
213
1118
        if let Some(protocols) = request.headers().get("Sec-WebSocket-Protocol") {
214
1118
            if let Ok(protocols) = protocols.to_str() {
215
1118
                for protocol in protocols.split(',').map(str::trim) {
216
1118
                    if protocol == CURRENT_PROTOCOL_VERSION {
217
1105
                        response.headers_mut().insert(
218
1105
                            "Sec-WebSocket-Protocol",
219
1105
                            CURRENT_PROTOCOL_VERSION.try_into().unwrap(),
220
1105
                        );
221
1105
                        return Ok(response);
222
13
                    }
223
                }
224
            }
225
        }
226

            
227
13
        let mut err = tokio_tungstenite::tungstenite::handshake::server::ErrorResponse::new(None);
228
13
        *err.status_mut() = 406_u16.try_into().unwrap();
229
13
        Err(err)
230
1118
    }
231
}