1
use std::collections::HashMap;
2
use std::sync::atomic::{AtomicU32, Ordering};
3
use std::sync::Arc;
4

            
5
use bonsaidb_core::api::ApiName;
6
use bonsaidb_core::networking::Payload;
7
use bonsaidb_utils::fast_async_lock;
8
use flume::Receiver;
9
use futures::stream::{SplitSink, SplitStream};
10
use futures::{SinkExt, StreamExt};
11
use tokio::net::TcpStream;
12
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
13
use tokio_tungstenite::tungstenite::Message;
14
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
15

            
16
use super::PendingRequest;
17
use crate::client::{
18
    disconnect_pending_requests, AnyApiCallback, ConnectionInfo, OutstandingRequestMapHandle,
19
};
20
use crate::Error;
21

            
22
3210
pub(super) async fn reconnecting_client_loop(
23
3210
    server: ConnectionInfo,
24
3210
    protocol_version: &str,
25
3210
    request_receiver: Receiver<PendingRequest>,
26
3210
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
27
3210
    connection_counter: Arc<AtomicU32>,
28
3210
) -> Result<(), Error> {
29
3210
    let mut pending_error = None;
30
3270
    while let Ok(request) = {
31
5190
        server.subscribers.clear();
32
5190
        request_receiver.recv_async().await
33
    } {
34
3270
        if let Some(pending_error) = pending_error.take() {
35
30
            drop(request.responder.send(Err(pending_error)));
36
30
            continue;
37
3240
        }
38
3240

            
39
3240
        connection_counter.fetch_add(1, Ordering::SeqCst);
40
3240
        let (stream, _) = match tokio::time::timeout(
41
3240
            server.connect_timeout,
42
3240
            tokio_tungstenite::connect_async(
43
3240
                tokio_tungstenite::tungstenite::handshake::client::Request::get(
44
3240
                    server.url.as_str(),
45
3240
                )
46
3240
                .header("Sec-WebSocket-Protocol", protocol_version)
47
3240
                .header("Sec-WebSocket-Version", "13")
48
3240
                .header("Sec-WebSocket-Key", generate_key())
49
3240
                .header("Host", server.url.host_str().expect("no host"))
50
3240
                .header("Connection", "Upgrade")
51
3240
                .header("Upgrade", "websocket")
52
3240
                .body(())
53
3240
                .unwrap(),
54
3240
            ),
55
3240
        )
56
10020
        .await
57
        {
58
3150
            Ok(Ok(result)) => result,
59
30
            Ok(Err(err)) => {
60
30
                drop(request.responder.send(Err(Error::from(err))));
61
30
                continue;
62
            }
63
            Err(_) => {
64
60
                drop(request.responder.send(Err(Error::connect_timeout())));
65
60
                continue;
66
            }
67
        };
68

            
69
3150
        let (mut sender, receiver) = stream.split();
70
3150

            
71
3150
        let outstanding_requests = OutstandingRequestMapHandle::default();
72
        {
73
3150
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
74
3150
            if let Err(err) = sender
75
3150
                .send(Message::Binary(bincode::serialize(&request.request)?))
76
                .await
77
            {
78
                drop(request.responder.send(Err(Error::from(err))));
79
                continue;
80
3150
            }
81
3150
            outstanding_requests.insert(
82
3150
                request.request.id.expect("all requests must have ids"),
83
3150
                request,
84
3150
            );
85
        }
86

            
87
1616610
        if let Err(err) = tokio::try_join!(
88
1616610
            request_sender(&request_receiver, sender, outstanding_requests.clone()),
89
1616610
            response_processor(receiver, outstanding_requests.clone(), &custom_apis,)
90
1616610
        ) {
91
            // Our socket was disconnected, clear the outstanding requests before returning.
92
1320
            log::error!("Error on socket {:?}", err);
93
1860
            pending_error = Some(err);
94
1860
            disconnect_pending_requests(&outstanding_requests, &mut pending_error).await;
95
        }
96
    }
97

            
98
1860
    Ok(())
99
1860
}
100

            
101
3150
async fn request_sender(
102
3150
    request_receiver: &Receiver<PendingRequest>,
103
3150
    mut sender: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
104
3150
    outstanding_requests: OutstandingRequestMapHandle,
105
3150
) -> Result<(), Error> {
106
1613250
    while let Ok(pending) = request_receiver.recv_async().await {
107
1120500
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
108
1120500
        sender
109
1120500
            .send(Message::Binary(bincode::serialize(&pending.request)?))
110
210
            .await?;
111

            
112
1120500
        outstanding_requests.insert(
113
1120500
            pending.request.id.expect("all requests must have ids"),
114
1120500
            pending,
115
1120500
        );
116
    }
117

            
118
1830
    Err(Error::disconnected())
119
1830
}
120

            
121
#[allow(clippy::collapsible_else_if)] // not possible due to cfg statement
122
3150
async fn response_processor(
123
3150
    mut receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
124
3150
    outstanding_requests: OutstandingRequestMapHandle,
125
3150
    custom_apis: &HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>,
126
3150
) -> Result<(), Error> {
127
1611840
    while let Some(message) = receiver.next().await {
128
1124400
        let message = message?;
129
1124370
        match message {
130
1124370
            Message::Binary(response) => {
131
1124370
                let payload = bincode::deserialize::<Payload>(&response)?;
132

            
133
1124370
                super::process_response_payload(payload, &outstanding_requests, custom_apis).await;
134
            }
135
            other => {
136
                log::error!("Unexpected websocket message: {:?}", other);
137
            }
138
        }
139
    }
140

            
141
    Ok(())
142
30
}