1
use std::{collections::HashMap, sync::Arc};
2

            
3
use bonsaidb_core::{networking::Payload, schema::ApiName};
4
use bonsaidb_utils::fast_async_lock;
5
use fabruic::{self, Certificate, Endpoint};
6
use flume::Receiver;
7
use futures::StreamExt;
8
use url::Url;
9

            
10
use super::PendingRequest;
11
use crate::{
12
    client::{AnyApiCallback, OutstandingRequestMapHandle, SubscriberMap},
13
    Error,
14
};
15

            
16
/// This function will establish a connection and try to keep it active. If an
17
/// error occurs, any queries that come in while reconnecting will have the
18
/// error replayed to them.
19
1178
pub async fn reconnecting_client_loop(
20
1178
    mut url: Url,
21
1178
    protocol_version: &'static str,
22
1178
    certificate: Option<Certificate>,
23
1178
    request_receiver: Receiver<PendingRequest>,
24
1178
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
25
1178
    subscribers: SubscriberMap,
26
1178
) -> Result<(), Error> {
27
1178
    if url.port().is_none() && url.scheme() == "bonsaidb" {
28
76
        let _ = url.set_port(Some(5645));
29
1102
    }
30

            
31
1178
    subscribers.clear();
32
1235
    while let Ok(request) = request_receiver.recv_async().await {
33
1178
        if let Err((failed_request, err)) = connect_and_process(
34
1178
            &url,
35
1178
            protocol_version,
36
1178
            certificate.as_ref(),
37
1178
            request,
38
1178
            &request_receiver,
39
1178
            custom_apis.clone(),
40
375117
        )
41
375117
        .await
42
        {
43
57
            if let Some(failed_request) = failed_request {
44
19
                drop(failed_request.responder.send(Err(err)));
45
38
            }
46
57
            continue;
47
        }
48
    }
49

            
50
38
    Ok(())
51
38
}
52

            
53
1178
async fn connect_and_process(
54
1178
    url: &Url,
55
1178
    protocol_version: &str,
56
1178
    certificate: Option<&Certificate>,
57
1178
    initial_request: PendingRequest,
58
1178
    request_receiver: &Receiver<PendingRequest>,
59
1178
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
60
1178
) -> Result<(), (Option<PendingRequest>, Error)> {
61
1159
    let (_connection, payload_sender, payload_receiver) =
62
1178
        match connect(url, certificate, protocol_version).await {
63
1159
            Ok(result) => result,
64
19
            Err(err) => return Err((Some(initial_request), err)),
65
        };
66

            
67
1159
    let outstanding_requests = OutstandingRequestMapHandle::default();
68
1159
    let request_processor = tokio::spawn(process(
69
1159
        outstanding_requests.clone(),
70
1159
        payload_receiver,
71
1159
        custom_apis,
72
1159
    ));
73

            
74
1159
    if let Err(err) = payload_sender.send(&initial_request.request) {
75
        return Err((Some(initial_request), Error::from(err)));
76
1159
    }
77

            
78
1159
    {
79
1159
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
80
1159
        outstanding_requests.insert(
81
1159
            initial_request
82
1159
                .request
83
1159
                .id
84
1159
                .expect("all requests require ids"),
85
1159
            initial_request,
86
1159
        );
87
    }
88

            
89
1159
    if let Err(err) = futures::try_join!(
90
375098
        process_requests(
91
375098
            outstanding_requests.clone(),
92
375098
            request_receiver,
93
375098
            payload_sender
94
375098
        ),
95
375098
        async { request_processor.await.map_err(|_| Error::Disconnected)? }
96
375098
    ) {
97
        // Our socket was disconnected, clear the outstanding requests before returning.
98
38
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
99
38
        for (_, pending) in outstanding_requests.drain() {
100
            drop(pending.responder.send(Err(Error::Disconnected)));
101
        }
102
38
        return Err((None, err));
103
    }
104

            
105
    Ok(())
106
57
}
107

            
108
1159
async fn process_requests(
109
1159
    outstanding_requests: OutstandingRequestMapHandle,
110
1159
    request_receiver: &Receiver<PendingRequest>,
111
1159
    payload_sender: fabruic::Sender<Payload>,
112
1159
) -> Result<(), Error> {
113
649116
    while let Ok(client_request) = request_receiver.recv_async().await {
114
647957
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
115
647957
        payload_sender.send(&client_request.request)?;
116
647957
        outstanding_requests.insert(
117
647957
            client_request.request.id.expect("all requests require ids"),
118
647957
            client_request,
119
647957
        );
120
    }
121

            
122
    // Return an error to make sure try_join returns.
123
38
    Err(Error::Disconnected)
124
38
}
125

            
126
1159
pub async fn process(
127
1159
    outstanding_requests: OutstandingRequestMapHandle,
128
1159
    mut payload_receiver: fabruic::Receiver<Payload>,
129
1159
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
130
1159
) -> Result<(), Error> {
131
650560
    while let Some(payload) = payload_receiver.next().await {
132
649667
        let payload = payload?;
133
649401
        super::process_response_payload(payload, &outstanding_requests, &custom_apis).await;
134
    }
135

            
136
    Err(Error::Disconnected)
137
266
}
138

            
139
1178
async fn connect(
140
1178
    url: &Url,
141
1178
    certificate: Option<&Certificate>,
142
1178
    protocol_version: &str,
143
1178
) -> Result<
144
1178
    (
145
1178
        fabruic::Connection<()>,
146
1178
        fabruic::Sender<Payload>,
147
1178
        fabruic::Receiver<Payload>,
148
1178
    ),
149
1178
    Error,
150
1178
> {
151
1178
    let mut endpoint = Endpoint::builder();
152
1178
    endpoint
153
1178
        .set_max_idle_timeout(None)
154
1178
        .map_err(|err| Error::Core(bonsaidb_core::Error::Transport(err.to_string())))?;
155
1178
    endpoint.set_protocols([protocol_version.as_bytes().to_vec()]);
156
1178
    let endpoint = endpoint
157
1178
        .build()
158
1178
        .map_err(|err| Error::Core(bonsaidb_core::Error::Transport(err.to_string())))?;
159
1178
    let connecting = if let Some(certificate) = certificate {
160
1178
        endpoint.connect_pinned(url, certificate, None).await?
161
    } else {
162
        endpoint.connect(url).await?
163
    };
164

            
165
1178
    let connection = connecting.accept::<()>().await.map_err(|err| {
166
19
        if matches!(err, fabruic::error::Connecting::ProtocolMismatch) {
167
19
            Error::ProtocolVersionMismatch
168
        } else {
169
            Error::from(err)
170
        }
171
1178
    })?;
172
1159
    let (sender, receiver) = connection.open_stream(&()).await?;
173

            
174
1159
    Ok((connection, sender, receiver))
175
1178
}