1
use std::sync::Arc;
2

            
3
use async_trait::async_trait;
4
use rustls::server::ResolvesServerCert;
5
use tokio::{
6
    io::{AsyncRead, AsyncWrite},
7
    net::TcpListener,
8
};
9

            
10
use crate::{Backend, CustomServer, Error};
11

            
12
impl<B: Backend> CustomServer<B> {
13
    /// Listens for HTTP traffic on `port`. This port will also receive
14
    /// `WebSocket` connections if feature `websockets` is enabled.
15
18
    pub async fn listen_for_tcp_on<S: TcpService, T: tokio::net::ToSocketAddrs + Send + Sync>(
16
18
        &self,
17
18
        addr: T,
18
18
        service: S,
19
18
    ) -> Result<(), Error> {
20
18
        let listener = TcpListener::bind(&addr).await?;
21
17
        let mut shutdown_watcher = self
22
17
            .data
23
17
            .shutdown
24
17
            .watcher()
25
            .await
26
17
            .expect("server already shutdown");
27

            
28
        loop {
29
106
            tokio::select! {
30
                _ = shutdown_watcher.wait_for_shutdown() => {
31
                    break;
32
                }
33
89
                incoming = listener.accept() => {
34
                    if incoming.is_err() {
35
                        continue;
36
                    }
37
                    let (connection, remote_addr) = incoming.unwrap();
38

            
39
                    let peer = Peer {
40
                        address: remote_addr,
41
                        protocol: service.available_protocols()[0].clone(),
42
                        secure: false,
43
                    };
44

            
45
                    let task_self = self.clone();
46
                    let task_service = service.clone();
47
89
                    tokio::spawn(async move {
48
28328
                        if let Err(err) = task_self.handle_tcp_connection(connection, peer, &task_service).await {
49
                            log::error!("[server] closing connection {}: {:?}", remote_addr, err);
50
76
                        }
51
76
                    });
52
                }
53
            }
54
        }
55

            
56
2
        Ok(())
57
3
    }
58

            
59
    /// Listens for HTTPS traffic on `port`. This port will also receive
60
    /// `WebSocket` connections if feature `websockets` is enabled. If feature
61
    /// `acme` is enabled, this connection will automatically manage the
62
    /// server's private key and certificate, which is also used for the
63
    /// QUIC-based protocol.
64
    #[cfg_attr(not(feature = "websockets"), allow(unused_variables))]
65
    #[cfg_attr(not(feature = "acme"), allow(unused_mut))]
66
1
    pub async fn listen_for_secure_tcp_on<
67
1
        S: TcpService,
68
1
        T: tokio::net::ToSocketAddrs + Send + Sync,
69
1
    >(
70
1
        &self,
71
1
        addr: T,
72
1
        service: S,
73
1
    ) -> Result<(), Error> {
74
        // We may not have a certificate yet, so we ignore any errors.
75
3
        drop(self.refresh_certified_key().await);
76

            
77
        #[cfg(feature = "acme")]
78
1
        {
79
1
            let task_self = self.clone();
80
1
            tokio::task::spawn(async move {
81
1
                if let Err(err) = task_self.update_acme_certificates().await {
82
                    log::error!("[server] acme task error: {0}", err);
83
                }
84
1
            });
85
1
        }
86
1

            
87
1
        let mut config = rustls::ServerConfig::builder()
88
1
            .with_safe_defaults()
89
1
            .with_no_client_auth()
90
1
            .with_cert_resolver(Arc::new(self.clone()));
91
1
        config.alpn_protocols = service
92
1
            .available_protocols()
93
1
            .iter()
94
2
            .map(|proto| proto.alpn_name().to_vec())
95
1
            .collect();
96
1

            
97
1
        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
98
1
        let listener = TcpListener::bind(&addr).await?;
99
        loop {
100
            let (stream, peer_addr) = listener.accept().await?;
101
            let acceptor = acceptor.clone();
102

            
103
            let task_self = self.clone();
104
            let task_service = service.clone();
105
            tokio::task::spawn(async move {
106
                let stream = match acceptor.accept(stream).await {
107
                    Ok(stream) => stream,
108
                    Err(err) => {
109
                        log::error!("[server] error during tls handshake: {:?}", err);
110
                        return;
111
                    }
112
                };
113

            
114
                let available_protocols = task_service.available_protocols();
115
                let protocol = stream
116
                    .get_ref()
117
                    .1
118
                    .alpn_protocol()
119
                    .and_then(|protocol| {
120
                        available_protocols
121
                            .iter()
122
                            .find(|p| p.alpn_name() == protocol)
123
                            .cloned()
124
                    })
125
                    .unwrap_or_else(|| available_protocols[0].clone());
126
                let peer = Peer {
127
                    address: peer_addr,
128
                    secure: true,
129
                    protocol,
130
                };
131
                if let Err(err) = task_self
132
                    .handle_tcp_connection(stream, peer, &task_service)
133
                    .await
134
                {
135
                    log::error!("[server] error for client {}: {:?}", peer_addr, err);
136
                }
137
            });
138
        }
139
1
    }
140

            
141
    #[cfg_attr(not(feature = "websockets"), allow(unused_variables))]
142
89
    async fn handle_tcp_connection<
143
89
        S: TcpService,
144
89
        C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
145
89
    >(
146
89
        &self,
147
89
        connection: C,
148
89
        peer: Peer<S::ApplicationProtocols>,
149
89
        service: &S,
150
89
    ) -> Result<(), Error> {
151
89
        // For ACME, don't send any traffic over the connection.
152
89
        #[cfg(feature = "acme")]
153
89
        if peer.protocol.alpn_name() == async_acme::acme::ACME_TLS_ALPN_NAME {
154
            log::info!("received acme challenge connection");
155
            return Ok(());
156
89
        }
157

            
158
92
        if let Err(connection) = service.handle_connection(connection, &peer).await {
159
            #[cfg(feature = "websockets")]
160
86
            if let Err(err) = self
161
28322
                .handle_raw_websocket_connection(connection, peer.address)
162
28322
                .await
163
            {
164
                log::error!(
165
1
                    "[server] error on websocket for {}: {:?}",
166
                    peer.address,
167
                    err
168
                );
169
72
            }
170
3
        }
171

            
172
76
        Ok(())
173
76
    }
174
}
175

            
176
impl<B: Backend> ResolvesServerCert for CustomServer<B> {
177
    #[cfg_attr(not(feature = "acme"), allow(unused_variables))]
178
    fn resolve(
179
        &self,
180
        client_hello: rustls::server::ClientHello<'_>,
181
    ) -> Option<Arc<rustls::sign::CertifiedKey>> {
182
        #[cfg(feature = "acme")]
183
        if client_hello
184
            .alpn()
185
            .map(|mut iter| iter.any(|n| n == async_acme::acme::ACME_TLS_ALPN_NAME))
186
            .unwrap_or_default()
187
        {
188
            let server_name = client_hello.server_name()?.to_owned();
189
            let keys = self.data.alpn_keys.lock().unwrap();
190
            if let Some(key) = keys.get(AsRef::<str>::as_ref(&server_name)) {
191
                log::info!("returning acme challenge");
192
                return Some(key.clone());
193
            }
194

            
195
            log::error!(
196
                "acme alpn challenge received with no key for {}",
197
                server_name
198
            );
199
            return None;
200
        }
201

            
202
        let cached_key = self.data.primary_tls_key.lock();
203
        if let Some(key) = cached_key.as_ref() {
204
            Some(key.clone())
205
        } else {
206
            log::error!("[server] inbound tls connection with no certificate installed");
207
            None
208
        }
209
    }
210
}
211

            
212
/// A service that can handle incoming TCP connections.
213
#[async_trait]
214
pub trait TcpService: Clone + Send + Sync + 'static {
215
    /// The application layer protocols that this service supports.
