1
use std::collections::HashMap;
2
use std::sync::atomic::{AtomicU32, Ordering};
3
use std::sync::Arc;
4

            
5
use bonsaidb_core::api::ApiName;
6
use bonsaidb_core::networking::Payload;
7
use bonsaidb_utils::fast_async_lock;
8
use flume::Receiver;
9
use futures::stream::{SplitSink, SplitStream};
10
use futures::{SinkExt, StreamExt};
11
use tokio::net::TcpStream;
12
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
13
use tokio_tungstenite::tungstenite::Message;
14
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
15

            
16
use super::PendingRequest;
17
use crate::client::{
18
    disconnect_pending_requests, AnyApiCallback, ConnectionInfo, OutstandingRequestMapHandle,
19
};
20
use crate::Error;
21

            
22
1717
pub(super) async fn reconnecting_client_loop(
23
1717
    server: ConnectionInfo,
24
1717
    protocol_version: &str,
25
1717
    request_receiver: Receiver<PendingRequest>,
26
1717
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
27
1717
    connection_counter: Arc<AtomicU32>,
28
1717
) -> Result<(), Error> {
29
1717
    let mut pending_error = None;
30
1751
    while let Ok(request) = {
31
2737
        server.subscribers.clear();
32
2737
        request_receiver.recv_async().await
33
    } {
34
1751
        if let Some(pending_error) = pending_error.take() {
35
17
            drop(request.responder.send(Err(pending_error)));
36
17
            continue;
37
1734
        }
38
1734

            
39
1734
        connection_counter.fetch_add(1, Ordering::SeqCst);
40
1734
        let (stream, _) = match tokio::time::timeout(
41
1734
            server.connect_timeout,
42
1734
            tokio_tungstenite::connect_async(
43
1734
                tokio_tungstenite::tungstenite::handshake::client::Request::get(
44
1734
                    server.url.as_str(),
45
1734
                )
46
1734
                .header("Sec-WebSocket-Protocol", protocol_version)
47
1734
                .header("Sec-WebSocket-Version", "13")
48
1734
                .header("Sec-WebSocket-Key", generate_key())
49
1734
                .header("Host", server.url.host_str().expect("no host"))
50
1734
                .header("Connection", "Upgrade")
51
1734
                .header("Upgrade", "websocket")
52
1734
                .body(())
53
1734
                .unwrap(),
54
1734
            ),
55
1734
        )
56
5253
        .await
57
        {
58
1683
            Ok(Ok(result)) => result,
59
17
            Ok(Err(err)) => {
60
17
                drop(request.responder.send(Err(Error::from(err))));
61
17
                continue;
62
            }
63
            Err(_) => {
64
34
                drop(request.responder.send(Err(Error::connect_timeout())));
65
34
                continue;
66
            }
67
        };
68

            
69
1683
        let (mut sender, receiver) = stream.split();
70
1683

            
71
1683
        let outstanding_requests = OutstandingRequestMapHandle::default();
72
        {
73
1683
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
74
            if let Err(err) = sender
75
1683
                .send(Message::Binary(bincode::serialize(&request.request)?))
76
34
                .await
77
            {
78
                drop(request.responder.send(Err(Error::from(err))));
79
                continue;
80
1683
            }
81
1683
            outstanding_requests.insert(
82
1683
                request.request.id.expect("all requests must have ids"),
83
1683
                request,
84
1683
            );
85
        }
86

            
87
1683
        if let Err(err) = tokio::try_join!(
88
930240
            request_sender(&request_receiver, sender, outstanding_requests.clone()),
89
930240
            response_processor(receiver, outstanding_requests.clone(), &custom_apis,)
90
930240
        ) {
91
            // Our socket was disconnected, clear the outstanding requests before returning.
92
748
            log::error!("Error on socket {:?}", err);
93
952
            pending_error = Some(err);
94
952
            disconnect_pending_requests(&outstanding_requests, &mut pending_error).await;
95
        }
96
    }
97

            
98
952
    Ok(())
99
952
}
100

            
101
1683
async fn request_sender(
102
1683
    request_receiver: &Receiver<PendingRequest>,
103
1683
    mut sender: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
104
1683
    outstanding_requests: OutstandingRequestMapHandle,
105
1683
) -> Result<(), Error> {
106
928438
    while let Ok(pending) = request_receiver.recv_async().await {
107
642736
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
108
        sender
109
642736
            .send(Message::Binary(bincode::serialize(&pending.request)?))
110
119
            .await?;
111

            
112
642736
        outstanding_requests.insert(
113
642736
            pending.request.id.expect("all requests must have ids"),
114
642736
            pending,
115
642736
        );
116
    }
117

            
118
935
    Err(Error::disconnected())
119
935
}
120

            
121
#[allow(clippy::collapsible_else_if)] // not possible due to cfg statement
122
1683
async fn response_processor(
123
1683
    mut receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
124
1683
    outstanding_requests: OutstandingRequestMapHandle,
125
1683
    custom_apis: &HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>,
126
1683
) -> Result<(), Error> {
127
927945
    while let Some(message) = receiver.next().await {
128
644844
        let message = message?;
129
644827
        match message {
130
644827
            Message::Binary(response) => {
131
644827
                let payload = bincode::deserialize::<Payload>(&response)?;
132

            
133
644827
                super::process_response_payload(payload, &outstanding_requests, custom_apis).await;
134
            }
135
            other => {
136
                log::error!("Unexpected websocket message: {:?}", other);
137
            }
138
        }
139
    }
140

            
141
    Ok(())
142
17
}