1
use std::time::Duration;
2

            
3
use async_lock::Mutex;
4
use bonsaidb_utils::fast_async_lock;
5
use tokio::{sync::watch, time::Instant};
6

            
7
#[derive(Debug)]
8
pub struct Shutdown {
9
    sender: watch::Sender<ShutdownState>,
10
    receiver: Mutex<Option<watch::Receiver<ShutdownState>>>,
11
}
12

            
13
6
#[derive(Clone, Debug)]
14
pub enum ShutdownState {
15
    Running,
16
    GracefulShutdown,
17
    Shutdown,
18
}
19

            
20
impl Shutdown {
21
913
    pub fn new() -> Self {
22
913
        let (sender, receiver) = watch::channel(ShutdownState::Running);
23
913
        Self {
24
913
            sender,
25
913
            receiver: Mutex::new(Some(receiver)),
26
913
        }
27
913
    }
28

            
29
702
    pub async fn watcher(&self) -> Option<ShutdownStateWatcher> {
30
39
        let receiver = fast_async_lock!(self.receiver);
31
39
        receiver
32
39
            .clone()
33
39
            .map(|receiver| ShutdownStateWatcher { receiver })
34
39
    }
35

            
36
97
    async fn stop_watching(&self) {
37
29
        let mut receiver = fast_async_lock!(self.receiver);
38
29
        *receiver = None;
39
29
    }
40

            
41
72
    pub async fn graceful_shutdown(&self, timeout: Duration) {
42
4
        self.stop_watching().await;
43
4
        if self.sender.send(ShutdownState::GracefulShutdown).is_ok() {
44
4
            let timeout = tokio::time::sleep_until(Instant::now() + timeout);
45
4
            if !tokio::select! {
46
8
                _ = self.sender.closed() => true,
47
8
                _ = timeout => false,
48
8
            } {
49
                // Failed to gracefully shut down
50
                self.shutdown().await;
51
4
            }
52
        }
53
4
    }
54

            
55
25
    pub async fn shutdown(&self) {
56
25
        self.stop_watching().await;
57
25
        drop(self.sender.send(ShutdownState::Shutdown));
58
25
    }
59
}
60

            
61
pub struct ShutdownStateWatcher {
62
    receiver: watch::Receiver<ShutdownState>,
63
}
64

            
65
impl ShutdownStateWatcher {
66
2844
    pub async fn wait_for_shutdown(&mut self) -> ShutdownState {
67
157
        if self.receiver.changed().await.is_ok() {
68
6
            self.receiver.borrow().clone()
69
        } else {
70
            ShutdownState::Shutdown
71
        }
72
6
    }
73
}