1
use std::sync::Arc;
2

            
3
use async_trait::async_trait;
4
use rustls::server::ResolvesServerCert;
5
use tokio::io::{AsyncRead, AsyncWrite};
6
use tokio::net::TcpListener;
7

            
8
use crate::{Backend, CustomServer, Error};
9

            
10
impl<B: Backend> CustomServer<B> {
11
    /// Listens for HTTP traffic on `port`. This port will also receive
12
    /// `WebSocket` connections if feature `websockets` is enabled.
13
21
    pub async fn listen_for_tcp_on<S: TcpService, T: tokio::net::ToSocketAddrs + Send + Sync>(
14
21
        &self,
15
21
        addr: T,
16
21
        service: S,
17
21
    ) -> Result<(), Error> {
18
21
        let listener = TcpListener::bind(&addr).await?;
19
20
        let mut shutdown_watcher = self
20
20
            .data
21
20
            .shutdown
22
20
            .watcher()
23
            .await
24
20
            .expect("server already shutdown");
25

            
26
        loop {
27
122
            tokio::select! {
28
                _ = shutdown_watcher.wait_for_shutdown() => {
29
                    break;
30
                }
31
102
                incoming = listener.accept() => {
32
                    if incoming.is_err() {
33
                        continue;
34
                    }
35
                    let (connection, remote_addr) = incoming.unwrap();
36

            
37
                    let peer = Peer {
38
                        address: remote_addr,
39
                        protocol: service.available_protocols()[0].clone(),
40
                        secure: false,
41
                    };
42

            
43
                    let task_self = self.clone();
44
                    let task_service = service.clone();
45
102
                    tokio::spawn(async move {
46
25323
                        if let Err(err) = task_self.handle_tcp_connection(connection, peer, &task_service).await {
47
                            log::error!("[server] closing connection {}: {:?}", remote_addr, err);
48
89
                        }
49
89
                    });
50
                }
51
            }
52
        }
53

            
54
3
        Ok(())
55
4
    }
56

            
57
    /// Listens for HTTPS traffic on `port`. This port will also receive
58
    /// `WebSocket` connections if feature `websockets` is enabled. If feature
59
    /// `acme` is enabled, this connection will automatically manage the
60
    /// server's private key and certificate, which is also used for the
61
    /// QUIC-based protocol.
62
    #[cfg_attr(not(feature = "websockets"), allow(unused_variables))]
63
    #[cfg_attr(not(feature = "acme"), allow(unused_mut))]
64
1
    pub async fn listen_for_secure_tcp_on<
65
1
        S: TcpService,
66
1
        T: tokio::net::ToSocketAddrs + Send + Sync,
67
1
    >(
68
1
        &self,
69
1
        addr: T,
70
1
        service: S,
71
1
    ) -> Result<(), Error> {
72
        // We may not have a certificate yet, so we ignore any errors.
73
3
        drop(self.refresh_certified_key().await);
74

            
75
        #[cfg(feature = "acme")]
76
1
        {
77
1
            let task_self = self.clone();
78
1
            tokio::task::spawn(async move {
79
1
                if let Err(err) = task_self.update_acme_certificates().await {
80
                    log::error!("[server] acme task error: {0}", err);
81
                }
82
1
            });
83
1
        }
84
1

            
85
1
        let mut config = rustls::ServerConfig::builder()
86
1
            .with_safe_defaults()
87
1
            .with_no_client_auth()
88
1
            .with_cert_resolver(Arc::new(self.clone()));
89
1
        config.alpn_protocols = service
90
1
            .available_protocols()
91
1
            .iter()
92
2
            .map(|proto| proto.alpn_name().to_vec())
93
1
            .collect();
94
1

            
95
1
        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
96
1
        let listener = TcpListener::bind(&addr).await?;
97
        loop {
98
            let (stream, peer_addr) = listener.accept().await?;
99
            let acceptor = acceptor.clone();
100

            
101
            let task_self = self.clone();
102
            let task_service = service.clone();
103
            tokio::task::spawn(async move {
104
                let stream = match acceptor.accept(stream).await {
105
                    Ok(stream) => stream,
106
                    Err(err) => {
107
                        log::error!("[server] error during tls handshake: {:?}", err);
108
                        return;
109
                    }
110
                };
111

            
112
                let available_protocols = task_service.available_protocols();
113
                let protocol = stream
114
                    .get_ref()
115
                    .1
116
                    .alpn_protocol()
117
                    .and_then(|protocol| {
118
                        available_protocols
119
                            .iter()
120
                            .find(|p| p.alpn_name() == protocol)
121
                            .cloned()
122
                    })
123
                    .unwrap_or_else(|| available_protocols[0].clone());
124
                let peer = Peer {
125
                    address: peer_addr,
126
                    secure: true,
127
                    protocol,
128
                };
129
                if let Err(err) = task_self
130
                    .handle_tcp_connection(stream, peer, &task_service)
131
                    .await
132
                {
133
                    log::error!("[server] error for client {}: {:?}", peer_addr, err);
134
                }
135
            });
136
        }
137
1
    }
138

            
139
    #[cfg_attr(not(feature = "websockets"), allow(unused_variables))]
140
102
    async fn handle_tcp_connection<
141
102
        S: TcpService,
142
102
        C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
143
102
    >(
144
102
        &self,
145
102
        connection: C,
146
102
        peer: Peer<S::ApplicationProtocols>,
147
102
        service: &S,
148
102
    ) -> Result<(), Error> {
149
102
        // For ACME, don't send any traffic over the connection.
150
102
        #[cfg(feature = "acme")]
151
102
        if peer.protocol.alpn_name() == async_acme::acme::ACME_TLS_ALPN_NAME {
152
            log::info!("received acme challenge connection");
153
            return Ok(());
154
102
        }
155

            
156
107
        if let Err(connection) = service.handle_connection(connection, &peer).await {
157
            #[cfg(feature = "websockets")]
158
99
            if let Err(err) = self
159
99
                .handle_raw_websocket_connection(connection, peer.address)
160
25315
                .await
161
            {
162
                log::error!(
163
1
                    "[server] error on websocket for {}: {:?}",
164
                    peer.address,
165
                    err
166
                );
167
85
            }
168
3
        }
169

            
170
89
        Ok(())
171
89
    }
172
}
173

            
174
impl<B: Backend> ResolvesServerCert for CustomServer<B> {
175
    #[cfg_attr(not(feature = "acme"), allow(unused_variables))]
176
    fn resolve(
177
        &self,
178
        client_hello: rustls::server::ClientHello<'_>,
179
    ) -> Option<Arc<rustls::sign::CertifiedKey>> {
180
        #[cfg(feature = "acme")]
181
        if client_hello
182
            .alpn()
183
            .map(|mut iter| iter.any(|n| n == async_acme::acme::ACME_TLS_ALPN_NAME))
184
            .unwrap_or_default()
185
        {
186
            let server_name = client_hello.server_name()?.to_owned();
187
            let keys = self.data.alpn_keys.lock();
188
            if let Some(key) = keys.get(AsRef::<str>::as_ref(&server_name)) {
189
                log::info!("returning acme challenge");
190
                return Some(key.clone());
191
            }
192

            
193
            log::error!(
194
                "acme alpn challenge received with no key for {}",
195
                server_name
196
            );
197
            return None;
198
        }
199

            
200
        let cached_key = self.data.primary_tls_key.lock();
201
        if let Some(key) = cached_key.as_ref() {
202
            Some(key.clone())
203
        } else {
204
            log::error!("[server] inbound tls connection with no certificate installed");
205
            None
206
        }
207
    }
208
}
209

            
210
/// A service that can handle incoming TCP connections.
211
#[async_trait]
212
pub trait TcpService: Clone + Send + Sync + 'static {
213
    /// The application layer protocols that this service supports.
