1
1
//! Shows how to use the axum web framework with BonsaiDb. Any hyper-compatible
2
//! framework should be usable.
3

            
4
use async_trait::async_trait;
5
use axum::{extract, routing::get, AddExtensionLayer, Router};
6
use bonsaidb::{
7
    core::{connection::StorageConnection, keyvalue::KeyValue},
8
    local::config::Builder,
9
    server::{
10
        DefaultPermissions, HttpService, Peer, Server, ServerConfiguration, StandardTcpProtocols,
11
    },
12
};
13
use hyper::{server::conn::Http, Body, Request, Response};
14
#[cfg(feature = "client")]
15
use ::{std::time::Duration, url::Url};
16

            
17
3
#[derive(Debug, Clone)]
18
pub struct AxumService {
19
    server: Server,
20
}
21

            
22
#[async_trait]
23
impl HttpService for AxumService {
24
3
    async fn handle_connection<
25
3
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
26
3
    >(
27
3
        &self,
28
3
        connection: S,
29
3
        peer: &Peer<StandardTcpProtocols>,
30
3
    ) -> Result<(), S> {
31
3
        let server = self.server.clone();
32
3
        let app = Router::new()
33
3
            .route("/", get(uptime_handler))
34
3
            .route("/ws", get(upgrade_websocket))
35
3
            // Attach the server and the remote address as extractable data for the /ws route
36
3
            .layer(AddExtensionLayer::new(server))
37
3
            .layer(AddExtensionLayer::new(peer.address));
38

            
39
3
        if let Err(err) = Http::new()
40
3
            .serve_connection(connection, app)
41
6
            .with_upgrades()
42
6
            .await
43
        {
44
            log::error!("[http] error serving {}: {:?}", peer.address, err);
45
3
        }
46

            
47
3
        Ok(())
48
6
    }
49
}
50

            
51
#[tokio::main]
52
1
async fn main() -> anyhow::Result<()> {
53
1
    env_logger::init();
54
1
    let server = Server::open(
55
1
        ServerConfiguration::new("http-server-data.bonsaidb")
56
1
            .default_permissions(DefaultPermissions::AllowAll)
57
1
            .with_schema::<()>()?,
58
19
    )
59
19
    .await?;
60
2
    server.create_database::<()>("storage", true).await?;
61

            
62
    #[cfg(feature = "client")]
63
1
    {
64
1
        // This is silly to do over a websocket connection, because it can
65
1
        // easily be done by just using `server` instead. However, this is to
66
1
        // demonstrate that websocket connections work in this example.
67
1
        let client = bonsaidb::client::Client::build(Url::parse("ws://localhost:8080/ws")?)
68
1
            .finish()
69
            .await?;
70
1
        tokio::spawn(async move {
71
            loop {
72
3
                tokio::time::sleep(Duration::from_secs(1)).await;
73
2
                let db = client.database::<()>("storage").await.unwrap();
74
2
                db.increment_key_by("uptime", 1_u64).await.unwrap();
75
            }
76
1
        });
77
1
    }
78
1

            
79
1
    server
80
1
        .listen_for_tcp_on(
81
1
            "localhost:8080",
82
1
            AxumService {
83
1
                server: server.clone(),
84
1
            },
85
4
        )
86
4
        .await?;
87

            
88
    Ok(())
89
}
90

            
91
2
async fn uptime_handler(server: extract::Extension<Server>) -> String {
92
2
    let db = server.database::<()>("storage").await.unwrap();
93
2
    format!(
94
2
        "Current uptime: {} seconds",
95
2
        db.get_key("uptime")
96
2
            .into_u64()
97
            .await
98
2
            .unwrap()
99
2
            .unwrap_or_default()
100
2
    )
101
2
}
102

            
103
1
async fn upgrade_websocket(
104
1
    server: extract::Extension<Server>,
105
1
    peer_address: extract::Extension<std::net::SocketAddr>,
106
1
    req: Request<Body>,
107
1
) -> Response<Body> {
108
1
    server.upgrade_websocket(*peer_address, req).await
109
1
}
110

            
111
1
#[tokio::test]
112
#[cfg_attr(not(feature = "client"), allow(unused_variables))]
113
1
async fn test() {
114
1
    use axum::body::HttpBody;
115
1

            
116
1
    std::thread::spawn(|| main().unwrap());
117
1

            
118
3
    let retrieve_uptime = || async {
119
3
        let client = hyper::Client::new();
120
11
        let mut response = match client.get("http://localhost:8080/".parse().unwrap()).await {
121
2
            Ok(response) => response,
122
1
            Err(err) if err.is_connect() => {
123
1
                return None;
124
            }
125
            Err(other) => unreachable!(other),
126
        };
127

            
128
2
        assert_eq!(response.status(), 200);
129

            
130
2
        let body = response
131
2
            .body_mut()
132
2
            .data()
133
            .await
134
2
            .expect("no response")
135
2
            .unwrap();
136
2
        let body = String::from_utf8(body.to_vec()).unwrap();
137
2
        assert!(body.contains("Current uptime: "));
138
2
        Some(body)
139
3
    };
140

            
141
1
    let mut retries_left = 5;
142
1
    let original_uptime = loop {
143
7
        if let Some(uptime) = retrieve_uptime().await {
144
1
            break uptime;
145
1
        } else if retries_left > 0 {
146
1
            println!("Waiting for server to start");
147
1
            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
148

            
149
1
            retries_left -= 1;
150
        } else {
151
            unreachable!("Unable to connect to axum server.")
152
        }
153
    };
154

            
155
    #[cfg(feature = "client")]
156
    {
157
        // If we have the client, we're expecting the uptime to increase every second
158
1
        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
159
4
        let new_uptime = retrieve_uptime().await.unwrap();
160
1
        assert_ne!(original_uptime, new_uptime);
161
    }
162
1
}