1
use std::collections::HashMap;
2
use std::sync::atomic::{AtomicU32, Ordering};
3
use std::sync::Arc;
4
use std::time::Duration;
5

            
6
use bonsaidb_core::api::ApiName;
7
use bonsaidb_core::networking::Payload;
8
use bonsaidb_utils::fast_async_lock;
9
use fabruic::{self, Certificate, Endpoint};
10
use flume::Receiver;
11
use futures::StreamExt;
12
use url::Url;
13

            
14
use super::PendingRequest;
15
use crate::client::{
16
    disconnect_pending_requests, AnyApiCallback, ConnectionInfo, OutstandingRequestMapHandle,
17
};
18
use crate::Error;
19

            
20
/// This function will establish a connection and try to keep it active. If an
21
/// error occurs, any queries that come in while reconnecting will have the
22
/// error replayed to them.
23
1296
pub(super) async fn reconnecting_client_loop(
24
1296
    mut server: ConnectionInfo,
25
1296
    protocol_version: &'static str,
26
1296
    certificate: Option<Certificate>,
27
1296
    request_receiver: Receiver<PendingRequest>,
28
1296
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
29
1296
    connection_counter: Arc<AtomicU32>,
30
1296
) -> Result<(), Error> {
31
1296
    if server.url.port().is_none() && server.url.scheme() == "bonsaidb" {
32
72
        let _: Result<_, _> = server.url.set_port(Some(5645));
33
1224
    }
34

            
35
1296
    server.subscribers.clear();
36
1296
    let mut pending_error = None;
37
1818
    while let Ok(request) = request_receiver.recv_async().await {
38
1332
        if let Some(pending_error) = pending_error.take() {
39
18
            drop(request.responder.send(Err(pending_error)));
40
18
            continue;
41
1314
        }
42
1314
        connection_counter.fetch_add(1, Ordering::SeqCst);
43
1314
        if let Err((failed_request, Some(err))) = connect_and_process(
44
1314
            &server.url,
45
1314
            protocol_version,
46
1314
            certificate.as_ref(),
47
1314
            request,
48
1314
            &request_receiver,
49
1314
            custom_apis.clone(),
50
1314
            server.connect_timeout,
51
1314
        )
52
439434
        .await
53
        {
54
432
            if let Some(failed_request) = failed_request {
55
54
                drop(failed_request.responder.send(Err(err)));
56
378
            } else {
57
378
                pending_error = Some(err);
58
378
            }
59
72
        }
60
    }
61

            
62
450
    Ok(())
63
450
}
64

            
65
1314
async fn connect_and_process(
66
1314
    url: &Url,
67
1314
    protocol_version: &str,
68
1314
    certificate: Option<&Certificate>,
69
1314
    initial_request: PendingRequest,
70
1314
    request_receiver: &Receiver<PendingRequest>,
71
1314
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
72
1314
    connect_timeout: Duration,
73
1314
) -> Result<(), (Option<PendingRequest>, Option<Error>)> {
74
1260
    let (_connection, payload_sender, payload_receiver) =
75
1314
        match tokio::time::timeout(connect_timeout, connect(url, certificate, protocol_version))
76
1314
            .await
77
        {
78
1260
            Ok(Ok(result)) => result,
79
18
            Ok(Err(err)) => return Err((Some(initial_request), Some(err))),
80
36
            Err(_) => return Err((Some(initial_request), Some(Error::connect_timeout()))),
81
        };
82

            
83
1260
    let outstanding_requests = OutstandingRequestMapHandle::default();
84
1260
    let request_processor = tokio::spawn(process(
85
1260
        outstanding_requests.clone(),
86
1260
        payload_receiver,
87
1260
        custom_apis,
88
1260
    ));
89

            
90
1260
    if let Err(err) = payload_sender.send(&initial_request.request) {
91
        return Err((Some(initial_request), Some(Error::from(err))));
92
1260
    }
93

            
94
1260
    {
95
1260
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
96
1260
        outstanding_requests.insert(
97
1260
            initial_request
98
1260
                .request
99
1260
                .id
100
1260
                .expect("all requests require ids"),
101
1260
            initial_request,
102
1260
        );
103
    }
104

            
105
1260
    if let Err(err) = futures::try_join!(
106
439380
        process_requests(
107
439380
            outstanding_requests.clone(),
108
439380
            request_receiver,
109
439380
            payload_sender
110
439380
        ),
111
439380
        async { request_processor.await.map_err(|_| Error::disconnected())? }
112
439380
    ) {
113
450
        let mut pending_error = Some(err);
114
450
        // Our socket was disconnected, clear the outstanding requests before returning.
115
450
        disconnect_pending_requests(&outstanding_requests, &mut pending_error).await;
116
450
        return Err((None, pending_error));
117
    }
118

            
119
    Ok(())
120
504
}
121

            
122
1260
async fn process_requests(
123
1260
    outstanding_requests: OutstandingRequestMapHandle,
124
1260
    request_receiver: &Receiver<PendingRequest>,
125
1260
    payload_sender: fabruic::Sender<Payload>,
126
1260
) -> Result<(), Error> {
127
625986
    while let Ok(client_request) = request_receiver.recv_async().await {
128
624726
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
129
624726
        payload_sender.send(&client_request.request)?;
130
624726
        outstanding_requests.insert(
131
624726
            client_request.request.id.expect("all requests require ids"),
132
624726
            client_request,
133
624726
        );
134
    }
135

            
136
432
    drop(payload_sender.finish());
137
432

            
138
432
    // Return an error to make sure try_join returns.
139
432
    Err(Error::disconnected())
140
432
}
141

            
142
1260
pub async fn process(
143
1260
    outstanding_requests: OutstandingRequestMapHandle,
144
1260
    mut payload_receiver: fabruic::Receiver<Payload>,
145
1260
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
146
1260
) -> Result<(), Error> {
147
627390
    while let Some(payload) = payload_receiver.next().await {
148
626310
        let payload = payload?;
149
626130
        super::process_response_payload(payload, &outstanding_requests, &custom_apis).await;
150
    }
151

            
152
90
    Err(Error::disconnected())
153
270
}
154

            
155
1314
async fn connect(
156
1314
    url: &Url,
157
1314
    certificate: Option<&Certificate>,
158
1314
    protocol_version: &str,
159
1314
) -> Result<
160
1314
    (
161
1314
        fabruic::Connection<()>,
162
1314
        fabruic::Sender<Payload>,
163
1314
        fabruic::Receiver<Payload>,
164
1314
    ),
165
1314
    Error,
166
1314
> {
167
1314
    let mut endpoint = Endpoint::builder();
168
1314
    endpoint
169
1314
        .set_max_idle_timeout(None)
170
1314
        .map_err(|err| Error::Core(bonsaidb_core::Error::other("quic", err)))?;
171
1314
    endpoint.set_protocols([protocol_version.as_bytes().to_vec()]);
172
1314
    let endpoint = endpoint
173
1314
        .build()
174
1314
        .map_err(|err| Error::Core(bonsaidb_core::Error::other("quic", err)))?;
175
1314
    let connecting = if let Some(certificate) = certificate {
176
1278
        endpoint.connect_pinned(url, certificate, None).await?
177
    } else {
178
36
        endpoint.connect(url).await?
179
    };
180

            
181
1314
    let connection = connecting.accept::<()>().await.map_err(|err| {
182
18
        if matches!(err, fabruic::error::Connecting::ProtocolMismatch) {
183
18
            Error::ProtocolVersionMismatch
184
        } else {
185
            Error::from(err)
186
        }
187
1278
    })?;
188
1260
    let (sender, receiver) = connection.open_stream(&()).await?;
189

            
190
1260
    Ok((connection, sender, receiver))
191
1278
}