214
    type ApplicationProtocols: ApplicationProtocols;
215

            
216
    /// Returns all available protocols for this service. The first will be the
217
    /// default used if a connection is made without negotiating the application
218
    /// protocol.
219
    fn available_protocols(&self) -> &[Self::ApplicationProtocols];
220

            
221
    /// Handle an incoming `connection` for `peer`. Return `Err(connection)` to
222
    /// have BonsaiDb handle the connection internally.
223
    async fn handle_connection<
224
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
225
    >(
226
        &self,
227
        connection: S,
228
        peer: &Peer<Self::ApplicationProtocols>,
229
    ) -> Result<(), S>;
230
}
231

            
232
/// A service that can handle incoming HTTP connections. A convenience
233
/// implementation of [`TcpService`] that is useful is you are only serving HTTP
234
/// and WebSockets over a service.
235
#[async_trait]
236
pub trait HttpService: Clone + Send + Sync + 'static {
237
    /// Handle an incoming `connection` for `peer`. Return `Err(connection)` to
238
    /// have BonsaiDb handle the connection internally.
239
    async fn handle_connection<
240
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
241
    >(
242
        &self,
243
        connection: S,
244
        peer: &Peer,
245
    ) -> Result<(), S>;
246
}
247

            
248
#[async_trait]
249
impl<T> TcpService for T
250
where
251
    T: HttpService,
