1
use std::collections::HashMap;
2
use std::sync::atomic::{AtomicU32, Ordering};
3
use std::sync::Arc;
4
use std::time::Duration;
5

            
6
use bonsaidb_core::api::ApiName;
7
use bonsaidb_core::networking::Payload;
8
use bonsaidb_utils::fast_async_lock;
9
use fabruic::{self, Certificate, Endpoint};
10
use flume::Receiver;
11
use futures::StreamExt;
12
use url::Url;
13

            
14
use super::PendingRequest;
15
use crate::client::{
16
    disconnect_pending_requests, AnyApiCallback, ConnectionInfo, OutstandingRequestMapHandle,
17
};
18
use crate::Error;
19

            
20
/// This function will establish a connection and try to keep it active. If an
21
/// error occurs, any queries that come in while reconnecting will have the
22
/// error replayed to them.
23
2340
pub(super) async fn reconnecting_client_loop(
24
2340
    mut server: ConnectionInfo,
25
2340
    protocol_version: &'static str,
26
2340
    certificate: Option<Certificate>,
27
2340
    request_receiver: Receiver<PendingRequest>,
28
2340
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
29
2340
    connection_counter: Arc<AtomicU32>,
30
2340
) -> Result<(), Error> {
31
2340
    if server.url.port().is_none() && server.url.scheme() == "bonsaidb" {
32
120
        let _: Result<_, _> = server.url.set_port(Some(5645));
33
2220
    }
34

            
35
2340
    server.subscribers.clear();
36
2340
    let mut pending_error = None;
37
3300
    while let Ok(request) = request_receiver.recv_async().await {
38
2400
        if let Some(pending_error) = pending_error.take() {
39
30
            drop(request.responder.send(Err(pending_error)));
40
30
            continue;
41
2370
        }
42
2370
        connection_counter.fetch_add(1, Ordering::SeqCst);
43
2370
        if let Err((failed_request, Some(err))) = connect_and_process(
44
2370
            &server.url,
45
2370
            protocol_version,
46
2370
            certificate.as_ref(),
47
2370
            request,
48
2370
            &request_receiver,
49
2370
            custom_apis.clone(),
50
2370
            server.connect_timeout,
51
2370
        )
52
793800
        .await
53
        {
54
810
            if let Some(failed_request) = failed_request {
55
90
                drop(failed_request.responder.send(Err(err)));
56
720
            } else {
57
720
                pending_error = Some(err);
58
720
            }
59
120
        }
60
    }
61

            
62
840
    Ok(())
63
840
}
64

            
65
2370
async fn connect_and_process(
66
2370
    url: &Url,
67
2370
    protocol_version: &str,
68
2370
    certificate: Option<&Certificate>,
69
2370
    initial_request: PendingRequest,
70
2370
    request_receiver: &Receiver<PendingRequest>,
71
2370
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
72
2370
    connect_timeout: Duration,
73
2370
) -> Result<(), (Option<PendingRequest>, Option<Error>)> {
74
2280
    let (_connection, payload_sender, payload_receiver) =
75
2370
        match tokio::time::timeout(connect_timeout, connect(url, certificate, protocol_version))
76
2370
            .await
77
        {
78
2280
            Ok(Ok(result)) => result,
79
30
            Ok(Err(err)) => return Err((Some(initial_request), Some(err))),
80
60
            Err(_) => return Err((Some(initial_request), Some(Error::connect_timeout()))),
81
        };
82

            
83
2280
    let outstanding_requests = OutstandingRequestMapHandle::default();
84
2280
    let request_processor = tokio::spawn(process(
85
2280
        outstanding_requests.clone(),
86
2280
        payload_receiver,
87
2280
        custom_apis,
88
2280
    ));
89

            
90
2280
    if let Err(err) = payload_sender.send(&initial_request.request) {
91
        return Err((Some(initial_request), Some(Error::from(err))));
92
2280
    }
93

            
94
2280
    {
95
2280
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
96
2280
        outstanding_requests.insert(
97
2280
            initial_request
98
2280
                .request
99
2280
                .id
100
2280
                .expect("all requests require ids"),
101
2280
            initial_request,
102
2280
        );
103
    }
104

            
105
793710
    if let Err(err) = futures::try_join!(
106
793710
        process_requests(
107
793710
            outstanding_requests.clone(),
108
793710
            request_receiver,
109
793710
            payload_sender
110
793710
        ),
111
793710
        async { request_processor.await.map_err(|_| Error::disconnected())? }
112
793710
    ) {
113
840
        let mut pending_error = Some(err);
114
840
        // Our socket was disconnected, clear the outstanding requests before returning.
115
840
        disconnect_pending_requests(&outstanding_requests, &mut pending_error).await;
116
840
        return Err((None, pending_error));
117
    }
118

            
119
    Ok(())
120
930
}
121

            
122
2280
async fn process_requests(
123
2280
    outstanding_requests: OutstandingRequestMapHandle,
124
2280
    request_receiver: &Receiver<PendingRequest>,
125
2280
    payload_sender: fabruic::Sender<Payload>,
126
2280
) -> Result<(), Error> {
127
1038090
    while let Ok(client_request) = request_receiver.recv_async().await {
128
1035810
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
129
1035810
        payload_sender.send(&client_request.request)?;
130
1035810
        outstanding_requests.insert(
131
1035810
            client_request.request.id.expect("all requests require ids"),
132
1035810
            client_request,
133
1035810
        );
134
    }
135

            
136
810
    drop(payload_sender.finish());
137
810

            
138
810
    // Return an error to make sure try_join returns.
139
810
    Err(Error::disconnected())
140
810
}
141

            
142
2280
pub async fn process(
143
2280
    outstanding_requests: OutstandingRequestMapHandle,
144
2280
    mut payload_receiver: fabruic::Receiver<Payload>,
145
2280
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
146
2280
) -> Result<(), Error> {
147
1040610
    while let Some(payload) = payload_receiver.next().await {
148
1038810
        let payload = payload?;
149
1038330
        super::process_response_payload(payload, &outstanding_requests, &custom_apis).await;
150
    }
151

            
152
150
    Err(Error::disconnected())
153
630
}
154

            
155
2370
async fn connect(
156
2370
    url: &Url,
157
2370
    certificate: Option<&Certificate>,
158
2370
    protocol_version: &str,
159
2370
) -> Result<
160
2370
    (
161
2370
        fabruic::Connection<()>,
162
2370
        fabruic::Sender<Payload>,
163
2370
        fabruic::Receiver<Payload>,
164
2370
    ),
165
2370
    Error,
166
2370
> {
167
2370
    let mut endpoint = Endpoint::builder();
168
2370
    endpoint
169
2370
        .set_max_idle_timeout(None)
170
2370
        .map_err(|err| Error::Core(bonsaidb_core::Error::other("quic", err)))?;
171
2370
    endpoint.set_protocols([protocol_version.as_bytes().to_vec()]);
172
2370
    let endpoint = endpoint
173
2370
        .build()
174
2370
        .map_err(|err| Error::Core(bonsaidb_core::Error::other("quic", err)))?;
175
2370
    let connecting = if let Some(certificate) = certificate {
176
2310
        endpoint.connect_pinned(url, certificate, None).await?
177
    } else {
178
60
        endpoint.connect(url).await?
179
    };
180

            
181
2370
    let connection = connecting.accept::<()>().await.map_err(|err| {
182
30
        if matches!(err, fabruic::error::Connecting::ProtocolMismatch) {
183
30
            Error::ProtocolVersionMismatch
184
        } else {
185
            Error::from(err)
186
        }
187
2310
    })?;
188
2280
    let (sender, receiver) = connection.open_stream(&()).await?;
189

            
190
2280
    Ok((connection, sender, receiver))
191
2310
}