1
use std::{collections::HashMap, sync::Arc};
2

            
3
use bonsaidb_core::{networking::Payload, schema::ApiName};
4
use bonsaidb_utils::fast_async_lock;
5
use flume::Receiver;
6
use futures::{
7
    stream::{SplitSink, SplitStream},
8
    SinkExt, StreamExt,
9
};
10
use tokio::net::TcpStream;
11
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
12
use url::Url;
13

            
14
use super::PendingRequest;
15
use crate::{
16
    client::{AnyApiCallback, OutstandingRequestMapHandle, SubscriberMap},
17
    Error,
18
};
19

            
20
1653
pub async fn reconnecting_client_loop(
21
1653
    url: Url,
22
1653
    protocol_version: &str,
23
1653
    request_receiver: Receiver<PendingRequest>,
24
1653
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
25
1653
    subscribers: SubscriberMap,
26
1653
) -> Result<(), Error> {
27
1653
    while let Ok(request) = {
28
1805
        subscribers.clear();
29
1805
        request_receiver.recv_async().await
30
    } {
31
1653
        let (stream, _) = match tokio_tungstenite::connect_async(
32
1653
            tokio_tungstenite::tungstenite::handshake::client::Request::get(url.as_str())
33
1653
                .header("Sec-WebSocket-Protocol", protocol_version)
34
1653
                .body(())
35
1653
                .unwrap(),
36
5149
        )
37
5149
        .await
38
        {
39
1634
            Ok(result) => result,
40
19
            Err(err) => {
41
19
                drop(request.responder.send(Err(Error::from(err))));
42
19
                continue;
43
            }
44
        };
45

            
46
1634
        let (mut sender, receiver) = stream.split();
47
1634

            
48
1634
        let outstanding_requests = OutstandingRequestMapHandle::default();
49
        {
50
1634
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
51
            if let Err(err) = sender
52
1634
                .send(Message::Binary(bincode::serialize(&request.request)?))
53
                .await
54
            {
55
                drop(request.responder.send(Err(Error::from(err))));
56
                continue;
57
1634
            }
58
1634
            outstanding_requests.insert(
59
1634
                request.request.id.expect("all requests must have ids"),
60
1634
                request,
61
1634
            );
62
        }
63

            
64
1634
        if let Err(err) = tokio::try_join!(
65
1006449
            request_sender(&request_receiver, sender, outstanding_requests.clone()),
66
1006449
            response_processor(receiver, outstanding_requests.clone(), &custom_apis,)
67
1006449
        ) {
68
            // Our socket was disconnected, clear the outstanding requests before returning.
69
133
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
70
133
            for (_, pending) in outstanding_requests.drain() {
71
19
                drop(pending.responder.send(Err(Error::Disconnected)));
72
19
            }
73
76
            log::error!("Error on socket {:?}", err);
74
        }
75
    }
76

            
77
133
    Ok(())
78
133
}
79

            
80
1634
async fn request_sender(
81
1634
    request_receiver: &Receiver<PendingRequest>,
82
1634
    mut sender: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
83
1634
    outstanding_requests: OutstandingRequestMapHandle,
84
1634
) -> Result<(), Error> {
85
1004701
    while let Ok(pending) = request_receiver.recv_async().await {
86
700891
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
87
        sender
88
700891
            .send(Message::Binary(bincode::serialize(&pending.request)?))
89
133
            .await?;
90

            
91
700891
        outstanding_requests.insert(
92
700891
            pending.request.id.expect("all requests must have ids"),
93
700891
            pending,
94
700891
        );
95
    }
96

            
97
95
    Err(Error::Disconnected)
98
95
}
99

            
100
#[allow(clippy::collapsible_else_if)] // not possible due to cfg statement
101
1634
async fn response_processor(
102
1634
    mut receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
103
1634
    outstanding_requests: OutstandingRequestMapHandle,
104
1634
    custom_apis: &HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>,
105
1634
) -> Result<(), Error> {
106
1004739
    while let Some(message) = receiver.next().await {
107
703114
        let message = message?;
108
703076
        match message {
109
703076
            Message::Binary(response) => {
110
703076
                let payload = bincode::deserialize::<Payload>(&response)?;
111

            
112
703076
                super::process_response_payload(payload, &outstanding_requests, custom_apis).await;
113
            }
114
            other => {
115
                log::error!("Unexpected websocket message: {:?}", other);
116
            }
117
        }
118
    }
119

            
120
    Ok(())
121
38
}