1
use std::{sync::Arc, time::Duration};
2

            
3
use async_acme::cache::AcmeCache;
4
use async_trait::async_trait;
5
use bonsaidb_core::{
6
    arc_bytes::serde::Bytes,
7
    connection::Connection,
8
    define_basic_unique_mapped_view,
9
    document::{CollectionDocument, KeyId},
10
    schema::{Collection, CollectionName, DefaultSerialization, Schematic, SerializedCollection},
11
    ENCRYPTION_ENABLED,
12
};
13
use serde::{Deserialize, Serialize};
14

            
15
use crate::{Backend, CustomServer, Error};
16

            
17
#[derive(Debug, Serialize, Deserialize)]
18
pub struct AcmeAccount {
19
    pub contacts: Vec<String>,
20
    pub data: Bytes,
21
}
22

            
23
impl Collection for AcmeAccount {
24
3802
    fn encryption_key() -> Option<KeyId> {
25
3802
        if ENCRYPTION_ENABLED {
26
3802
            Some(KeyId::Master)
27
        } else {
28
            None
29
        }
30
3802
    }
31

            
32
11405
    fn collection_name() -> CollectionName {
33
11405
        CollectionName::new("khonsulabs", "acme-accounts")
34
11405
    }
35

            
36
    fn define_views(schema: &mut Schematic) -> Result<(), bonsaidb_core::Error> {
37
3802
        schema.define_view(AcmeAccountByContacts)?;
38
3802
        Ok(())
39
3802
    }
40
}
41

            
42
impl DefaultSerialization for AcmeAccount {}
43

            
44
define_basic_unique_mapped_view!(
45
    AcmeAccountByContacts,
46
    AcmeAccount,
47
    1,
48
    "by-contacts",
49
    String,
50
    |document: CollectionDocument<AcmeAccount>| {
51
        document
52
            .header
53
            .emit_key(document.contents.contacts.join(";"))
54
    }
55
);
56

            
57
#[async_trait]
58
impl<B: Backend> AcmeCache for CustomServer<B> {
59
    type Error = Error;
60

            
61
    async fn read_account(&self, contacts: &[&str]) -> Result<Option<Vec<u8>>, Self::Error> {
62
        let db = self.hosted().await;
63
        let contact = db
64
            .view::<AcmeAccountByContacts>()
65
            .with_key(contacts.join(";"))
66
            .query_with_collection_docs()
67
            .await?
68
            .documents
69
            .into_iter()
70
            .next();
71

            
72
        if let Some((_, contact)) = contact {
73
            Ok(Some(contact.contents.data.into_vec()))
74
        } else {
75
            Ok(None)
76
        }
77
    }
78

            
79
    async fn write_account(&self, contacts: &[&str], contents: &[u8]) -> Result<(), Self::Error> {
80
        let db = self.hosted().await;
81
        let mapped_account = db
82
            .view::<AcmeAccountByContacts>()
83
            .with_key(contacts.join(";"))
84
            .query_with_collection_docs()
85
            .await?
86
            .documents
87
            .into_iter()
88
            .next();
89
        if let Some((_, mut account)) = mapped_account {
90
            account.contents.data = Bytes::from(contents);
91
            account.update(&db).await?;
92
        } else {
93
            AcmeAccount {
94
                contacts: contacts.iter().map(|&c| c.to_string()).collect(),
95
                data: Bytes::from(contents),
96
            }
97
            .push_into(&db)
98
            .await?;
99
        }
100

            
101
        Ok(())
102
    }
103

            
104
    async fn write_certificate(
105
        &self,
106
        _domains: &[String],
107
        _directory_url: &str,
108
        key_pem: &str,
109
        certificate_pem: &str,
110
    ) -> Result<(), Self::Error> {
111
        self.install_pem_certificate(certificate_pem.as_bytes(), key_pem.as_bytes())
112
            .await
113
    }
114
}
115

            
116
impl<B: Backend> CustomServer<B> {
117
1
    pub(crate) async fn update_acme_certificates(&self) -> Result<(), Error> {
118
        loop {
119
            {
120
1
                let key = self.data.primary_tls_key.lock().clone();
121
1
                while async_acme::rustls_helper::duration_until_renewal_attempt(key.as_deref(), 0)
122
1
                    > Duration::from_secs(24 * 60 * 60 * 14)
123
                {
124
1
                    tokio::time::sleep(Duration::from_secs(60 * 60)).await;
125
                }
126
            }
127

            
128
            log::info!(
129
                "requesting new tls certificate for {}",
130
                self.data.primary_domain
131
            );
132
            let domains = vec![self.data.primary_domain.clone()];
133
            async_acme::rustls_helper::order(
134
                |domain, key| {
135
                    let mut auth_keys = self.data.alpn_keys.lock().unwrap();
136
                    auth_keys.insert(domain, Arc::new(key));
137
                    Ok(())
138
                },
139
                &self.data.acme.directory,
140
                &domains,
141
                Some(self),
142
                &self
143
                    .data
144
                    .acme
145
                    .contact_email
146
                    .iter()
147
                    .cloned()
148
                    .collect::<Vec<_>>(),
149
            )
150
            .await?;
151
        }
152
    }
153
}