1
1
//! Shows how to use the axum web framework with BonsaiDb. Any hyper-compatible
2
//! framework should be usable.
3

            
4
use async_trait::async_trait;
5
use axum::{extract, extract::Extension, routing::get, Router};
6
use bonsaidb::{
7
    core::{connection::AsyncStorageConnection, keyvalue::AsyncKeyValue},
8
    local::config::Builder,
9
    server::{
10
        DefaultPermissions, HttpService, Peer, Server, ServerConfiguration, StandardTcpProtocols,
11
    },
12
};
13
use hyper::{server::conn::Http, Body, Request, Response};
14
#[cfg(feature = "client")]
15
use ::{std::time::Duration, url::Url};
16

            
17
3
#[derive(Debug, Clone)]
18
pub struct AxumService {
19
    server: Server,
20
}
21

            
22
#[async_trait]
23
impl HttpService for AxumService {
24
3
    async fn handle_connection<
25
3
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
26
3
    >(
27
3
        &self,
28
3
        connection: S,
29
3
        peer: &Peer<StandardTcpProtocols>,
30
3
    ) -> Result<(), S> {
31
3
        let server = self.server.clone();
32
3
        let app = Router::new()
33
3
            .route("/", get(uptime_handler))
34
3
            .route("/ws", get(upgrade_websocket))
35
3
            // Attach the server and the remote address as extractable data for the /ws route
36
3
            .layer(Extension(server))
37
3
            .layer(Extension(peer.address));
38

            
39
3
        if let Err(err) = Http::new()
40
3
            .serve_connection(connection, app)
41
9
            .with_upgrades()
42
9
            .await
43
        {
44
            log::error!("[http] error serving {}: {:?}", peer.address, err);
45
3
        }
46

            
47
3
        Ok(())
48
6
    }
49
}
50

            
51
#[tokio::main]
52
1
async fn main() -> anyhow::Result<()> {
53
1
    env_logger::init();
54
1
    let server = Server::open(
55
1
        ServerConfiguration::new("http-server-data.bonsaidb")
56
1
            .default_permissions(DefaultPermissions::AllowAll)
57
1
            .with_schema::<()>()?,
58
3
    )
59
3
    .await?;
60
2
    server.create_database::<()>("storage", true).await?;
61

            
62
    #[cfg(feature = "client")]
63
1
    {
64
1
        // This is silly to do over a websocket connection, because it can
65
1
        // easily be done by just using `server` instead. However, this is to
66
1
        // demonstrate that websocket connections work in this example.
67
1
        let client =
68
1
            bonsaidb::client::Client::build(Url::parse("ws://localhost:8080/ws")?).finish()?;
69
1
        tokio::spawn(async move {
70
            loop {
71
3
                tokio::time::sleep(Duration::from_secs(1)).await;
72
2
                let db = client.database::<()>("storage").await.unwrap();
73
2
                db.increment_key_by("uptime", 1_u64).await.unwrap();
74
            }
75
1
        });
76
1
    }
77
1

            
78
1
    server
79
1
        .listen_for_tcp_on(
80
1
            "localhost:8080",
81
1
            AxumService {
82
1
                server: server.clone(),
83
1
            },
84
4
        )
85
4
        .await?;
86

            
87
    Ok(())
88
}
89

            
90
2
async fn uptime_handler(server: extract::Extension<Server>) -> String {
91
2
    let db = server.database::<()>("storage").await.unwrap();
92
2
    format!(
93
2
        "Current uptime: {} seconds",
94
2
        db.get_key("uptime")
95
2
            .into_u64()
96
2
            .await
97
2
            .unwrap()
98
2
            .unwrap_or_default()
99
2
    )
100
2
}
101

            
102
1
async fn upgrade_websocket(
103
1
    server: extract::Extension<Server>,
104
1
    peer_address: extract::Extension<std::net::SocketAddr>,
105
1
    req: Request<Body>,
106
1
) -> Response<Body> {
107
1
    server.upgrade_websocket(*peer_address, req).await
108
1
}
109

            
110
1
#[tokio::test]
111
#[cfg_attr(not(feature = "client"), allow(unused_variables))]
112
1
async fn test() {
113
1
    use axum::body::HttpBody;
114
1

            
115
1
    std::thread::spawn(|| main().unwrap());
116
1

            
117
3
    let retrieve_uptime = || async {
118
3
        let client = hyper::Client::new();
119
11
        let mut response = match client.get("http://localhost:8080/".parse().unwrap()).await {
120
2
            Ok(response) => response,
121
1
            Err(err) if err.is_connect() => {
122
1
                return None;
123
            }
124
            Err(other) => unreachable!("{}", other),
125
        };
126

            
127
2
        assert_eq!(response.status(), 200);
128

            
129
2
        let body = response
130
2
            .body_mut()
131
2
            .data()
132
            .await
133
2
            .expect("no response")
134
2
            .unwrap();
135
2
        let body = String::from_utf8(body.to_vec()).unwrap();
136
2
        assert!(body.contains("Current uptime: "));
137
2
        Some(body)
138
3
    };
139

            
140
1
    let mut retries_left = 5;
141
1
    let original_uptime = loop {
142
7
        if let Some(uptime) = retrieve_uptime().await {
143
1
            break uptime;
144
1
        } else if retries_left > 0 {
145
1
            println!("Waiting for server to start");
146
1
            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
147

            
148
1
            retries_left -= 1;
149
        } else {
150
            unreachable!("Unable to connect to axum server.")
151
        }
152
    };
153

            
154
    #[cfg(feature = "client")]
155
    {
156
        // If we have the client, we're expecting the uptime to increase every second
157
1
        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
158
4
        let new_uptime = retrieve_uptime().await.unwrap();
159
1
        assert_ne!(original_uptime, new_uptime);
160
    }
161
1
}