1
use std::{collections::HashMap, sync::Arc};
2

            
3
use bonsaidb_core::{networking::Payload, schema::ApiName};
4
use bonsaidb_utils::fast_async_lock;
5
use fabruic::{self, Certificate, Endpoint};
6
use flume::Receiver;
7
use futures::StreamExt;
8
use url::Url;
9

            
10
use super::PendingRequest;
11
use crate::{
12
    client::{AnyApiCallback, OutstandingRequestMapHandle, SubscriberMap},
13
    Error,
14
};
15

            
16
/// This function will establish a connection and try to keep it active. If an
17
/// error occurs, any queries that come in while reconnecting will have the
18
/// error replayed to them.
19
1240
pub async fn reconnecting_client_loop(
20
1240
    mut url: Url,
21
1240
    protocol_version: &'static str,
22
1240
    certificate: Option<Certificate>,
23
1240
    request_receiver: Receiver<PendingRequest>,
24
1240
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
25
1240
    subscribers: SubscriberMap,
26
1240
) -> Result<(), Error> {
27
1240
    if url.port().is_none() && url.scheme() == "bonsaidb" {
28
80
        let _ = url.set_port(Some(5645));
29
1160
    }
30

            
31
1240
    subscribers.clear();
32
1280
    while let Ok(request) = request_receiver.recv_async().await {
33
1240
        if let Err((failed_request, err)) = connect_and_process(
34
1240
            &url,
35
1240
            protocol_version,
36
1240
            certificate.as_ref(),
37
1240
            request,
38
1240
            &request_receiver,
39
1240
            custom_apis.clone(),
40
401860
        )
41
401860
        .await
42
        {
43
40
            if let Some(failed_request) = failed_request {
44
20
                drop(failed_request.responder.send(Err(err)));
45
20
            }
46
40
            continue;
47
        }
48
    }
49

            
50
20
    Ok(())
51
20
}
52

            
53
1240
async fn connect_and_process(
54
1240
    url: &Url,
55
1240
    protocol_version: &str,
56
1240
    certificate: Option<&Certificate>,
57
1240
    initial_request: PendingRequest,
58
1240
    request_receiver: &Receiver<PendingRequest>,
59
1240
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
60
1240
) -> Result<(), (Option<PendingRequest>, Error)> {
61
1220
    let (_connection, payload_sender, payload_receiver) =
62
1240
        match connect(url, certificate, protocol_version).await {
63
1220
            Ok(result) => result,
64
20
            Err(err) => return Err((Some(initial_request), err)),
65
        };
66

            
67
1220
    let outstanding_requests = OutstandingRequestMapHandle::default();
68
1220
    let request_processor = tokio::spawn(process(
69
1220
        outstanding_requests.clone(),
70
1220
        payload_receiver,
71
1220
        custom_apis,
72
1220
    ));
73

            
74
1220
    if let Err(err) = payload_sender.send(&initial_request.request) {
75
        return Err((Some(initial_request), Error::from(err)));
76
1220
    }
77

            
78
1220
    {
79
1220
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
80
1220
        outstanding_requests.insert(
81
1220
            initial_request
82
1220
                .request
83
1220
                .id
84
1220
                .expect("all requests require ids"),
85
1220
            initial_request,
86
1220
        );
87
    }
88

            
89
1220
    if let Err(err) = futures::try_join!(
90
401840
        process_requests(
91
401840
            outstanding_requests.clone(),
92
401840
            request_receiver,
93
401840
            payload_sender
94
401840
        ),
95
401840
        async { request_processor.await.map_err(|_| Error::Disconnected)? }
96
401840
    ) {
97
        // Our socket was disconnected, clear the outstanding requests before returning.
98
20
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
99
20
        for (_, pending) in outstanding_requests.drain() {
100
            drop(pending.responder.send(Err(Error::Disconnected)));
101
        }
102
20
        return Err((None, err));
103
    }
104

            
105
    Ok(())
106
40
}
107

            
108
1220
async fn process_requests(
109
1220
    outstanding_requests: OutstandingRequestMapHandle,
110
1220
    request_receiver: &Receiver<PendingRequest>,
111
1220
    payload_sender: fabruic::Sender<Payload>,
112
1220
) -> Result<(), Error> {
113
695960
    while let Ok(client_request) = request_receiver.recv_async().await {
114
694760
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
115
694760
        payload_sender.send(&client_request.request)?;
116
694760
        outstanding_requests.insert(
117
694760
            client_request.request.id.expect("all requests require ids"),
118
694760
            client_request,
119
694760
        );
120
    }
121

            
122
    // Return an error to make sure try_join returns.
123
20
    Err(Error::Disconnected)
124
20
}
125

            
126
1220
pub async fn process(
127
1220
    outstanding_requests: OutstandingRequestMapHandle,
128
1220
    mut payload_receiver: fabruic::Receiver<Payload>,
129
1220
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
130
1220
) -> Result<(), Error> {
131
697500
    while let Some(payload) = payload_receiver.next().await {
132
696540
        let payload = payload?;
133
696280
        super::process_response_payload(payload, &outstanding_requests, &custom_apis).await;
134
    }
135

            
136
    Err(Error::Disconnected)
137
260
}
138

            
139
1240
async fn connect(
140
1240
    url: &Url,
141
1240
    certificate: Option<&Certificate>,
142
1240
    protocol_version: &str,
143
1240
) -> Result<
144
1240
    (
145
1240
        fabruic::Connection<()>,
146
1240
        fabruic::Sender<Payload>,
147
1240
        fabruic::Receiver<Payload>,
148
1240
    ),
149
1240
    Error,
150
1240
> {
151
1240
    let mut endpoint = Endpoint::builder();
152
1240
    endpoint
153
1240
        .set_max_idle_timeout(None)
154
1240
        .map_err(|err| Error::Core(bonsaidb_core::Error::Transport(err.to_string())))?;
155
1240
    endpoint.set_protocols([protocol_version.as_bytes().to_vec()]);
156
1240
    let endpoint = endpoint
157
1240
        .build()
158
1240
        .map_err(|err| Error::Core(bonsaidb_core::Error::Transport(err.to_string())))?;
159
1240
    let connecting = if let Some(certificate) = certificate {
160
1240
        endpoint.connect_pinned(url, certificate, None).await?
161
    } else {
162
        endpoint.connect(url).await?
163
    };
164

            
165
1240
    let connection = connecting.accept::<()>().await.map_err(|err| {
166
20
        if matches!(err, fabruic::error::Connecting::ProtocolMismatch) {
167
20
            Error::ProtocolVersionMismatch
168
        } else {
169
            Error::from(err)
170
        }
171
1240
    })?;
172
1220
    let (sender, receiver) = connection.open_stream(&()).await?;
173

            
174
1220
    Ok((connection, sender, receiver))
175
1240
}