216
    type ApplicationProtocols: ApplicationProtocols;
217

            
218
    /// Returns all available protocols for this service. The first will be the
219
    /// default used if a connection is made without negotiating the application
220
    /// protocol.
221
    fn available_protocols(&self) -> &[Self::ApplicationProtocols];
222

            
223
    /// Handle an incoming `connection` for `peer`. Return `Err(connection)` to
224
    /// have BonsaiDb handle the connection internally.
225
    async fn handle_connection<
226
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
227
    >(
228
        &self,
229
        connection: S,
230
        peer: &Peer<Self::ApplicationProtocols>,
231
    ) -> Result<(), S>;
232
}
233

            
234
/// A service that can handle incoming HTTP connections. A convenience
235
/// implementation of [`TcpService`] that is useful is you are only serving HTTP
236
/// and WebSockets over a service.
237
#[async_trait]
238
pub trait HttpService: Clone + Send + Sync + 'static {
239
    /// Handle an incoming `connection` for `peer`. Return `Err(connection)` to
240
    /// have BonsaiDb handle the connection internally.
241
    async fn handle_connection<
242
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
243
    >(
244
        &self,
245
        connection: S,
246
        peer: &Peer,
247
    ) -> Result<(), S>;
248
}
249

            
250
#[async_trait]
251
impl<T> TcpService for T
252
where
253
    T: HttpService,
