1
use std::sync::Arc;
2
use std::thread::JoinHandle;
3
use std::time::{Duration, Instant};
4

            
5
use argon2::password_hash::{ParamsString, SaltString};
6
use argon2::{Algorithm, Argon2, AssociatedData, Block, ParamsBuilder, PasswordHash, Version};
7
use bonsaidb_core::connection::SensitiveString;
8
use once_cell::sync::OnceCell;
9
use rand::{thread_rng, CryptoRng, Rng};
10

            
11
use crate::config::{ArgonConfiguration, ArgonParams};
12
use crate::Error;
13

            
14
#[derive(Debug)]
15
#[cfg_attr(not(test), allow(dead_code))]
16
pub struct Hasher {
17
    sender: flume::Sender<HashRequest>,
18
    threads: Vec<JoinHandle<()>>,
19
}
20

            
21
impl Hasher {
22
5126
    pub fn new(config: ArgonConfiguration) -> Self {
23
5126
        let (sender, receiver) = flume::unbounded();
24
5126
        let thread = HashingThread {
25
5126
            receiver,
26
5126
            algorithm: config.algorithm,
27
5126
            params: config.params,
28
5126
            blocks: Vec::default(),
29
5126
            builder_template: Arc::default(),
30
5126
        };
31
5126
        let mut threads = Vec::with_capacity(config.hashers as usize);
32
5126
        for _ in 0..config.hashers.max(1) {
33
5126
            let thread = thread.clone();
34
5126
            threads.push(
35
5126
                std::thread::Builder::new()
36
5126
                    .name(String::from("argon2"))
37
5126
                    .spawn(move || thread.process_requests())
38
5126
                    .unwrap(),
39
5126
            );
40
5126
        }
41
5126
        Hasher { sender, threads }
42
5126
    }
43

            
44
181
    pub fn hash(&self, id: u64, password: SensitiveString) -> Result<SensitiveString, Error> {
45
181
        let (result_sender, result_receiver) = flume::bounded(1);
46
181
        if self
47
181
            .sender
48
181
            .send(HashRequest {
49
181
                id,
50
181
                password,
51
181
                verify_against: None,
52
181
                result_sender,
53
181
            })
54
181
            .is_ok()
55
        {
56
181
            match result_receiver.recv()?.map_err(Error::from) {
57
181
                Ok(HashResponse::Hash(hash)) => Ok(hash),
58
                Ok(HashResponse::Verified) => unreachable!(),
59
                Err(err) => Err(err),
60
            }
61
        } else {
62
            Err(Error::InternalCommunication)
63
        }
64
181
    }
65

            
66
253
    pub fn verify(
67
253
        &self,
68
253
        id: u64,
69
253
        password: SensitiveString,
70
253
        saved_hash: SensitiveString,
71
253
    ) -> Result<(), Error> {
72
253
        let (result_sender, result_receiver) = flume::bounded(1);
73
253
        if self
74
253
            .sender
75
253
            .send(HashRequest {
76
253
                id,
77
253
                password,
78
253
                verify_against: Some(saved_hash),
79
253
                result_sender,
80
253
            })
81
253
            .is_ok()
82
        {
83
253
            match result_receiver.recv()?.map_err(Error::from) {
84
253
                Ok(_) => Ok(()),
85
                Err(err) => {
86
                    eprintln!("Error validating password for user {id}: {err:?}");
87
                    Err(Error::Core(bonsaidb_core::Error::InvalidCredentials))
88
                }
89
            }
90
        } else {
91
            Err(Error::InternalCommunication)
92
        }
93
253
    }
94
}
95

            
96
5126
#[derive(Clone, Debug)]
97
struct HashingThread {
98
    receiver: flume::Receiver<HashRequest>,
99
    algorithm: Algorithm,
100
    params: ArgonParams,
101
    blocks: Vec<Block>,
102
    builder_template: Arc<OnceCell<Result<ParamsBuilder, ArgonError>>>,
103
}
104

            
105
impl HashingThread {
106
181
    fn initialize_builder_template<R: Rng + CryptoRng>(
107
181
        &mut self,
108
181
        rng: &mut R,
109
181
    ) -> Result<ParamsBuilder, ArgonError> {
110
181
        match &self.params {
111
            ArgonParams::Params(builder) => Ok(builder.clone()),
112
181
            ArgonParams::Timed(config) => {
113
181
                let mut params_builder = ParamsBuilder::new();
114
181
                let params = params_builder
115
181
                    .m_cost(config.ram_per_hasher / 1_024)
116
181
                    .p_cost(config.lanes)
117
181
                    .data(AssociatedData::new(&0_u64.to_be_bytes())?);
118
181
                let salt = SaltString::generate(rng);
119
181
                let mut salt_arr = [0u8; 64];
120
181
                let salt_bytes = salt.decode_b64(&mut salt_arr)?;
121
181
                let mut output = Vec::default();
122
181

            
123
181
                let minimum_duration = config.minimum_duration;
124
181
                let mut min_cost = 2; // OWASP sets the minimum iteration count at 2
125
181
                let mut total_spent_t = 0;
126
181
                let mut total_duration = Duration::ZERO;
127

            
128
                loop {
129
362
                    let t_cost = if total_spent_t > 0 {
130
181
                        let average_duration_per_t = total_duration / total_spent_t;
131
181
                        u32::try_from(ceil_divide(
132
181
                            minimum_duration.as_nanos(),
133
181
                            average_duration_per_t.as_nanos(),
134
181
                        ))
135
181
                        .unwrap()
136
181
                        .max(min_cost)
137
                    } else {
138
181
                        min_cost
139
                    };
140
362
                    params.t_cost(t_cost);
141

            
142
362
                    let params = params.clone().build()?;
143
362
                    self.allocate_blocks(&params);
144
362
                    let output_len = params
145
362
                        .output_len()
146
362
                        .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
147
362
                    output.resize(output_len, 0);
148
362

            
149
362
                    let start = Instant::now();
150
362
                    let argon = Argon2::new(self.algorithm, Version::V0x13, params);
151
362
                    argon.hash_password_into_with_memory(
152
362
                        b"hunter2",
153
362
                        salt_bytes,
154
362
                        &mut output[..],
155
362
                        &mut self.blocks,
156
362
                    )?;
157

            
158
362
                    let Some(elapsed) = Instant::now().checked_duration_since(start) else {
159
                        continue;
160
                    };
161
362
                    if elapsed < minimum_duration {
162
181
                        total_spent_t += t_cost;
163
181
                        total_duration += elapsed;
164
181
                        min_cost = t_cost + 1;
165
181
                    } else {
166
                        // TODO if it's too far past the minimum duration, maybe we should try again at a smaller cost?
167
181
                        break;
168
181
                    }
169
181
                }
170
181
                Ok(params_builder)
171
            }
172
        }
173
181
    }
174

            
175
5126
    fn process_requests(mut self) {
176
5126
        let mut rng = thread_rng();
177
5560
        while let Ok(request) = self.receiver.recv() {
178
578
            let result = if let Some(verify_against) = &request.verify_against {
179
253
                Self::verify(&request, verify_against)
180
            } else {
181
181
                self.hash(&request, &mut rng)
182
            };
183
434
            drop(request.result_sender.send(result));
184
        }
185
4982
    }
186

            
187
253
    fn verify(request: &HashRequest, hash: &str) -> Result<HashResponse, Error> {
188
253
        let hash = PasswordHash::new(hash)?;
189

            
190
253
        let algorithm = Algorithm::try_from(hash.algorithm)?;
191
253
        let version = hash
192
253
            .version
193
253
            .map(Version::try_from)
194
253
            .transpose()?
195
253
            .unwrap_or(Version::V0x13);
196

            
197
253
        let argon = Argon2::new(algorithm, version, argon2::Params::try_from(&hash)?);
198

            
199
253
        hash.verify_password(&[&argon], request.password.0.as_bytes())
200
253
            .map(|()| HashResponse::Verified)
201
253
            .map_err(Error::from)
202
253
    }
203

            
204
181
    fn hash<R: Rng + CryptoRng>(
205
181
        &mut self,
206
181
        request: &HashRequest,
207
181
        rng: &mut R,
208
181
    ) -> Result<HashResponse, Error> {
209
181
        let builder_template = self.builder_template.clone();
210
181
        let mut params =
211
181
            match builder_template.get_or_init(|| self.initialize_builder_template(rng)) {
212
181
                Ok(template) => template.clone(),
213
                Err(error) => return Err(Error::other("argon2", error)),
214
            };
215

            
216
181
        params.data(AssociatedData::new(&request.id.to_be_bytes())?);
217

            
218
181
        let params = params.build()?;
219
181
        self.allocate_blocks(&params);
220
181

            
221
181
        let salt = SaltString::generate(rng);
222
181
        let mut salt_arr = [0u8; 64];
223
181
        let salt_bytes = salt.decode_b64(&mut salt_arr)?;
224

            
225
181
        let argon = Argon2::new(self.algorithm, Version::V0x13, params);
226
181

            
227
181
        let output_len = argon
228
181
            .params()
229
181
            .output_len()
230
181
            .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
231
181
        let output = argon2::password_hash::Output::init_with(output_len, |out| {
232
181
            Ok(argon.hash_password_into_with_memory(
233
181
                request.password.as_bytes(),
234
181
                salt_bytes,
235
181
                out,
236
181
                &mut self.blocks,
237
181
            )?)
238
181
        })?;
239

            
240
        Ok(HashResponse::Hash(SensitiveString(
241
            PasswordHash {
242
181
                algorithm: self.algorithm.ident(),
243
181
                version: Some(Version::V0x13.into()),
244
181
                params: ParamsString::try_from(argon.params())?,
245
181
                salt: Some(salt.as_salt()),
246
181
                hash: Some(output),
247
181
            }
248
181
            .to_string(),
249
        )))
250
181
    }
251

            
252
543
    fn allocate_blocks(&mut self, params: &argon2::Params) {
253
3521536
        for _ in self.blocks.len()..params.block_count() {
254
3521536
            self.blocks.push(Block::default());
255
3521536
        }
256
543
    }
257
}
258

            
259
#[derive(Debug)]
260
pub struct HashRequest {
261
    id: u64,
262
    password: SensitiveString,
263
    verify_against: Option<SensitiveString>,
264
    result_sender: flume::Sender<Result<HashResponse, Error>>,
265
}
266

            
267
#[derive(Debug)]
268
pub enum HashResponse {
269
    Hash(SensitiveString),
270
    Verified,
271
}
272

            
273
#[derive(thiserror::Error, Debug)]
274
enum ArgonError {
275
    #[error("{0}")]
276
    Argon(#[from] argon2::Error),
277
    #[error("{0}")]
278
    Hash(#[from] argon2::password_hash::Error),
279
}
280

            
281
181
fn ceil_divide(dividend: u128, divisor: u128) -> u128 {
282
181
    match divisor {
283
        0 => panic!("divide by 0"),
284
        1 => dividend,
285
        _ => {
286
181
            let rounding = divisor - 1;
287
181
            (dividend + rounding) / divisor
288
        }
289
    }
290
181
}
291

            
292
1
#[test]
293
1
fn basic_test() {
294
1
    use crate::config::SystemDefault;
295
1
    let hasher = Hasher::new(ArgonConfiguration::default());
296
1

            
297
1
    let password = SensitiveString(String::from("hunter2"));
298
1
    let hash = hasher.hash(1, password.clone()).unwrap();
299
1
    hasher.verify(1, password, hash).unwrap();
300
1

            
301
1
    let Hasher { sender, threads } = hasher;
302
1
    drop(sender);
303
2
    for thread in threads {
304
1
        thread.join().unwrap();
305
1
    }
306
1
}