1
use std::time::Duration;
2

            
3
use async_lock::Mutex;
4
use bonsaidb_utils::fast_async_lock;
5
use tokio::sync::watch;
6

            
7
#[derive(Debug)]
8
pub struct Shutdown {
9
    sender: watch::Sender<ShutdownState>,
10
    receiver: Mutex<Option<watch::Receiver<ShutdownState>>>,
11
}
12

            
13
23
#[derive(Clone, Debug)]
14
pub enum ShutdownState {
15
    Running,
16
    GracefulShutdown,
17
    Shutdown,
18
}
19

            
20
impl Shutdown {
21
807
    pub fn new() -> Self {
22
807
        let (sender, receiver) = watch::channel(ShutdownState::Running);
23
807
        Self {
24
807
            sender,
25
807
            receiver: Mutex::new(Some(receiver)),
26
807
        }
27
807
    }
28

            
29
3010
    pub async fn watcher(&self) -> Option<ShutdownStateWatcher> {
30
215
        let receiver = fast_async_lock!(self.receiver);
31
215
        receiver
32
215
            .clone()
33
215
            .map(|receiver| ShutdownStateWatcher { receiver })
34
215
    }
35

            
36
114
    async fn stop_watching(&self) {
37
36
        let mut receiver = fast_async_lock!(self.receiver);
38
36
        *receiver = None;
39
36
    }
40

            
41
56
    pub async fn graceful_shutdown(&self, timeout: Duration) {
42
4
        self.stop_watching().await;
43
4
        if self.sender.send(ShutdownState::GracefulShutdown).is_ok()
44
4
            && tokio::time::timeout(timeout, self.sender.closed())
45
4
                .await
46
4
                .is_err()
47
        {
48
            // Failed to gracefully shut down. If we gracefully shut down, there
49
            // are no watchers remaining, therefore updating the state doesn't
50
            // matter.
51
            self.shutdown().await;
52
4
        }
53
4
    }
54

            
55
58
    pub async fn shutdown(&self) {
56
32
        self.stop_watching().await;
57
32
        drop(self.sender.send(ShutdownState::Shutdown));
58
32
    }
59

            
60
14
    pub fn should_shutdown(&self) -> bool {
61
14
        matches!(&*self.sender.borrow(), ShutdownState::Shutdown)
62
14
    }
63
}
64

            
65
239
#[derive(Clone)]
66
pub struct ShutdownStateWatcher {
67
    receiver: watch::Receiver<ShutdownState>,
68
}
69

            
70
impl ShutdownStateWatcher {
71
2542232
    pub async fn wait_for_shutdown(&mut self) -> ShutdownState {
72
143527
        if self.receiver.changed().await.is_ok() {
73
23
            self.receiver.borrow().clone()
74
        } else {
75
            ShutdownState::Shutdown
76
        }
77
23
    }
78
}