1
use std::sync::Arc;
2

            
3
use async_trait::async_trait;
4
use bonsaidb_core::{
5
    arc_bytes::serde::Bytes,
6
    networking::{CreateSubscriber, Publish, PublishToAll, SubscribeTo, UnsubscribeFrom},
7
    pubsub::{AsyncPubSub, AsyncSubscriber, Receiver},
8
};
9

            
10
use crate::Client;
11

            
12
#[async_trait]
13
impl AsyncPubSub for super::RemoteDatabase {
14
    type Subscriber = RemoteSubscriber;
15

            
16
304
    async fn create_subscriber(&self) -> Result<Self::Subscriber, bonsaidb_core::Error> {
17
304
        let subscriber_id = self
18
304
            .client
19
304
            .send_api_request_async(&CreateSubscriber {
20
304
                database: self.name.to_string(),
21
304
            })
22
304
            .await?;
23

            
24
304
        let (sender, receiver) = flume::unbounded();
25
304
        self.client.register_subscriber(subscriber_id, sender);
26
304
        Ok(RemoteSubscriber {
27
304
            client: self.client.clone(),
28
304
            database: self.name.clone(),
29
304
            id: subscriber_id,
30
304
            receiver: Receiver::new(receiver),
31
304
            #[cfg(not(target_arch = "wasm32"))]
32
304
            tokio: tokio::runtime::Handle::try_current().ok().map(Arc::new),
33
304
        })
34
608
    }
35

            
36
380
    async fn publish_bytes(
37
380
        &self,
38
380
        topic: Vec<u8>,
39
380
        payload: Vec<u8>,
40
380
    ) -> Result<(), bonsaidb_core::Error> {
41
380
        self.client
42
380
            .send_api_request_async(&Publish {
43
380
                database: self.name.to_string(),
44
380
                topic: Bytes::from(topic),
45
380
                payload: Bytes::from(payload),
46
380
            })
47
380
            .await?;
48
380
        Ok(())
49
760
    }
50

            
51
2
    async fn publish_bytes_to_all(
52
2
        &self,
53
2
        topics: impl IntoIterator<Item = Vec<u8>> + Send + 'async_trait,
54
2
        payload: Vec<u8>,
55
2
    ) -> Result<(), bonsaidb_core::Error> {
56
2
        let topics = topics.into_iter().map(Bytes::from).collect();
57
2
        self.client
58
2
            .send_api_request_async(&PublishToAll {
59
2
                database: self.name.to_string(),
60
2
                topics,
61
2
                payload: Bytes::from(payload),
62
2
            })
63
2
            .await?;
64
2
        Ok(())
65
4
    }
66
}
67

            
68
/// A `PubSub` subscriber from a remote server.
69
#[derive(Debug)]
70
pub struct RemoteSubscriber {
71
    pub(crate) client: Client,
72
    pub(crate) database: Arc<String>,
73
    pub(crate) id: u64,
74
    pub(crate) receiver: Receiver,
75
    #[cfg(not(target_arch = "wasm32"))]
76
    pub(crate) tokio: Option<Arc<tokio::runtime::Handle>>,
77
}
78

            
79
#[async_trait]
80
impl AsyncSubscriber for RemoteSubscriber {
81
494
    async fn subscribe_to_bytes(&self, topic: Vec<u8>) -> Result<(), bonsaidb_core::Error> {
82
494
        self.client
83
494
            .send_api_request_async(&SubscribeTo {
84
494
                database: self.database.to_string(),
85
494
                subscriber_id: self.id,
86
494
                topic: Bytes::from(topic),
87
494
            })
88
494
            .await?;
89
494
        Ok(())
90
988
    }
91

            
92
38
    async fn unsubscribe_from_bytes(&self, topic: &[u8]) -> Result<(), bonsaidb_core::Error> {
93
38
        self.client
94
38
            .send_api_request_async(&UnsubscribeFrom {
95
38
                database: self.database.to_string(),
96
38
                subscriber_id: self.id,
97
38
                topic: Bytes::from(topic),
98
38
            })
99
38
            .await?;
100
38
        Ok(())
101
76
    }
102

            
103
1026
    fn receiver(&self) -> &Receiver {
104
1026
        &self.receiver
105
1026
    }
106
}
107

            
108
#[cfg(target_arch = "wasm32")]
109
impl Drop for RemoteSubscriber {
110
    fn drop(&mut self) {
111
        let client = self.client.clone();
112
        let database = self.database.to_string();
113
        let subscriber_id = self.id;
114
        let drop_future = async move {
115
            client
116
                .unregister_subscriber_async(database, subscriber_id)
117
                .await;
118
        };
119
        wasm_bindgen_futures::spawn_local(drop_future);
120
    }
121
}
122

            
123
#[cfg(not(target_arch = "wasm32"))]
124
impl Drop for RemoteSubscriber {
125
    fn drop(&mut self) {
126
456
        if let Some(tokio) = &self.tokio {
127
304
            let client = self.client.clone();
128
304
            let database = self.database.to_string();
129
304
            let subscriber_id = self.id;
130
304
            tokio.spawn(async move {
131
38
                client
132
38
                    .unregister_subscriber_async(database, subscriber_id)
133
38
                    .await;
134
304
            });
135
304
        } else {
136
152
            self.client
137
152
                .unregister_subscriber(self.database.to_string(), self.id);
138
152
        }
139
456
    }
140
}