1
use std::collections::HashMap;
2
use std::sync::atomic::{AtomicU32, Ordering};
3
use std::sync::Arc;
4
use std::time::Duration;
5

            
6
use bonsaidb_core::api::ApiName;
7
use bonsaidb_core::networking::Payload;
8
use bonsaidb_utils::fast_async_lock;
9
use fabruic::{self, Certificate, Endpoint};
10
use flume::Receiver;
11
use futures::StreamExt;
12
use url::Url;
13

            
14
use super::PendingRequest;
15
use crate::client::{
16
    disconnect_pending_requests, AnyApiCallback, ConnectionInfo, OutstandingRequestMapHandle,
17
};
18
use crate::Error;
19

            
20
/// This function will establish a connection and try to keep it active. If an
21
/// error occurs, any queries that come in while reconnecting will have the
22
/// error replayed to them.
23
1224
pub(super) async fn reconnecting_client_loop(
24
1224
    mut server: ConnectionInfo,
25
1224
    protocol_version: &'static str,
26
1224
    certificate: Option<Certificate>,
27
1224
    request_receiver: Receiver<PendingRequest>,
28
1224
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
29
1224
    connection_counter: Arc<AtomicU32>,
30
1224
) -> Result<(), Error> {
31
1224
    if server.url.port().is_none() && server.url.scheme() == "bonsaidb" {
32
68
        let _: Result<_, _> = server.url.set_port(Some(5645));
33
1156
    }
34

            
35
1224
    server.subscribers.clear();
36
1224
    let mut pending_error = None;
37
1700
    while let Ok(request) = request_receiver.recv_async().await {
38
1258
        if let Some(pending_error) = pending_error.take() {
39
17
            drop(request.responder.send(Err(pending_error)));
40
17
            continue;
41
1241
        }
42
1241
        connection_counter.fetch_add(1, Ordering::SeqCst);
43
1241
        if let Err((failed_request, Some(err))) = connect_and_process(
44
1241
            &server.url,
45
1241
            protocol_version,
46
1241
            certificate.as_ref(),
47
1241
            request,
48
1241
            &request_receiver,
49
1241
            custom_apis.clone(),
50
1241
            server.connect_timeout,
51
1241
        )
52
410958
        .await
53
        {
54
391
            if let Some(failed_request) = failed_request {
55
51
                drop(failed_request.responder.send(Err(err)));
56
340
            } else {
57
340
                pending_error = Some(err);
58
340
            }
59
68
        }
60
    }
61

            
62
408
    Ok(())
63
408
}
64

            
65
1241
async fn connect_and_process(
66
1241
    url: &Url,
67
1241
    protocol_version: &str,
68
1241
    certificate: Option<&Certificate>,
69
1241
    initial_request: PendingRequest,
70
1241
    request_receiver: &Receiver<PendingRequest>,
71
1241
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
72
1241
    connect_timeout: Duration,
73
1241
) -> Result<(), (Option<PendingRequest>, Option<Error>)> {
74
1190
    let (_connection, payload_sender, payload_receiver) =
75
1241
        match tokio::time::timeout(connect_timeout, connect(url, certificate, protocol_version))
76
1241
            .await
77
        {
78
1190
            Ok(Ok(result)) => result,
79
17
            Ok(Err(err)) => return Err((Some(initial_request), Some(err))),
80
34
            Err(_) => return Err((Some(initial_request), Some(Error::connect_timeout()))),
81
        };
82

            
83
1190
    let outstanding_requests = OutstandingRequestMapHandle::default();
84
1190
    let request_processor = tokio::spawn(process(
85
1190
        outstanding_requests.clone(),
86
1190
        payload_receiver,
87
1190
        custom_apis,
88
1190
    ));
89

            
90
1190
    if let Err(err) = payload_sender.send(&initial_request.request) {
91
        return Err((Some(initial_request), Some(Error::from(err))));
92
1190
    }
93

            
94
1190
    {
95
1190
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
96
1190
        outstanding_requests.insert(
97
1190
            initial_request
98
1190
                .request
99
1190
                .id
100
1190
                .expect("all requests require ids"),
101
1190
            initial_request,
102
1190
        );
103
    }
104

            
105
1190
    if let Err(err) = futures::try_join!(
106
410907
        process_requests(
107
410907
            outstanding_requests.clone(),
108
410907
            request_receiver,
109
410907
            payload_sender
110
410907
        ),
111
410907
        async { request_processor.await.map_err(|_| Error::disconnected())? }
112
410907
    ) {
113
408
        let mut pending_error = Some(err);
114
408
        // Our socket was disconnected, clear the outstanding requests before returning.
115
408
        disconnect_pending_requests(&outstanding_requests, &mut pending_error).await;
116
408
        return Err((None, pending_error));
117
    }
118

            
119
    Ok(())
120
459
}
121

            
122
1190
async fn process_requests(
123
1190
    outstanding_requests: OutstandingRequestMapHandle,
124
1190
    request_receiver: &Receiver<PendingRequest>,
125
1190
    payload_sender: fabruic::Sender<Payload>,
126
1190
) -> Result<(), Error> {
127
595935
    while let Ok(client_request) = request_receiver.recv_async().await {
128
594745
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
129
594745
        payload_sender.send(&client_request.request)?;
130
594745
        outstanding_requests.insert(
131
594745
            client_request.request.id.expect("all requests require ids"),
132
594745
            client_request,
133
594745
        );
134
    }
135

            
136
391
    drop(payload_sender.finish());
137
391

            
138
391
    // Return an error to make sure try_join returns.
139
391
    Err(Error::disconnected())
140
391
}
141

            
142
1190
pub async fn process(
143
1190
    outstanding_requests: OutstandingRequestMapHandle,
144
1190
    mut payload_receiver: fabruic::Receiver<Payload>,
145
1190
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
146
1190
) -> Result<(), Error> {
147
597261
    while let Some(payload) = payload_receiver.next().await {
148
596241
        let payload = payload?;
149
596071
        super::process_response_payload(payload, &outstanding_requests, &custom_apis).await;
150
    }
151

            
152
68
    Err(Error::disconnected())
153
238
}
154

            
155
1241
async fn connect(
156
1241
    url: &Url,
157
1241
    certificate: Option<&Certificate>,
158
1241
    protocol_version: &str,
159
1241
) -> Result<
160
1241
    (
161
1241
        fabruic::Connection<()>,
162
1241
        fabruic::Sender<Payload>,
163
1241
        fabruic::Receiver<Payload>,
164
1241
    ),
165
1241
    Error,
166
1241
> {
167
1241
    let mut endpoint = Endpoint::builder();
168
1241
    endpoint
169
1241
        .set_max_idle_timeout(None)
170
1241
        .map_err(|err| Error::Core(bonsaidb_core::Error::other("quic", err)))?;
171
1241
    endpoint.set_protocols([protocol_version.as_bytes().to_vec()]);
172
1241
    let endpoint = endpoint
173
1241
        .build()
174
1241
        .map_err(|err| Error::Core(bonsaidb_core::Error::other("quic", err)))?;
175
1241
    let connecting = if let Some(certificate) = certificate {
176
1207
        endpoint.connect_pinned(url, certificate, None).await?
177
    } else {
178
34
        endpoint.connect(url).await?
179
    };
180

            
181
1241
    let connection = connecting.accept::<()>().await.map_err(|err| {
182
17
        if matches!(err, fabruic::error::Connecting::ProtocolMismatch) {
183
17
            Error::ProtocolVersionMismatch
184
        } else {
185
            Error::from(err)
186
        }
187
1207
    })?;
188
1190
    let (sender, receiver) = connection.open_stream(&()).await?;
189

            
190
1190
    Ok((connection, sender, receiver))
191
1207
}