1
use std::sync::Arc;
2

            
3
use async_trait::async_trait;
4
use circulate::{flume, Message, Relay};
5
use serde::Serialize;
6

            
7
use crate::Error;
8

            
9
/// Publishes and Subscribes to messages on topics.
10
#[async_trait]
11
pub trait PubSub {
12
    /// The Subscriber type for this `PubSub` connection.
13
    type Subscriber: Subscriber;
14

            
15
    /// Create a new [`Subscriber`] for this relay.
16
    async fn create_subscriber(&self) -> Result<Self::Subscriber, Error>;
17

            
18
    /// Publishes a `payload` to all subscribers of `topic`.
19
    async fn publish<S: Into<String> + Send, P: Serialize + Sync>(
20
        &self,
21
        topic: S,
22
        payload: &P,
23
    ) -> Result<(), Error>;
24

            
25
    /// Publishes a `payload` to all subscribers of all `topics`.
26
    async fn publish_to_all<P: Serialize + Sync>(
27
        &self,
28
        topics: Vec<String>,
29
        payload: &P,
30
    ) -> Result<(), Error>;
31
}
32

            
33
/// A subscriber to one or more topics.
34
#[async_trait]
35
pub trait Subscriber {
36
    /// Subscribe to [`Message`]s published to `topic`.
37
    async fn subscribe_to<S: Into<String> + Send>(&self, topic: S) -> Result<(), Error>;
38

            
39
    /// Unsubscribe from [`Message`]s published to `topic`.
40
    async fn unsubscribe_from(&self, topic: &str) -> Result<(), Error>;
41

            
42
    /// Returns the receiver to receive [`Message`]s.
43
    #[must_use]
44
    fn receiver(&self) -> &'_ flume::Receiver<Arc<Message>>;
45
}
46

            
47
#[async_trait]
48
impl PubSub for Relay {
49
    type Subscriber = circulate::Subscriber;
50

            
51
7
    async fn create_subscriber(&self) -> Result<Self::Subscriber, Error> {
52
7
        Ok(self.create_subscriber().await)
53
14
    }
54

            
55
    async fn publish<S: Into<String> + Send, P: Serialize + Sync>(
56
        &self,
57
        topic: S,
58
        payload: &P,
59
    ) -> Result<(), Error> {
60
        self.publish(topic, payload).await?;
61
        Ok(())
62
    }
63

            
64
    async fn publish_to_all<P: Serialize + Sync>(
65
        &self,
66
        topics: Vec<String>,
67
        payload: &P,
68
    ) -> Result<(), Error> {
69
        self.publish_to_all(topics, payload).await?;
70
        Ok(())
71
    }
72
}
73

            
74
#[async_trait]
75
impl Subscriber for circulate::Subscriber {
76
12
    async fn subscribe_to<S: Into<String> + Send>(&self, topic: S) -> Result<(), Error> {
77
12
        self.subscribe_to(topic).await;
78
12
        Ok(())
79
24
    }
80

            
81
1
    async fn unsubscribe_from(&self, topic: &str) -> Result<(), Error> {
82
1
        self.unsubscribe_from(topic).await;
83
1
        Ok(())
84
2
    }
85

            
86
    fn receiver(&self) -> &'_ flume::Receiver<Arc<Message>> {
87
        self.receiver()
88
    }
