1
use std::marker::PhantomData;
2
#[cfg(any(feature = "websockets", feature = "acme"))]
3
use std::net::{Ipv6Addr, SocketAddr, SocketAddrV6};
4
#[cfg(feature = "acme")]
5
use std::time::Duration;
6

            
7
use clap::Args;
8

            
9
use crate::{Backend, BackendError, CustomServer, TcpService};
10

            
11
/// Execute the server
12
8
#[derive(Args, Debug)]
13
pub struct Serve<B: Backend> {
14
2
    /// The UDP port for the BonsaiDb protocol. Defaults to UDP port 5645 (not
15
1
    /// an [officially registered
16
1
    /// port](https://github.com/khonsulabs/bonsaidb/issues/48)).
17
1
    #[clap(short = 'l', long = "listen-on")]
18
1
    pub listen_on: Option<u16>,
19

            
20
    #[cfg(any(feature = "websockets", feature = "acme"))]
21
    /// The bind port and address for HTTP traffic. Defaults to TCP port 80.
22
    #[clap(long = "http")]
23
    pub http_port: Option<SocketAddr>,
24

            
25
    #[cfg(any(feature = "websockets", feature = "acme"))]
26
    /// The bind port and address for HTTPS traffic. Defaults to TCP port 443.
27
    #[clap(long = "https")]
28
    pub https_port: Option<SocketAddr>,
29

            
30
    #[clap(skip)]
31
    _backend: PhantomData<B>,
32
}
33

            
34
impl<B: Backend> Serve<B> {
35
    /// Starts the server.
36
1
    pub async fn execute(&self, server: &CustomServer<B>) -> Result<(), BackendError<B::Error>> {
37
5
        self.execute_with(server, ()).await
38
    }
39

            
40
    /// Starts the server using `service` for websocket connections, if enabled.
41
    #[cfg_attr(
42
        not(any(feature = "websockets", feature = "acme")),
43
        allow(unused_variables)
44
    )]
45
1
    pub async fn execute_with<S: TcpService>(
46
1
        &self,
47
1
        server: &CustomServer<B>,
48
1
        service: S,
49
1
    ) -> Result<(), BackendError<B::Error>> {
50
1
        // Try to initialize a logger, but ignore it if it fails. This API is
51
1
        // public and another logger may already be installed.
52
1
        drop(env_logger::try_init());
53
1
        let listen_on = self.listen_on.unwrap_or(5645);
54
1

            
55
1
        #[cfg(any(feature = "websockets", feature = "acme"))]
56
1
        {
57
1
            let listen_address = self.http_port.unwrap_or_else(|| {
58
1
                SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 80, 0, 0))
59
1
            });
60
1
            let task_server = server.clone();
61
1
            let task_service = service.clone();
62
1
            tokio::task::spawn(async move {
63
1
                task_server
64
1
                    .listen_for_tcp_on(listen_address, task_service)
65
                    .await
66
1
            });
67
1

            
68
1
            let listen_address = self.https_port.unwrap_or_else(|| {
69
1
                SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 443, 0, 0))
70
1
            });
71
1
            let task_server = server.clone();
72
1
            tokio::task::spawn(async move {
73
1
                task_server
74
3
                    .listen_for_secure_tcp_on(listen_address, service)
75
3
                    .await
76
1
            });
77
1

            
78
1
            #[cfg(feature = "acme")]
79
2
            if server.certificate_chain().await.is_err() {
80
                log::warn!("Server has no certificate chain. Because acme is enabled, waiting for certificate to be acquired.");
81
                while server.certificate_chain().await.is_err() {
82
                    tokio::time::sleep(Duration::from_secs(1)).await;
83
                }
84
                log::info!("Server certificate acquired. Listening for certificate");
85
1
            }
86
        }
87

            
88
1
        let task_server = server.clone();
89
7
        tokio::task::spawn(async move { task_server.listen_on(listen_on).await });
90
1

            
91
3
        server.listen_for_shutdown().await?;
92

            
93
        Ok(())
94
    }
95
}