1
use std::collections::HashMap;
2
use std::sync::atomic::{AtomicU32, Ordering};
3
use std::sync::Arc;
4

            
5
use bonsaidb_core::api::ApiName;
6
use bonsaidb_core::networking::Payload;
7
use bonsaidb_utils::fast_async_lock;
8
use flume::Receiver;
9
use futures::stream::{SplitSink, SplitStream};
10
use futures::{SinkExt, StreamExt};
11
use tokio::net::TcpStream;
12
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
13
use tokio_tungstenite::tungstenite::Message;
14
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
15

            
16
use super::PendingRequest;
17
use crate::client::{
18
    disconnect_pending_requests, AnyApiCallback, ConnectionInfo, OutstandingRequestMapHandle,
19
};
20
use crate::Error;
21

            
22
1818
pub(super) async fn reconnecting_client_loop(
23
1818
    server: ConnectionInfo,
24
1818
    protocol_version: &str,
25
1818
    request_receiver: Receiver<PendingRequest>,
26
1818
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
27
1818
    connection_counter: Arc<AtomicU32>,
28
1818
) -> Result<(), Error> {
29
1818
    let mut pending_error = None;
30
1854
    while let Ok(request) = {
31
2898
        server.subscribers.clear();
32
2898
        request_receiver.recv_async().await
33
    } {
34
1854
        if let Some(pending_error) = pending_error.take() {
35
18
            drop(request.responder.send(Err(pending_error)));
36
18
            continue;
37
1836
        }
38
1836

            
39
1836
        connection_counter.fetch_add(1, Ordering::SeqCst);
40
1836
        let (stream, _) = match tokio::time::timeout(
41
1836
            server.connect_timeout,
42
1836
            tokio_tungstenite::connect_async(
43
1836
                tokio_tungstenite::tungstenite::handshake::client::Request::get(
44
1836
                    server.url.as_str(),
45
1836
                )
46
1836
                .header("Sec-WebSocket-Protocol", protocol_version)
47
1836
                .header("Sec-WebSocket-Version", "13")
48
1836
                .header("Sec-WebSocket-Key", generate_key())
49
1836
                .header("Host", server.url.host_str().expect("no host"))
50
1836
                .header("Connection", "Upgrade")
51
1836
                .header("Upgrade", "websocket")
52
1836
                .body(())
53
1836
                .unwrap(),
54
1836
            ),
55
1836
        )
56
5544
        .await
57
        {
58
1782
            Ok(Ok(result)) => result,
59
18
            Ok(Err(err)) => {
60
18
                drop(request.responder.send(Err(Error::from(err))));
61
18
                continue;
62
            }
63
            Err(_) => {
64
36
                drop(request.responder.send(Err(Error::connect_timeout())));
65
36
                continue;
66
            }
67
        };
68

            
69
1782
        let (mut sender, receiver) = stream.split();
70
1782

            
71
1782
        let outstanding_requests = OutstandingRequestMapHandle::default();
72
        {
73
1782
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
74
1782
            if let Err(err) = sender
75
1782
                .send(Message::Binary(bincode::serialize(&request.request)?))
76
                .await
77
            {
78
                drop(request.responder.send(Err(Error::from(err))));
79
                continue;
80
1782
            }
81
1782
            outstanding_requests.insert(
82
1782
                request.request.id.expect("all requests must have ids"),
83
1782
                request,
84
1782
            );
85
        }
86

            
87
1782
        if let Err(err) = tokio::try_join!(
88
979884
            request_sender(&request_receiver, sender, outstanding_requests.clone()),
89
979884
            response_processor(receiver, outstanding_requests.clone(), &custom_apis,)
90
979884
        ) {
91
            // Our socket was disconnected, clear the outstanding requests before returning.
92
792
            log::error!("Error on socket {:?}", err);
93
1008
            pending_error = Some(err);
94
1008
            disconnect_pending_requests(&outstanding_requests, &mut pending_error).await;
95
        }
96
    }
97

            
98
1008
    Ok(())
99
1008
}
100

            
101
1782
async fn request_sender(
102
1782
    request_receiver: &Receiver<PendingRequest>,
103
1782
    mut sender: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
104
1782
    outstanding_requests: OutstandingRequestMapHandle,
105
1782
) -> Result<(), Error> {
106
977976
    while let Ok(pending) = request_receiver.recv_async().await {
107
675540
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
108
675540
        sender
109
675540
            .send(Message::Binary(bincode::serialize(&pending.request)?))
110
126
            .await?;
111

            
112
675540
        outstanding_requests.insert(
113
675540
            pending.request.id.expect("all requests must have ids"),
114
675540
            pending,
115
675540
        );
116
    }
117

            
118
990
    Err(Error::disconnected())
119
990
}
120

            
121
#[allow(clippy::collapsible_else_if)] // not possible due to cfg statement
122
1782
async fn response_processor(
123
1782
    mut receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
124
1782
    outstanding_requests: OutstandingRequestMapHandle,
125
1782
    custom_apis: &HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>,
126
1782
) -> Result<(), Error> {
127
977400
    while let Some(message) = receiver.next().await {
128
677772
        let message = message?;
129
677754
        match message {
130
677754
            Message::Binary(response) => {
131
677754
                let payload = bincode::deserialize::<Payload>(&response)?;
132

            
133
677754
                super::process_response_payload(payload, &outstanding_requests, custom_apis).await;
134
            }
135
            other => {
136
                log::error!("Unexpected websocket message: {:?}", other);
137
            }
138
        }
139
    }
140

            
141
    Ok(())
142
18
}