1
use std::{
2
    sync::Arc,
3
    thread::JoinHandle,
4
    time::{Duration, Instant},
5
};
6

            
7
use argon2::{
8
    password_hash::{ParamsString, SaltString},
9
    Algorithm, Argon2, Block, ParamsBuilder, PasswordHash, Version,
10
};
11
use bonsaidb_core::connection::SensitiveString;
12
use once_cell::sync::OnceCell;
13
use rand::{thread_rng, CryptoRng, Rng};
14
use tokio::sync::oneshot;
15

            
16
use crate::{
17
    config::{ArgonConfiguration, ArgonParams},
18
    Error,
19
};
20

            
21
#[derive(Debug)]
22
#[cfg_attr(not(test), allow(dead_code))]
23
pub struct Hasher {
24
    sender: flume::Sender<HashRequest>,
25
    threads: Vec<JoinHandle<()>>,
26
}
27

            
28
impl Hasher {
29
2254
    pub fn new(config: ArgonConfiguration) -> Self {
30
2254
        let (sender, receiver) = flume::unbounded();
31
2254
        let thread = HashingThread {
32
2254
            receiver,
33
2254
            algorithm: config.algorithm,
34
2254
            params: config.params,
35
2254
            blocks: Vec::default(),
36
2254
            builder_template: Arc::default(),
37
2254
        };
38
2254
        let mut threads = Vec::with_capacity(config.hashers as usize);
39
2254
        for _ in 0..config.hashers.max(1) {
40
2254
            let thread = thread.clone();
41
2254
            threads.push(std::thread::spawn(move || thread.process_requests()));
42
2254
        }
43
2254
        Hasher { sender, threads }
44
2254
    }
45

            
46
73
    pub async fn hash(&self, id: u64, password: SensitiveString) -> Result<SensitiveString, Error> {
47
4
        let (result_sender, result_receiver) = oneshot::channel();
48
4
        if self
49
4
            .sender
50
4
            .send(HashRequest {
51
4
                id,
52
4
                password,
53
4
                verify_against: None,
54
4
                result_sender,
55
4
            })
56
4
            .is_ok()
57
        {
58
4
            match result_receiver.await?.map_err(Error::from) {
59
4
                Ok(HashResponse::Hash(hash)) => Ok(hash),
60
                Ok(HashResponse::Verified) => unreachable!(),
61
                Err(err) => Err(err),
62
            }
63
        } else {
64
            Err(Error::InternalCommunication)
65
        }
66
4
    }
67

            
68
121
    pub async fn verify(
69
121
        &self,
70
121
        id: u64,
71
121
        password: SensitiveString,
72
121
        saved_hash: SensitiveString,
73
121
    ) -> Result<(), Error> {
74
6
        let (result_sender, result_receiver) = oneshot::channel();
75
6
        if self
76
6
            .sender
77
6
            .send(HashRequest {
78
6
                id,
79
6
                password,
80
6
                verify_against: Some(saved_hash),
81
6
                result_sender,
82
6
            })
83
6
            .is_ok()
84
        {
85
6
            match result_receiver.await?.map_err(Error::from) {
86
6
                Ok(_) => Ok(()),
87
                Err(err) => {
88
                    eprintln!("Error validating password for user {}: {:?}", id, err);
89
                    Err(Error::Core(bonsaidb_core::Error::InvalidCredentials))
90
                }
91
            }
92
        } else {
93
            Err(Error::InternalCommunication)
94
        }
95
6
    }
96
}
97

            
98
2254
#[derive(Clone, Debug)]
99
struct HashingThread {
100
    receiver: flume::Receiver<HashRequest>,
101
    algorithm: Algorithm,
102
    params: ArgonParams,
103
    blocks: Vec<Block>,
104
    builder_template: Arc<OnceCell<Result<ParamsBuilder, ArgonError>>>,
105
}
106

            
107
impl HashingThread {
108
73
    fn initialize_builder_template<R: Rng + CryptoRng>(
109
73
        &mut self,
110
73
        rng: &mut R,
111
73
    ) -> Result<ParamsBuilder, ArgonError> {
112
73
        match &self.params {
113
            ArgonParams::Params(builder) => Ok(builder.clone()),
114
73
            ArgonParams::Timed(config) => {
115
73
                let mut params_builder = ParamsBuilder::new();
116
73
                let params = params_builder
117
73
                    .m_cost(config.ram_per_hasher)?
118
73
                    .p_cost(config.lanes)?
119
73
                    .data(&0_u64.to_be_bytes())?;
120
73
                let salt = SaltString::generate(rng);
121
73
                let mut salt_arr = [0u8; 64];
122
73
                let salt_bytes = salt.b64_decode(&mut salt_arr)?;
123
73
                let mut output = Vec::default();
124
73

            
125
73
                let minimum_duration = config.minimum_duration;
126
73
                let mut min_cost = 1;
127
73
                let mut total_spent_t = 0;
128
73
                let mut total_duration = Duration::ZERO;
129

            
130
                loop {
131
73
                    let t_cost = if total_spent_t > 0 {
132
                        let average_duration_per_t = total_duration / total_spent_t;
133
                        u32::try_from(ceil_divide(
134
                            minimum_duration.as_nanos(),
135
                            average_duration_per_t.as_nanos(),
136
                        ))
137
                        .unwrap()
138
                        .max(min_cost)
139
                    } else {
140
73
                        min_cost
141
                    };
142
73
                    params.t_cost(t_cost)?;
143

            
144
73
                    let params = params.clone().params()?;
145
73
                    self.allocate_blocks(&params);
146
73
                    let output_len = params
147
73
                        .output_len()
148
73
                        .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
149
73
                    output.resize(output_len, 0);
150
73

            
151
73
                    let start = Instant::now();
152
73
                    let argon = Argon2::new(self.algorithm, Version::V0x13, params);
153
73
                    argon.hash_password_into_with_memory(
154
73
                        b"hunter2",
155
73
                        salt_bytes,
156
73
                        &mut output[..],
157
73
                        &mut self.blocks,
158
73
                    )?;
159
73
                    let elapsed = match Instant::now().checked_duration_since(start) {
160
73
                        Some(elapsed) => elapsed,
161
                        None => continue,
162
                    };
163
73
                    if elapsed < minimum_duration {
164
                        total_spent_t += t_cost;
165
                        total_duration += elapsed;
166
                        min_cost = t_cost + 1;
167
                    } else {
168
                        // TODO if it's too far past the minimum duration, maybe we should try again at a smaller cost?
169
73
                        break;
170
73
                    }
171
73
                }
172
73
                Ok(params_builder)
173
            }
174
        }
175
73
    }
176

            
177
2254
    fn process_requests(mut self) {
178
2254
        let mut rng = thread_rng();
179
2448
        while let Ok(request) = self.receiver.recv() {
180
422
            let result = if let Some(verify_against) = &request.verify_against {
181
121
                Self::verify(&request, verify_against)
182
            } else {
183
73
                self.hash(&request, &mut rng)
184
            };
185
194
            drop(request.result_sender.send(result));
186
        }
187
2026
    }
188

            
189
121
    fn verify(request: &HashRequest, hash: &str) -> Result<HashResponse, Error> {
190
121
        let hash = PasswordHash::new(hash)?;
191

            
192
121
        let algorithm = Algorithm::try_from(hash.algorithm)?;
193
121
        let version = hash
194
121
            .version
195
121
            .map(Version::try_from)
196
121
            .transpose()?
197
121
            .unwrap_or(Version::V0x13);
198

            
199
121
        let argon = Argon2::new(algorithm, version, argon2::Params::try_from(&hash)?);
200

            
201
121
        hash.verify_password(&[&argon], request.password.0.as_bytes())
202
121
            .map(|_| HashResponse::Verified)
203
121
            .map_err(Error::from)
204
121
    }
205

            
206
73
    fn hash<R: Rng + CryptoRng>(
207
73
        &mut self,
208
73
        request: &HashRequest,
209
73
        rng: &mut R,
210
73
    ) -> Result<HashResponse, Error> {
211
73
        let builder_template = self.builder_template.clone();
212
73
        let mut params =
213
73
            match builder_template.get_or_init(|| self.initialize_builder_template(rng)) {
214
73
                Ok(template) => template.clone(),
215
                Err(error) => return Err(Error::PasswordHash(error.to_string())),
216
            };
217

            
218
73
        let id = request.id.to_be_bytes();
219
73
        params.data(&id)?;
220

            
221
73
        let params = params.params()?;
222
73
        self.allocate_blocks(&params);
223
73

            
224
73
        let salt = SaltString::generate(rng);
225
73
        let mut salt_arr = [0u8; 64];
226
73
        let salt_bytes = salt.b64_decode(&mut salt_arr)?;
227

            
228
73
        let argon = Argon2::new(self.algorithm, Version::V0x13, params);
229
73

            
230
73
        let output_len = argon
231
73
            .params()
232
73
            .output_len()
233
73
            .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
234
73
        let output = argon2::password_hash::Output::init_with(output_len, |out| {
235
73
            Ok(argon.hash_password_into_with_memory(
236
73
                request.password.as_bytes(),
237
73
                salt_bytes,
238
73
                out,
239
73
                &mut self.blocks,
240
73
            )?)
241
73
        })?;
242

            
243
        Ok(HashResponse::Hash(SensitiveString(
244
            PasswordHash {
245
73
                algorithm: self.algorithm.ident(),
246
73
                version: Some(Version::V0x13.into()),
247
73
                params: ParamsString::try_from(argon.params())?,
248
73
                salt: Some(salt.as_salt()),
249
73
                hash: Some(output),
250
73
            }
251
73
            .to_string(),
252
        )))
253
73
    }
254

            
255
146
    fn allocate_blocks(&mut self, params: &argon2::Params) {
256
2392064
        for _ in self.blocks.len()..params.block_count() {
257
2392064
            self.blocks.push(Block::default());
258
2392064
        }
259
146
    }
260
}
261

            
262
#[derive(Debug)]
263
pub struct HashRequest {
264
    id: u64,
265
    password: SensitiveString,
266
    verify_against: Option<SensitiveString>,
267
    result_sender: oneshot::Sender<Result<HashResponse, Error>>,
268
}
269

            
270
#[derive(Debug)]
271
pub enum HashResponse {
272
    Hash(SensitiveString),
273
    Verified,
274
}
275

            
276
#[derive(thiserror::Error, Debug)]
277
enum ArgonError {
278
    #[error("{0}")]
279
    Argon(#[from] argon2::Error),
280
    #[error("{0}")]
281
    Hash(#[from] argon2::password_hash::Error),
282
}
283

            
284
fn ceil_divide(dividend: u128, divisor: u128) -> u128 {
285
    match divisor {
286
        0 => panic!("divide by 0"),
287
        1 => dividend,
288
        _ => {
289
            let rounding = divisor - 1;
290
            (dividend + rounding) / divisor
291
        }
292
    }
293
}
294

            
295
1
#[tokio::test]
296
1
async fn basic_test() {
297
1
    use crate::config::SystemDefault;
298
1
    let hasher = Hasher::new(ArgonConfiguration::default());
299
1

            
300
1
    let password = SensitiveString(String::from("hunter2"));
301
1
    let hash = hasher.hash(1, password.clone()).await.unwrap();
302
1
    hasher
303
1
        .verify(1, password.clone(), hash.clone())
304
1
        .await
305
1
        .unwrap();
306
1

            
307
1
    let Hasher { sender, threads } = hasher;
308
1
    drop(sender);
309
2
    for thread in threads {
310
1
        thread.join().unwrap();
311
1
    }
312
1
}