1
use std::sync::Arc;
2

            
3
use bonsaidb_core::{
4
    custom_api::{CustomApi, CustomApiResult},
5
    networking::{Payload, Request, Response},
6
};
7
use bonsaidb_utils::fast_async_lock;
8
use fabruic::{self, Certificate, Endpoint};
9
use flume::Receiver;
10
use futures::StreamExt;
11
use url::Url;
12

            
13
use super::{CustomApiCallback, PendingRequest};
14
use crate::{
15
    client::{OutstandingRequestMapHandle, SubscriberMap},
16
    Error,
17
};
18

            
19
/// This function will establish a connection and try to keep it active. If an
20
/// error occurs, any queries that come in while reconnecting will have the
21
/// error replayed to them.
22
62
pub async fn reconnecting_client_loop<A: CustomApi>(
23
62
    mut url: Url,
24
62
    protocol_version: &'static str,
25
62
    certificate: Option<Certificate>,
26
62
    request_receiver: Receiver<PendingRequest<A>>,
27
62
    custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
28
62
    subscribers: SubscriberMap,
29
62
) -> Result<(), Error<A::Error>> {
30
62
    if url.port().is_none() && url.scheme() == "bonsaidb" {
31
4
        let _ = url.set_port(Some(5645));
32
62
    }
33

            
34
62
    subscribers.clear().await;
35
63
    while let Ok(request) = request_receiver.recv_async().await {
36
62
        if let Err((failed_request, err)) = connect_and_process(
37
62
            &url,
38
62
            protocol_version,
39
62
            certificate.as_ref(),
40
62
            request,
41
62
            &request_receiver,
42
62
            custom_api_callback.clone(),
43
62
            &subscribers,
44
16958
        )
45
16958
        .await
46
        {
47
1
            if let Some(failed_request) = failed_request {
48
1
                drop(failed_request.responder.send(Err(err)));
49
1
            }
50
1
            continue;
51
        }
52
    }
53

            
54
    Ok(())
55
}
56

            
57
62
async fn connect_and_process<A: CustomApi>(
58
62
    url: &Url,
59
62
    protocol_version: &str,
60
62
    certificate: Option<&Certificate>,
61
62
    initial_request: PendingRequest<A>,
62
62
    request_receiver: &Receiver<PendingRequest<A>>,
63
62
    custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
64
62
    subscribers: &SubscriberMap,
65
62
) -> Result<(), (Option<PendingRequest<A>>, Error<A::Error>)> {
66
61
    let (_connection, payload_sender, payload_receiver) =
67
62
        match connect::<A>(url, certificate, protocol_version).await {
68
61
            Ok(result) => result,
69
1
            Err(err) => return Err((Some(initial_request), err)),
70
        };
71

            
72
61
    let outstanding_requests = OutstandingRequestMapHandle::default();
73
61
    let request_processor = tokio::spawn(process(
74
61
        outstanding_requests.clone(),
75
61
        payload_receiver,
76
61
        custom_api_callback,
77
61
        subscribers.clone(),
78
61
    ));
79

            
80
61
    if let Err(err) = payload_sender.send(&initial_request.request) {
81
        return Err((Some(initial_request), Error::from(err)));
82
61
    }
83

            
84
61
    {
85
61
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
86
61
        outstanding_requests.insert(
87
61
            initial_request
88
61
                .request
89
61
                .id
90
61
                .expect("all requests require ids"),
91
61
            initial_request,
92
61
        );
93
    }
94

            
95
61
    if let Err(err) = futures::try_join!(
96
16957
        process_requests::<A>(
97
16957
            outstanding_requests.clone(),
98
16957
            request_receiver,
99
16957
            payload_sender
100
16957
        ),
101
16957
        async { request_processor.await.map_err(|_| Error::Disconnected)? }
102
16957
    ) {
103
        // Our socket was disconnected, clear the outstanding requests before returning.
104
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
105
        for (_, pending) in outstanding_requests.drain() {
106
            drop(pending.responder.send(Err(Error::Disconnected)));
107
        }
108
        return Err((None, err));
109
    }
110

            
111
    Ok(())
112
1
}
113

            
114
61
async fn process_requests<A: CustomApi>(
115
61
    outstanding_requests: OutstandingRequestMapHandle<A>,
116
61
    request_receiver: &Receiver<PendingRequest<A>>,
117
61
    payload_sender: fabruic::Sender<Payload<Request<A::Request>>>,
118
61
) -> Result<(), Error<A::Error>> {
119
30223
    while let Ok(client_request) = request_receiver.recv_async().await {
120
30162
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
121
30162
        payload_sender.send(&client_request.request)?;
122
30162
        outstanding_requests.insert(
123
30162
            client_request.request.id.expect("all requests require ids"),
124
30162
            client_request,
125
30162
        );
126
    }
127

            
128
    // Return an error to make sure try_join returns.
129
    Err(Error::Disconnected)
130
}
131

            
132
61
pub async fn process<A: CustomApi>(
133
61
    outstanding_requests: OutstandingRequestMapHandle<A>,
134
61
    mut payload_receiver: fabruic::Receiver<Payload<Response<CustomApiResult<A>>>>,
135
61
    custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
136
61
    subscribers: SubscriberMap,
137
61
) -> Result<(), Error<A::Error>> {
138
30296
    while let Some(payload) = payload_receiver.next().await {
139
30253
        let payload = payload?;
140
30237
        super::process_response_payload(
141
30237
            payload,
142
30237
            &outstanding_requests,
143
30237
            custom_api_callback.as_deref(),
144
30237
            &subscribers,
145
30237
        )
146
        .await;
147
    }
148

            
149
    Err(Error::Disconnected)
150
16
}
151

            
152
62
async fn connect<A: CustomApi>(
153
62
    url: &Url,
154
62
    certificate: Option<&Certificate>,
155
62
    protocol_version: &str,
156
62
) -> Result<
157
62
    (
158
62
        fabruic::Connection<()>,
159
62
        fabruic::Sender<Payload<Request<A::Request>>>,
160
62
        fabruic::Receiver<Payload<Response<CustomApiResult<A>>>>,
161
62
    ),
162
62
    Error<A::Error>,
163
62
> {
164
62
    let mut endpoint = Endpoint::builder();
165
62
    endpoint
166
62
        .set_max_idle_timeout(None)
167
62
        .map_err(|err| Error::Core(bonsaidb_core::Error::Transport(err.to_string())))?;
168
62
    endpoint.set_protocols([protocol_version.as_bytes().to_vec()]);
169
62
    let endpoint = endpoint
170
62
        .build()
171
62
        .map_err(|err| Error::Core(bonsaidb_core::Error::Transport(err.to_string())))?;
172
62
    let connecting = if let Some(certificate) = certificate {
173
62
        endpoint.connect_pinned(url, certificate, None).await?
174
    } else {
175
        endpoint.connect(url).await?
176
    };
177

            
178
62
    let connection = connecting.accept::<()>().await.map_err(|err| {
179
1
        if matches!(err, fabruic::error::Connecting::ProtocolMismatch) {
180
1
            Error::ProtocolVersionMismatch
181
        } else {
182
            Error::from(err)
183
        }
184
62
    })?;
185
61
    let (sender, receiver) = connection.open_stream(&()).await?;
186

            
187
61
    Ok((connection, sender, receiver))
188
62
}