1
use std::sync::Arc;
2

            
3
use bonsaidb_core::{
4
    custom_api::{CustomApi, CustomApiResult},
5
    networking::{Payload, Response},
6
};
7
use bonsaidb_utils::fast_async_lock;
8
use flume::Receiver;
9
use futures::{
10
    stream::{SplitSink, SplitStream},
11
    SinkExt, StreamExt,
12
};
13
use tokio::net::TcpStream;
14
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
15
use url::Url;
16

            
17
use super::{CustomApiCallback, PendingRequest};
18
use crate::{
19
    client::{OutstandingRequestMapHandle, SubscriberMap},
20
    Error,
21
};
22

            
23
56
pub async fn reconnecting_client_loop<A: CustomApi>(
24
56
    url: Url,
25
56
    protocol_version: &str,
26
56
    request_receiver: Receiver<PendingRequest<A>>,
27
56
    custom_api_callback: Option<Arc<dyn CustomApiCallback<A>>>,
28
56
    subscribers: SubscriberMap,
29
56
) -> Result<(), Error<A::Error>> {
30
56
    while let Ok(request) = {
31
59
        subscribers.clear().await;
32
59
        request_receiver.recv_async().await
33
    } {
34
56
        let (stream, _) = match tokio_tungstenite::connect_async(
35
56
            tokio_tungstenite::tungstenite::handshake::client::Request::get(url.as_str())
36
56
                .header("Sec-WebSocket-Protocol", protocol_version)
37
56
                .body(())
38
56
                .unwrap(),
39
185
        )
40
185
        .await
41
        {
42
55
            Ok(result) => result,
43
1
            Err(err) => {
44
1
                drop(request.responder.send(Err(Error::from(err))));
45
1
                continue;
46
            }
47
        };
48

            
49
55
        let (mut sender, receiver) = stream.split();
50
55

            
51
55
        let outstanding_requests = OutstandingRequestMapHandle::default();
52
        {
53
55
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
54
            if let Err(err) = sender
55
55
                .send(Message::Binary(bincode::serialize(&request.request)?))
56
                .await
57
            {
58
                drop(request.responder.send(Err(Error::from(err))));
59
                continue;
60
55
            }
61
55
            outstanding_requests.insert(
62
55
                request.request.id.expect("all requests must have ids"),
63
55
                request,
64
55
            );
65
        }
66

            
67
55
        if let Err(err) = tokio::try_join!(
68
31703
            request_sender(&request_receiver, sender, outstanding_requests.clone()),
69
31703
            response_processor(
70
31703
                receiver,
71
31703
                outstanding_requests.clone(),
72
31703
                custom_api_callback.as_deref(),
73
31703
                subscribers.clone()
74
31703
            )
75
31703
        ) {
76
            // Our socket was disconnected, clear the outstanding requests before returning.
77
2
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
78
2
            for (_, pending) in outstanding_requests.drain() {
79
                drop(pending.responder.send(Err(Error::Disconnected)));
80
            }
81
2
            log::error!("Error on socket {:?}", err);
82
        }
83
    }
84

            
85
2
    Ok(())
86
2
}
87

            
88
55
async fn request_sender<Api: CustomApi>(
89
55
    request_receiver: &Receiver<PendingRequest<Api>>,
90
55
    mut sender: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
91
55
    outstanding_requests: OutstandingRequestMapHandle<Api>,
92
55
) -> Result<(), Error<Api::Error>> {
93
32690
    while let Ok(pending) = request_receiver.recv_async().await {
94
21952
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
95
        sender
96
21952
            .send(Message::Binary(bincode::serialize(&pending.request)?))
97
7
            .await?;
98

            
99
21952
        outstanding_requests.insert(
100
21952
            pending.request.id.expect("all requests must have ids"),
101
21952
            pending,
102
21952
        );
103
    }
104

            
105
    Err(Error::Disconnected)
106
}
107

            
108
#[allow(clippy::collapsible_else_if)] // not possible due to cfg statement
109
55
async fn response_processor<A: CustomApi>(
110
55
    mut receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
111
55
    outstanding_requests: OutstandingRequestMapHandle<A>,
112
55
    custom_api_callback: Option<&dyn CustomApiCallback<A>>,
113
55
    subscribers: SubscriberMap,
114
55
) -> Result<(), Error<A::Error>> {
115
32733
    while let Some(message) = receiver.next().await {
116
22023
        let message = message?;
117
22021
        match message {
118
22021
            Message::Binary(response) => {
119
22021
                let payload =
120
22021
                    bincode::deserialize::<Payload<Response<CustomApiResult<A>>>>(&response)?;
121

            
122
22021
                super::process_response_payload(
123
22021
                    payload,
124
22021
                    &outstanding_requests,
125
22021
                    custom_api_callback,
126
22021
                    &subscribers,
127
22021
                )
128
                .await;
129
            }
130
            other => {
131
                log::error!("Unexpected websocket message: {:?}", other);
132
            }
133
        }
134
    }
135

            
136
    Ok(())
137
2
}