1
//! Shows how to use the axum web framework with BonsaiDb. Any hyper-compatible
2
//! framework should be usable.
3

            
4
use async_trait::async_trait;
5
use axum::extract::Extension;
6
use axum::routing::get;
7
use axum::{extract, Router};
8
use bonsaidb::core::connection::AsyncStorageConnection;
9
use bonsaidb::core::keyvalue::AsyncKeyValue;
10
use bonsaidb::local::config::Builder;
11
use bonsaidb::server::{
12
    DefaultPermissions, HttpService, Peer, Server, ServerConfiguration, StandardTcpProtocols,
13
};
14
use hyper::server::conn::Http;
15
use hyper::{Body, Request, Response};
16
#[cfg(feature = "client")]
17
use ::{std::time::Duration, url::Url};
18

            
19
3
#[derive(Debug, Clone)]
20
pub struct AxumService {
21
    server: Server,
22
}
23

            
24
#[async_trait]
25
impl HttpService for AxumService {
26
3
    async fn handle_connection<
27
3
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
28
3
    >(
29
3
        &self,
30
3
        connection: S,
31
3
        peer: &Peer<StandardTcpProtocols>,
32
3
    ) -> Result<(), S> {
33
3
        let server = self.server.clone();
34
3
        let app = Router::new()
35
3
            .route("/", get(uptime_handler))
36
3
            .route("/ws", get(upgrade_websocket))
37
3
            // Attach the server and the remote address as extractable data for the /ws route
38
3
            .layer(Extension(server))
39
3
            .layer(Extension(peer.address));
40

            
41
3
        if let Err(err) = Http::new()
42
3
            .serve_connection(connection, app)
43
3
            .with_upgrades()
44
9
            .await
45
        {
46
            log::error!("[http] error serving {}: {:?}", peer.address, err);
47
3
        }
48

            
49
3
        Ok(())
50
6
    }
51
}
52

            
53
#[tokio::main]
54
1
async fn main() -> anyhow::Result<()> {
55
1
    env_logger::init();
56
1
    let server = Server::open(
57
1
        ServerConfiguration::new("http-server-data.bonsaidb")
58
1
            .default_permissions(DefaultPermissions::AllowAll)
59
1
            .with_schema::<()>()?,
60
    )
61
3
    .await?;
62
2
    server.create_database::<()>("storage", true).await?;
63

            
64
    #[cfg(feature = "client")]
65
1
    {
66
1
        // This is silly to do over a websocket connection, because it can
67
1
        // easily be done by just using `server` instead. However, this is to
68
1
        // demonstrate that websocket connections work in this example.
69
1
        let client =
70
1
            bonsaidb::client::AsyncClient::build(Url::parse("ws://localhost:8080/ws")?).build()?;
71
1
        tokio::spawn(async move {
72
            loop {
73
3
                tokio::time::sleep(Duration::from_secs(1)).await;
74
2
                let db = client.database::<()>("storage").await.unwrap();
75
2
                db.increment_key_by("uptime", 1_u64).await.unwrap();
76
            }
77
1
        });
78
1
    }
79
1

            
80
1
    server
81
1
        .listen_for_tcp_on(
82
1
            "localhost:8080",
83
1
            AxumService {
84
1
                server: server.clone(),
85
1
            },
86
1
        )
87
4
        .await?;
88

            
89
    Ok(())
90
}
91

            
92
2
async fn uptime_handler(server: extract::Extension<Server>) -> String {
93
2
    let db = server.database::<()>("storage").await.unwrap();
94
2
    format!(
95
2
        "Current uptime: {} seconds",
96
2
        db.get_key("uptime")
97
2
            .into_u64()
98
2
            .await
99
2
            .unwrap()
100
2
            .unwrap_or_default()
101
2
    )
102
2
}
103

            
104
1
async fn upgrade_websocket(
105
1
    server: extract::Extension<Server>,
106
1
    peer_address: extract::Extension<std::net::SocketAddr>,
107
1
    req: Request<Body>,
108
1
) -> Response<Body> {
109
1
    server.upgrade_websocket(*peer_address, req)
110
1
}
111

            
112
1
#[tokio::test]
113
#[cfg_attr(not(feature = "client"), allow(unused_variables))]
114
1
async fn test() {
115
1
    use axum::body::HttpBody;
116
1

            
117
1
    std::thread::spawn(|| main().unwrap());
118
1

            
119
3
    let retrieve_uptime = || async {
120
3
        let client = hyper::Client::new();
121
11
        let mut response = match client.get("http://localhost:8080/".parse().unwrap()).await {
122
2
            Ok(response) => response,
123
1
            Err(err) if err.is_connect() => {
124
1
                return None;
125
            }
126
            Err(other) => unreachable!("{}", other),
127
        };
128

            
129
2
        assert_eq!(response.status(), 200);
130

            
131
2
        let body = response
132
2
            .body_mut()
133
2
            .data()
134
            .await
135
2
            .expect("no response")
136
2
            .unwrap();
137
2
        let body = String::from_utf8(body.to_vec()).unwrap();
138
2
        assert!(body.contains("Current uptime: "));
139
2
        Some(body)
140
3
    };
141

            
142
1
    let mut retries_left = 5;
143
1
    let original_uptime = loop {
144
7
        if let Some(uptime) = retrieve_uptime().await {
145
1
            break uptime;
146
1
        } else if retries_left > 0 {
147
1
            println!("Waiting for server to start");
148
1
            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
149

            
150
1
            retries_left -= 1;
151
        } else {
152
            unreachable!("Unable to connect to axum server.")
153
        }
154
    };
155

            
156
    #[cfg(feature = "client")]
157
    {
158
        // If we have the client, we're expecting the uptime to increase every second
159
1
        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
160
4
        let new_uptime = retrieve_uptime().await.unwrap();
161
1
        assert_ne!(original_uptime, new_uptime);
162
    }
163
}