1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;

use bonsaidb_core::api::ApiName;
use bonsaidb_core::networking::Payload;
use bonsaidb_utils::fast_async_lock;
use flume::Receiver;
use futures::stream::{SplitSink, SplitStream};
use futures::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};

use super::PendingRequest;
use crate::client::{
    disconnect_pending_requests, AnyApiCallback, ConnectionInfo, OutstandingRequestMapHandle,
};
use crate::Error;

pub(super) async fn reconnecting_client_loop(
    server: ConnectionInfo,
    protocol_version: &str,
    request_receiver: Receiver<PendingRequest>,
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
    connection_counter: Arc<AtomicU32>,
) -> Result<(), Error> {
    let mut pending_error = None;
    while let Ok(request) = {
        server.subscribers.clear();
        request_receiver.recv_async().await
    } {
        if let Some(pending_error) = pending_error.take() {
            drop(request.responder.send(Err(pending_error)));
            continue;
        }

        connection_counter.fetch_add(1, Ordering::SeqCst);
        let (stream, _) = match tokio::time::timeout(
            server.connect_timeout,
            tokio_tungstenite::connect_async(
                tokio_tungstenite::tungstenite::handshake::client::Request::get(
                    server.url.as_str(),
                )
                .header("Sec-WebSocket-Protocol", protocol_version)
                .header("Sec-WebSocket-Version", "13")
                .header("Sec-WebSocket-Key", generate_key())
                .header("Host", server.url.host_str().expect("no host"))
                .header("Connection", "Upgrade")
                .header("Upgrade", "websocket")
                .body(())
                .unwrap(),
            ),
        )
        .await
        {
            Ok(Ok(result)) => result,
            Ok(Err(err)) => {
                drop(request.responder.send(Err(Error::from(err))));
                continue;
            }
            Err(_) => {
                drop(request.responder.send(Err(Error::connect_timeout())));
                continue;
            }
        };

        let (mut sender, receiver) = stream.split();

        let outstanding_requests = OutstandingRequestMapHandle::default();
        {
            let mut outstanding_requests = fast_async_lock!(outstanding_requests);
            if let Err(err) = sender
                .send(Message::Binary(bincode::serialize(&request.request)?))
                .await
            {
                drop(request.responder.send(Err(Error::from(err))));
                continue;
            }
            outstanding_requests.insert(
                request.request.id.expect("all requests must have ids"),
                request,
            );
        }

        if let Err(err) = tokio::try_join!(
            request_sender(&request_receiver, sender, outstanding_requests.clone()),
            response_processor(receiver, outstanding_requests.clone(), &custom_apis,)
        ) {
            // Our socket was disconnected, clear the outstanding requests before returning.
            log::error!("Error on socket {:?}", err);
            pending_error = Some(err);
            disconnect_pending_requests(&outstanding_requests, &mut pending_error).await;
        }
    }

    Ok(())
}

async fn request_sender(
    request_receiver: &Receiver<PendingRequest>,
    mut sender: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
    outstanding_requests: OutstandingRequestMapHandle,
) -> Result<(), Error> {
    while let Ok(pending) = request_receiver.recv_async().await {
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
        sender
            .send(Message::Binary(bincode::serialize(&pending.request)?))
            .await?;

        outstanding_requests.insert(
            pending.request.id.expect("all requests must have ids"),
            pending,
        );
    }

    Err(Error::disconnected())
}

#[allow(clippy::collapsible_else_if)] // not possible due to cfg statement
async fn response_processor(
    mut receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
    outstanding_requests: OutstandingRequestMapHandle,
    custom_apis: &HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>,
) -> Result<(), Error> {
    while let Some(message) = receiver.next().await {
        let message = message?;
        match message {
            Message::Binary(response) => {
                let payload = bincode::deserialize::<Payload>(&response)?;

                super::process_response_payload(payload, &outstanding_requests, custom_apis).await;
            }
            other => {
                log::error!("Unexpected websocket message: {:?}", other);
            }
        }
    }

    Ok(())
}