254
{
255
    type ApplicationProtocols = StandardTcpProtocols;
256

            
257
90
    fn available_protocols(&self) -> &[Self::ApplicationProtocols] {
258
90
        StandardTcpProtocols::all()
259
90
    }
260

            
261
89
    async fn handle_connection<
262
89
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
263
89
    >(
264
89
        &self,
265
89
        connection: S,
266
89
        peer: &Peer<Self::ApplicationProtocols>,
267
89
    ) -> Result<(), S> {
268
92
        HttpService::handle_connection(self, connection, peer).await
269
178
    }
270
}
271

            
272
#[async_trait]
273
impl HttpService for () {
274
86
    async fn handle_connection<
275
86
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
276
86
    >(
277
86
        &self,
278
86
        connection: S,
279
86
        _peer: &Peer<StandardTcpProtocols>,
280
86
    ) -> Result<(), S> {
281
86
        Err(connection)
282
86
    }
283
}
284

            
285
/// A collection of supported protocols for a network service.
286
pub trait ApplicationProtocols: Clone + std::fmt::Debug + Send + Sync {
287
    /// Returns the identifier to use in ALPN during TLS negotiation.
288
    fn alpn_name(&self) -> &'static [u8];
289
}
290

            
291
/// A connected network peer.
292
#[derive(Debug, Clone)]
293
pub struct Peer<P: ApplicationProtocols = StandardTcpProtocols> {
294
    /// The remote address of the peer.
295
    pub address: std::net::SocketAddr,
296
    /// If true, the connection is secured with TLS.
297
    pub secure: bool,
298
    /// The application protocol to use for this connection.
299
    pub protocol: P,
300
}
301

            
302
/// TCP [`ApplicationProtocols`] that BonsaiDb has some knowledge of.
303
89
#[derive(Debug, Clone)]
304
#[allow(missing_docs)]
305
pub enum StandardTcpProtocols {
306
    Http1,
307
    #[cfg(feature = "acme")]
308
    Acme,
309
    Other,
310
}
311

            
312
impl StandardTcpProtocols {
313
    #[cfg(feature = "acme")]
314
1170
    const fn all() -> &'static [Self] {
315
1170
        &[Self::Http1, Self::Acme]
316
1170
    }
317

            
318
    #[cfg(not(feature = "acme"))]
319
    const fn all() -> &'static [Self] {
320
        &[Self::Http1]
321
    }
322
}
323

            
324
impl Default for StandardTcpProtocols {
325
    fn default() -> Self {
326
        Self::Http1
327
    }
328
}
329

            
330
impl ApplicationProtocols for StandardTcpProtocols {
331
1183
    fn alpn_name(&self) -> &'static [u8] {
332
1183
        match self {
333
1170
            StandardTcpProtocols::Http1 => b"http/1.1",
334
            #[cfg(feature = "acme")]
335
13
            StandardTcpProtocols::Acme => async_acme::acme::ACME_TLS_ALPN_NAME,
336
            StandardTcpProtocols::Other => unreachable!(),
337
        }
338
1183
    }
339
}