1
use std::{collections::HashMap, sync::Arc};
2

            
3
use bonsaidb_core::{networking::Payload, schema::ApiName};
4
use bonsaidb_utils::fast_async_lock;
5
use flume::Receiver;
6
use futures::{
7
    stream::{SplitSink, SplitStream},
8
    SinkExt, StreamExt,
9
};
10
use tokio::net::TcpStream;
11
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
12
use url::Url;
13

            
14
use super::PendingRequest;
15
use crate::{
16
    client::{AnyApiCallback, OutstandingRequestMapHandle, SubscriberMap},
17
    Error,
18
};
19

            
20
1740
pub async fn reconnecting_client_loop(
21
1740
    url: Url,
22
1740
    protocol_version: &str,
23
1740
    request_receiver: Receiver<PendingRequest>,
24
1740
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
25
1740
    subscribers: SubscriberMap,
26
1740
) -> Result<(), Error> {
27
1740
    while let Ok(request) = {
28
1860
        subscribers.clear();
29
1860
        request_receiver.recv_async().await
30
    } {
31
1740
        let (stream, _) = match tokio_tungstenite::connect_async(
32
1740
            tokio_tungstenite::tungstenite::handshake::client::Request::get(url.as_str())
33
1740
                .header("Sec-WebSocket-Protocol", protocol_version)
34
1740
                .body(())
35
1740
                .unwrap(),
36
5320
        )
37
5320
        .await
38
        {
39
1720
            Ok(result) => result,
40
20
            Err(err) => {
41
20
                drop(request.responder.send(Err(Error::from(err))));
42
20
                continue;
43
            }
44
        };
45

            
46
1720
        let (mut sender, receiver) = stream.split();
47
1720

            
48
1720
        let outstanding_requests = OutstandingRequestMapHandle::default();
49
        {
50
1720
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
51
            if let Err(err) = sender
52
1720
                .send(Message::Binary(bincode::serialize(&request.request)?))
53
                .await
54
            {
55
                drop(request.responder.send(Err(Error::from(err))));
56
                continue;
57
1720
            }
58
1720
            outstanding_requests.insert(
59
1720
                request.request.id.expect("all requests must have ids"),
60
1720
                request,
61
1720
            );
62
        }
63

            
64
1720
        if let Err(err) = tokio::try_join!(
65
1079200
            request_sender(&request_receiver, sender, outstanding_requests.clone()),
66
1079200
            response_processor(receiver, outstanding_requests.clone(), &custom_apis,)
67
1079200
        ) {
68
            // Our socket was disconnected, clear the outstanding requests before returning.
69
100
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
70
100
            for (_, pending) in outstanding_requests.drain() {
71
                drop(pending.responder.send(Err(Error::Disconnected)));
72
            }
73
60
            log::error!("Error on socket {:?}", err);
74
        }
75
    }
76

            
77
100
    Ok(())
78
100
}
79

            
80
1720
async fn request_sender(
81
1720
    request_receiver: &Receiver<PendingRequest>,
82
1720
    mut sender: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
83
1720
    outstanding_requests: OutstandingRequestMapHandle,
84
1720
) -> Result<(), Error> {
85
1077380
    while let Ok(pending) = request_receiver.recv_async().await {
86
750420
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
87
        sender
88
750420
            .send(Message::Binary(bincode::serialize(&pending.request)?))
89
140
            .await?;
90

            
91
750420
        outstanding_requests.insert(
92
750420
            pending.request.id.expect("all requests must have ids"),
93
750420
            pending,
94
750420
        );
95
    }
96

            
97
60
    Err(Error::Disconnected)
98
60
}
99

            
100
#[allow(clippy::collapsible_else_if)] // not possible due to cfg statement
101
1720
async fn response_processor(
102
1720
    mut receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
103
1720
    outstanding_requests: OutstandingRequestMapHandle,
104
1720
    custom_apis: &HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>,
105
1720
) -> Result<(), Error> {
106
1077460
    while let Some(message) = receiver.next().await {
107
752760
        let message = message?;
108
752720
        match message {
109
752720
            Message::Binary(response) => {
110
752720
                let payload = bincode::deserialize::<Payload>(&response)?;
111

            
112
752720
                super::process_response_payload(payload, &outstanding_requests, custom_apis).await;
113
            }
114
            other => {
115
                log::error!("Unexpected websocket message: {:?}", other);
116
            }
117
        }
118
    }
119

            
120
    Ok(())
121
40
}