252
{
253
    type ApplicationProtocols = StandardTcpProtocols;
254

            
255
103
    fn available_protocols(&self) -> &[Self::ApplicationProtocols] {
256
103
        StandardTcpProtocols::all()
257
103
    }
258

            
259
102
    async fn handle_connection<
260
102
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
261
102
    >(
262
102
        &self,
263
102
        connection: S,
264
102
        peer: &Peer<Self::ApplicationProtocols>,
265
102
    ) -> Result<(), S> {
266
107
        HttpService::handle_connection(self, connection, peer).await
267
204
    }
268
}
269

            
270
#[async_trait]
271
impl HttpService for () {
272
99
    async fn handle_connection<
273
99
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
274
99
    >(
275
99
        &self,
276
99
        connection: S,
277
99
        _peer: &Peer<StandardTcpProtocols>,
278
99
    ) -> Result<(), S> {
279
99
        Err(connection)
280
99
    }
281
}
282

            
283
/// A collection of supported protocols for a network service.
284
pub trait ApplicationProtocols: Clone + std::fmt::Debug + Send + Sync {
285
    /// Returns the identifier to use in ALPN during TLS negotiation.
286
    fn alpn_name(&self) -> &'static [u8];
287
}
288

            
289
/// A connected network peer.
290
#[derive(Debug, Clone)]
291
pub struct Peer<P: ApplicationProtocols = StandardTcpProtocols> {
292
    /// The remote address of the peer.
293
    pub address: std::net::SocketAddr,
294
    /// If true, the connection is secured with TLS.
295
    pub secure: bool,
296
    /// The application protocol to use for this connection.
297
    pub protocol: P,
298
}
299

            
300
/// TCP [`ApplicationProtocols`] that BonsaiDb has some knowledge of.
301
102
#[derive(Debug, Clone)]
302
#[allow(missing_docs)]
303
pub enum StandardTcpProtocols {
304
    Http1,
305
    #[cfg(feature = "acme")]
306
    Acme,
307
    Other,
308
}
309

            
310
impl StandardTcpProtocols {
311
    #[cfg(feature = "acme")]
312
1442
    const fn all() -> &'static [Self] {
313
1442
        &[Self::Http1, Self::Acme]
314
1442
    }
315

            
316
    #[cfg(not(feature = "acme"))]
317
    const fn all() -> &'static [Self] {
318
        &[Self::Http1]
319
    }
320
}
321

            
322
impl Default for StandardTcpProtocols {
323
    fn default() -> Self {
324
        Self::Http1
325
    }
326
}
327

            
328
impl ApplicationProtocols for StandardTcpProtocols {
329
1456
    fn alpn_name(&self) -> &'static [u8] {
330
1456
        match self {
331
1442
            Self::Http1 => b"http/1.1",
332
            #[cfg(feature = "acme")]
333
14
            Self::Acme => async_acme::acme::ACME_TLS_ALPN_NAME,
334
            Self::Other => unreachable!(),
335
        }
336
1456
    }
337
}