1
use std::{
2
    sync::Arc,
3
    thread::JoinHandle,
4
    time::{Duration, Instant},
5
};
6

            
7
use argon2::{
8
    password_hash::{ParamsString, SaltString},
9
    Algorithm, Argon2, Block, ParamsBuilder, PasswordHash, Version,
10
};
11
use bonsaidb_core::connection::SensitiveString;
12
use once_cell::sync::OnceCell;
13
use rand::{thread_rng, CryptoRng, Rng};
14

            
15
use crate::{
16
    config::{ArgonConfiguration, ArgonParams},
17
    Error,
18
};
19

            
20
#[derive(Debug)]
21
#[cfg_attr(not(test), allow(dead_code))]
22
pub struct Hasher {
23
    sender: flume::Sender<HashRequest>,
24
    threads: Vec<JoinHandle<()>>,
25
}
26

            
27
impl Hasher {
28
3408
    pub fn new(config: ArgonConfiguration) -> Self {
29
3408
        let (sender, receiver) = flume::unbounded();
30
3408
        let thread = HashingThread {
31
3408
            receiver,
32
3408
            algorithm: config.algorithm,
33
3408
            params: config.params,
34
3408
            blocks: Vec::default(),
35
3408
            builder_template: Arc::default(),
36
3408
        };
37
3408
        let mut threads = Vec::with_capacity(config.hashers as usize);
38
3408
        for _ in 0..config.hashers.max(1) {
39
3408
            let thread = thread.clone();
40
3408
            threads.push(
41
3408
                std::thread::Builder::new()
42
3408
                    .name(String::from("argon2"))
43
3408
                    .spawn(move || thread.process_requests())
44
3408
                    .unwrap(),
45
3408
            );
46
3408
        }
47
3408
        Hasher { sender, threads }
48
3408
    }
49

            
50
91
    pub fn hash(&self, id: u64, password: SensitiveString) -> Result<SensitiveString, Error> {
51
91
        let (result_sender, result_receiver) = flume::bounded(1);
52
91
        if self
53
91
            .sender
54
91
            .send(HashRequest {
55
91
                id,
56
91
                password,
57
91
                verify_against: None,
58
91
                result_sender,
59
91
            })
60
91
            .is_ok()
61
        {
62
91
            match result_receiver.recv()?.map_err(Error::from) {
63
91
                Ok(HashResponse::Hash(hash)) => Ok(hash),
64
                Ok(HashResponse::Verified) => unreachable!(),
65
                Err(err) => Err(err),
66
            }
67
        } else {
68
            Err(Error::InternalCommunication)
69
        }
70
91
    }
71

            
72
151
    pub fn verify(
73
151
        &self,
74
151
        id: u64,
75
151
        password: SensitiveString,
76
151
        saved_hash: SensitiveString,
77
151
    ) -> Result<(), Error> {
78
151
        let (result_sender, result_receiver) = flume::bounded(1);
79
151
        if self
80
151
            .sender
81
151
            .send(HashRequest {
82
151
                id,
83
151
                password,
84
151
                verify_against: Some(saved_hash),
85
151
                result_sender,
86
151
            })
87
151
            .is_ok()
88
        {
89
151
            match result_receiver.recv()?.map_err(Error::from) {
90
151
                Ok(_) => Ok(()),
91
                Err(err) => {
92
                    eprintln!("Error validating password for user {}: {:?}", id, err);
93
                    Err(Error::Core(bonsaidb_core::Error::InvalidCredentials))
94
                }
95
            }
96
        } else {
97
            Err(Error::InternalCommunication)
98
        }
99
151
    }
100
}
101

            
102
3408
#[derive(Clone, Debug)]
103
struct HashingThread {
104
    receiver: flume::Receiver<HashRequest>,
105
    algorithm: Algorithm,
106
    params: ArgonParams,
107
    blocks: Vec<Block>,
108
    builder_template: Arc<OnceCell<Result<ParamsBuilder, ArgonError>>>,
109
}
110

            
111
impl HashingThread {
112
91
    fn initialize_builder_template<R: Rng + CryptoRng>(
113
91
        &mut self,
114
91
        rng: &mut R,
115
91
    ) -> Result<ParamsBuilder, ArgonError> {
116
91
        match &self.params {
117
            ArgonParams::Params(builder) => Ok(builder.clone()),
118
91
            ArgonParams::Timed(config) => {
119
91
                let mut params_builder = ParamsBuilder::new();
120
91
                let params = params_builder
121
91
                    .m_cost(config.ram_per_hasher)?
122
91
                    .p_cost(config.lanes)?
123
91
                    .data(&0_u64.to_be_bytes())?;
124
91
                let salt = SaltString::generate(rng);
125
91
                let mut salt_arr = [0u8; 64];
126
91
                let salt_bytes = salt.b64_decode(&mut salt_arr)?;
127
91
                let mut output = Vec::default();
128
91

            
129
91
                let minimum_duration = config.minimum_duration;
130
91
                let mut min_cost = 1;
131
91
                let mut total_spent_t = 0;
132
91
                let mut total_duration = Duration::ZERO;
133

            
134
                loop {
135
91
                    let t_cost = if total_spent_t > 0 {
136
                        let average_duration_per_t = total_duration / total_spent_t;
137
                        u32::try_from(ceil_divide(
138
                            minimum_duration.as_nanos(),
139
                            average_duration_per_t.as_nanos(),
140
                        ))
141
                        .unwrap()
142
                        .max(min_cost)
143
                    } else {
144
91
                        min_cost
145
                    };
146
91
                    params.t_cost(t_cost)?;
147

            
148
91
                    let params = params.clone().params()?;
149
91
                    self.allocate_blocks(&params);
150
91
                    let output_len = params
151
91
                        .output_len()
152
91
                        .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
153
91
                    output.resize(output_len, 0);
154
91

            
155
91
                    let start = Instant::now();
156
91
                    let argon = Argon2::new(self.algorithm, Version::V0x13, params);
157
91
                    argon.hash_password_into_with_memory(
158
91
                        b"hunter2",
159
91
                        salt_bytes,
160
91
                        &mut output[..],
161
91
                        &mut self.blocks,
162
91
                    )?;
163
91
                    let elapsed = match Instant::now().checked_duration_since(start) {
164
91
                        Some(elapsed) => elapsed,
165
                        None => continue,
166
                    };
167
91
                    if elapsed < minimum_duration {
168
                        total_spent_t += t_cost;
169
                        total_duration += elapsed;
170
                        min_cost = t_cost + 1;
171
                    } else {
172
                        // TODO if it's too far past the minimum duration, maybe we should try again at a smaller cost?
173
91
                        break;
174
91
                    }
175
91
                }
176
91
                Ok(params_builder)
177
            }
178
        }
179
91
    }
180

            
181
3408
    fn process_requests(mut self) {
182
3408
        let mut rng = thread_rng();
183
3650
        while let Ok(request) = self.receiver.recv() {
184
332
            let result = if let Some(verify_against) = &request.verify_against {
185
151
                Self::verify(&request, verify_against)
186
            } else {
187
91
                self.hash(&request, &mut rng)
188
            };
189
242
            drop(request.result_sender.send(result));
190
        }
191
3318
    }
192

            
193
151
    fn verify(request: &HashRequest, hash: &str) -> Result<HashResponse, Error> {
194
151
        let hash = PasswordHash::new(hash)?;
195

            
196
151
        let algorithm = Algorithm::try_from(hash.algorithm)?;
197
151
        let version = hash
198
151
            .version
199
151
            .map(Version::try_from)
200
151
            .transpose()?
201
151
            .unwrap_or(Version::V0x13);
202

            
203
151
        let argon = Argon2::new(algorithm, version, argon2::Params::try_from(&hash)?);
204

            
205
151
        hash.verify_password(&[&argon], request.password.0.as_bytes())
206
151
            .map(|_| HashResponse::Verified)
207
151
            .map_err(Error::from)
208
151
    }
209

            
210
91
    fn hash<R: Rng + CryptoRng>(
211
91
        &mut self,
212
91
        request: &HashRequest,
213
91
        rng: &mut R,
214
91
    ) -> Result<HashResponse, Error> {
215
91
        let builder_template = self.builder_template.clone();
216
91
        let mut params =
217
91
            match builder_template.get_or_init(|| self.initialize_builder_template(rng)) {
218
91
                Ok(template) => template.clone(),
219
                Err(error) => return Err(Error::PasswordHash(error.to_string())),
220
            };
221

            
222
91
        params.data(&request.id.to_be_bytes())?;
223

            
224
91
        let params = params.params()?;
225
91
        self.allocate_blocks(&params);
226
91

            
227
91
        let salt = SaltString::generate(rng);
228
91
        let mut salt_arr = [0u8; 64];
229
91
        let salt_bytes = salt.b64_decode(&mut salt_arr)?;
230

            
231
91
        let argon = Argon2::new(self.algorithm, Version::V0x13, params);
232
91

            
233
91
        let output_len = argon
234
91
            .params()
235
91
            .output_len()
236
91
            .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
237
91
        let output = argon2::password_hash::Output::init_with(output_len, |out| {
238
91
            Ok(argon.hash_password_into_with_memory(
239
91
                request.password.as_bytes(),
240
91
                salt_bytes,
241
91
                out,
242
91
                &mut self.blocks,
243
91
            )?)
244
91
        })?;
245

            
246
        Ok(HashResponse::Hash(SensitiveString(
247
            PasswordHash {
248
91
                algorithm: self.algorithm.ident(),
249
91
                version: Some(Version::V0x13.into()),
250
91
                params: ParamsString::try_from(argon.params())?,
251
91
                salt: Some(salt.as_salt()),
252
91
                hash: Some(output),
253
91
            }
254
91
            .to_string(),
255
        )))
256
91
    }
257

            
258
182
    fn allocate_blocks(&mut self, params: &argon2::Params) {
259
2981888
        for _ in self.blocks.len()..params.block_count() {
260
2981888
            self.blocks.push(Block::default());
261
2981888
        }
262
182
    }
263
}
264

            
265
#[derive(Debug)]
266
pub struct HashRequest {
267
    id: u64,
268
    password: SensitiveString,
269
    verify_against: Option<SensitiveString>,
270
    result_sender: flume::Sender<Result<HashResponse, Error>>,
271
}
272

            
273
#[derive(Debug)]
274
pub enum HashResponse {
275
    Hash(SensitiveString),
276
    Verified,
277
}
278

            
279
#[derive(thiserror::Error, Debug)]
280
enum ArgonError {
281
    #[error("{0}")]
282
    Argon(#[from] argon2::Error),
283
    #[error("{0}")]
284
    Hash(#[from] argon2::password_hash::Error),
285
}
286

            
287
fn ceil_divide(dividend: u128, divisor: u128) -> u128 {
288
    match divisor {
289
        0 => panic!("divide by 0"),
290
        1 => dividend,
291
        _ => {
292
            let rounding = divisor - 1;
293
            (dividend + rounding) / divisor
294
        }
295
    }
296
}
297

            
298
1
#[test]
299
1
fn basic_test() {
300
1
    use crate::config::SystemDefault;
301
1
    let hasher = Hasher::new(ArgonConfiguration::default());
302
1

            
303
1
    let password = SensitiveString(String::from("hunter2"));
304
1
    let hash = hasher.hash(1, password.clone()).unwrap();
305
1
    hasher.verify(1, password, hash).unwrap();
306
1

            
307
1
    let Hasher { sender, threads } = hasher;
308
1
    drop(sender);
309
2
    for thread in threads {
310
1
        thread.join().unwrap();
311
1
    }
312
1
}