89
}
90

            
91
/// Creates a topic for use in a server. This is an internal API, which is why
92
/// the documentation is hidden. This is an implementation detail, but both
93
/// Client and Server must agree on this format, which is why it lives in core.
94
#[doc(hidden)]
95
#[must_use]
96
3380
pub fn database_topic(database: &str, topic: &str) -> String {
97
3380
    format!("{}\u{0}{}", database, topic)
98
3380
}
99

            
100
/// Expands into a suite of pubsub unit tests using the passed type as the test harness.
101
#[cfg(any(test, feature = "test-util"))]
102
#[cfg_attr(feature = "test-util", macro_export)]
103
macro_rules! define_pubsub_test_suite {
104
    ($harness:ident) => {
105
        #[cfg(test)]
106
        use $crate::pubsub::{PubSub, Subscriber};
107

            
108
        #[tokio::test]
109
6
        async fn simple_pubsub_test() -> anyhow::Result<()> {
110
            let harness = $harness::new($crate::test_util::HarnessTest::PubSubSimple).await?;
111
            let pubsub = harness.connect().await?;
112
            let subscriber = PubSub::create_subscriber(&pubsub).await?;
113
            Subscriber::subscribe_to(&subscriber, "mytopic").await?;
114
            pubsub.publish("mytopic", &String::from("test")).await?;
115
            pubsub.publish("othertopic", &String::from("test")).await?;
116
            let receiver = subscriber.receiver().clone();
117
            let message = receiver.recv_async().await.expect("No message received");
118
            assert_eq!(message.payload::<String>()?, "test");
119
            // The message should only be received once.
120
            assert!(matches!(
121
                tokio::task::spawn_blocking(
122
                    move || receiver.recv_timeout(std::time::Duration::from_millis(100))
123
                )
124
                .await,
125
                Ok(Err(_))
126
            ));
127
            Ok(())
128
        }
129

            
130
        #[tokio::test]
131
6
        async fn multiple_subscribers_test() -> anyhow::Result<()> {
132
            let harness =
133
                $harness::new($crate::test_util::HarnessTest::PubSubMultipleSubscribers).await?;
134
            let pubsub = harness.connect().await?;
135
            let subscriber_a = PubSub::create_subscriber(&pubsub).await?;
136
            let subscriber_ab = PubSub::create_subscriber(&pubsub).await?;
137
            Subscriber::subscribe_to(&subscriber_a, "a").await?;
138
            Subscriber::subscribe_to(&subscriber_ab, "a").await?;
139
            Subscriber::subscribe_to(&subscriber_ab, "b").await?;
140

            
141
            let mut messages_a = Vec::new();
142
            let mut messages_ab = Vec::new();
143
            pubsub.publish("a", &String::from("a1")).await?;
144
            messages_a.push(
145
                subscriber_a
146
                    .receiver()
147
                    .recv_async()
148
                    .await?
149
                    .payload::<String>()?,
150
            );
151
            messages_ab.push(
152
                subscriber_ab
153
                    .receiver()
154
                    .recv_async()
155
                    .await?
156
                    .payload::<String>()?,
157
            );
158

            
159
            pubsub.publish("b", &String::from("b1")).await?;
160
            messages_ab.push(
161
                subscriber_ab
162
                    .receiver()
163
                    .recv_async()
164
                    .await?
165
                    .payload::<String>()?,
166
            );
167

            
168
            pubsub.publish("a", &String::from("a2")).await?;
169
            messages_a.push(
170
                subscriber_a
171
                    .receiver()
172
                    .recv_async()
173
                    .await?
174
                    .payload::<String>()?,
175
            );
176
            messages_ab.push(
177
                subscriber_ab
178
                    .receiver()
179
                    .recv_async()
180
                    .await?
181
                    .payload::<String>()?,
182
            );
183

            
184
            assert_eq!(&messages_a[0], "a1");
185
            assert_eq!(&messages_a[1], "a2");
186

            
187
            assert_eq!(&messages_ab[0], "a1");
188
            assert_eq!(&messages_ab[1], "b1");
189
            assert_eq!(&messages_ab[2], "a2");
190

            
191
            Ok(())
192
        }
193

            
194
        #[tokio::test]
195
6
        async fn unsubscribe_test() -> anyhow::Result<()> {
196
            let harness = $harness::new($crate::test_util::HarnessTest::PubSubUnsubscribe).await?;
197
            let pubsub = harness.connect().await?;
198
            let subscriber = PubSub::create_subscriber(&pubsub).await?;
199
            Subscriber::subscribe_to(&subscriber, "a").await?;
200

            
201
            pubsub.publish("a", &String::from("a1")).await?;
202
            Subscriber::unsubscribe_from(&subscriber, "a").await?;
203
            pubsub.publish("a", &String::from("a2")).await?;
204
            Subscriber::subscribe_to(&subscriber, "a").await?;
205
            pubsub.publish("a", &String::from("a3")).await?;
206

            
207
            // Check subscriber_a for a1 and a2.
208
            let message = subscriber.receiver().recv_async().await?;
209
            assert_eq!(message.payload::<String>()?, "a1");
210
            let message = subscriber.receiver().recv_async().await?;
211
            assert_eq!(message.payload::<String>()?, "a3");
212

            
213
            Ok(())
214
        }
215

            
216
        #[tokio::test]
217
6
        async fn publish_to_all_test() -> anyhow::Result<()> {
218
            let harness = $harness::new($crate::test_util::HarnessTest::PubSubPublishAll).await?;
219
            let pubsub = harness.connect().await?;
220
            let subscriber_a = PubSub::create_subscriber(&pubsub).await?;
221
            let subscriber_b = PubSub::create_subscriber(&pubsub).await?;
222
            let subscriber_c = PubSub::create_subscriber(&pubsub).await?;
223
            Subscriber::subscribe_to(&subscriber_a, "1").await?;
224
            Subscriber::subscribe_to(&subscriber_b, "1").await?;
225
            Subscriber::subscribe_to(&subscriber_b, "2").await?;
226
            Subscriber::subscribe_to(&subscriber_c, "2").await?;
227
            Subscriber::subscribe_to(&subscriber_a, "3").await?;
228
            Subscriber::subscribe_to(&subscriber_c, "3").await?;
229

            
230
            pubsub
231
                .publish_to_all(
232
                    vec![String::from("1"), String::from("2"), String::from("3")],
233
                    &String::from("1"),
234
                )
235
                .await?;
236

            
237
            // Each subscriber should get "1" twice on separate topics
238
            for subscriber in &[subscriber_a, subscriber_b, subscriber_c] {
239
                let mut message_topics = Vec::new();
240
                for _ in 0..2_u8 {
241
                    let message = subscriber.receiver().recv_async().await?;
242
                    assert_eq!(message.payload::<String>()?, "1");
243
                    message_topics.push(message.topic.clone());
244
                }
245
                assert!(matches!(
246
                    subscriber.receiver().try_recv(),
247
                    Err(flume::TryRecvError::Empty)
248
                ));
249
                assert!(message_topics[0] != message_topics[1]);
250
            }
251

            
252
            Ok(())
253
        }
254
    };
255
}
256

            
257
#[cfg(test)]
258
mod tests {
259
    use super::*;
260
    use crate::test_util::HarnessTest;
261

            
262
    struct Harness {
263
        relay: Relay,
264
    }
265

            
266
    impl Harness {
267
4
        async fn new(_: HarnessTest) -> Result<Self, Error> {
268
4
            Ok(Self {
269
4
                relay: Relay::default(),
270
4
            })
271
4
        }
272

            
273
4
        async fn connect(&self) -> Result<Relay, Error> {
274
4
            Ok(self.relay.clone())
275
4
        }
276
    }
277

            
278
9
    define_pubsub_test_suite!(Harness);
279
}