1
use std::{convert::Infallible, fmt::Debug, hash::Hash};
2

            
3
use async_trait::async_trait;
4

            
5
use super::Manager;
6
use crate::tasks::{Job, Keyed};
7

            
8
#[derive(Debug)]
9
struct Echo<T>(T);
10

            
11
#[async_trait]
12
impl<T> Job for Echo<T>
13
where
14
    T: Clone + Eq + Hash + Debug + Send + Sync + 'static,
15
{
16
    type Output = T;
17
    type Error = Infallible;
18

            
19
2
    async fn execute(&mut self) -> Result<Self::Output, Self::Error> {
20
2
        Ok(self.0.clone())
21
2
    }
22
}
23

            
24
impl<T> Keyed<T> for Echo<T>
25
where
26
    T: Clone + Eq + Hash + Debug + Send + Sync + 'static,
27
{
28
3
    fn key(&self) -> T {
29
3
        self.0.clone()
30
3
    }
31
}
32

            
33
1
#[tokio::test]
34
1
async fn simple() -> Result<(), tokio::sync::oneshot::error::RecvError> {
35
1
    let manager = Manager::<usize>::default();
36
1
    manager.spawn_worker();
37
1
    let handle = manager.enqueue(Echo(1)).await;
38
1
    if let Ok(value) = handle.receive().await? {
39
1
        assert_eq!(value, 1);
40

            
41
1
        Ok(())
42
    } else {
43
        unreachable!()
44
    }
45
1
}
46

            
47
1
#[tokio::test]
48
1
async fn keyed_simple() -> Result<(), tokio::sync::oneshot::error::RecvError> {
49
1
    let manager = Manager::<usize>::default();
50
1
    let handle = manager.lookup_or_enqueue(Echo(1)).await;
51
1
    let handle2 = manager.lookup_or_enqueue(Echo(1)).await;
52
    // Tests that they received the same job id
53
1
    assert_eq!(handle.id, handle2.id);
54
1
    let mut handle3 = manager.lookup_or_enqueue(Echo(1)).await;
55
1
    assert_eq!(handle3.id, handle.id);
56

            
57
1
    manager.spawn_worker();
58

            
59
1
    let (result1, result2) = tokio::try_join!(handle.receive(), handle2.receive())?;
60
    // Because they're all the same handle, if those have returned, this one
61
    // should be available without blocking.
62
1
    let result3 = handle3
63
1
        .try_receive()
64
1
        .expect("try_receive failed even though other channels were available");
65

            
66
3
    for result in [result1, result2, result3] {
67
3
        assert_eq!(result.unwrap(), 1);
68
    }
69

            
70
1
    Ok(())
71
1
}