1
use std::sync::Arc;
2

            
3
use async_trait::async_trait;
4
use bonsaidb_core::{
5
    arc_bytes::serde::Bytes,
6
    networking::{CreateSubscriber, Publish, PublishToAll, SubscribeTo, UnsubscribeFrom},
7
    pubsub::{AsyncPubSub, AsyncSubscriber, Receiver},
8
};
9

            
10
use crate::Client;
11

            
12
#[async_trait]
13
impl AsyncPubSub for super::RemoteDatabase {
14
    type Subscriber = RemoteSubscriber;
15

            
16
320
    async fn create_subscriber(&self) -> Result<Self::Subscriber, bonsaidb_core::Error> {
17
320
        let subscriber_id = self
18
320
            .client
19
320
            .send_api_request_async(&CreateSubscriber {
20
320
                database: self.name.to_string(),
21
320
            })
22
320
            .await?;
23

            
24
320
        let (sender, receiver) = flume::unbounded();
25
320
        self.client.register_subscriber(subscriber_id, sender);
26
320
        Ok(RemoteSubscriber {
27
320
            client: self.client.clone(),
28
320
            database: self.name.clone(),
29
320
            id: subscriber_id,
30
320
            receiver: Receiver::new(receiver),
31
320
            #[cfg(not(target_arch = "wasm32"))]
32
320
            tokio: tokio::runtime::Handle::try_current().ok().map(Arc::new),
33
320
        })
34
640
    }
35

            
36
400
    async fn publish_bytes(
37
400
        &self,
38
400
        topic: Vec<u8>,
39
400
        payload: Vec<u8>,
40
400
    ) -> Result<(), bonsaidb_core::Error> {
41
400
        self.client
42
400
            .send_api_request_async(&Publish {
43
400
                database: self.name.to_string(),
44
400
                topic: Bytes::from(topic),
45
400
                payload: Bytes::from(payload),
46
400
            })
47
400
            .await?;
48
400
        Ok(())
49
800
    }
50

            
51
2
    async fn publish_bytes_to_all(
52
2
        &self,
53
2
        topics: impl IntoIterator<Item = Vec<u8>> + Send + 'async_trait,
54
2
        payload: Vec<u8>,
55
2
    ) -> Result<(), bonsaidb_core::Error> {
56
2
        let topics = topics.into_iter().map(Bytes::from).collect();
57
2
        self.client
58
2
            .send_api_request_async(&PublishToAll {
59
2
                database: self.name.to_string(),
60
2
                topics,
61
2
                payload: Bytes::from(payload),
62
2
            })
63
2
            .await?;
64
2
        Ok(())
65
4
    }
66
}
67

            
68
/// A `PubSub` subscriber from a remote server.
69
#[derive(Debug)]
70
pub struct RemoteSubscriber {
71
    pub(crate) client: Client,
72
    pub(crate) database: Arc<String>,
73
    pub(crate) id: u64,
74
    pub(crate) receiver: Receiver,
75
    #[cfg(not(target_arch = "wasm32"))]
76
    pub(crate) tokio: Option<Arc<tokio::runtime::Handle>>,
77
}
78

            
79
#[async_trait]
80
impl AsyncSubscriber for RemoteSubscriber {
81
520
    async fn subscribe_to_bytes(&self, topic: Vec<u8>) -> Result<(), bonsaidb_core::Error> {
82
520
        self.client
83
520
            .send_api_request_async(&SubscribeTo {
84
520
                database: self.database.to_string(),
85
520
                subscriber_id: self.id,
86
520
                topic: Bytes::from(topic),
87
520
            })
88
520
            .await?;
89
520
        Ok(())
90
1040
    }
91

            
92
40
    async fn unsubscribe_from_bytes(&self, topic: &[u8]) -> Result<(), bonsaidb_core::Error> {
93
40
        self.client
94
40
            .send_api_request_async(&UnsubscribeFrom {
95
40
                database: self.database.to_string(),
96
40
                subscriber_id: self.id,
97
40
                topic: Bytes::from(topic),
98
40
            })
99
40
            .await?;
100
40
        Ok(())
101
80
    }
102

            
103
1080
    fn receiver(&self) -> &Receiver {
104
1080
        &self.receiver
105
1080
    }
106
}
107

            
108
#[cfg(target_arch = "wasm32")]
109
impl Drop for RemoteSubscriber {
110
    fn drop(&mut self) {
111
        let client = self.client.clone();
112
        let database = self.database.to_string();
113
        let subscriber_id = self.id;
114
        let drop_future = async move {
115
            client
116
                .unregister_subscriber_async(database, subscriber_id)
117
                .await;
118
        };
119
        wasm_bindgen_futures::spawn_local(drop_future);
120
    }
121
}
122

            
123
#[cfg(not(target_arch = "wasm32"))]
124
impl Drop for RemoteSubscriber {
125
    fn drop(&mut self) {
126
480
        if let Some(tokio) = &self.tokio {
127
320
            let client = self.client.clone();
128
320
            let database = self.database.to_string();
129
320
            let subscriber_id = self.id;
130
320
            tokio.spawn(async move {
131
40
                client
132
40
                    .unregister_subscriber_async(database, subscriber_id)
133
40
                    .await;
134
320
            });
135
320
        } else {
136
160
            self.client
137
160
                .unregister_subscriber(self.database.to_string(), self.id);
138
160
        }
139
480
    }
140
}