1
use std::sync::Arc;
2
use std::thread::JoinHandle;
3
use std::time::{Duration, Instant};
4

            
5
use argon2::password_hash::{ParamsString, SaltString};
6
use argon2::{Algorithm, Argon2, AssociatedData, Block, ParamsBuilder, PasswordHash, Version};
7
use bonsaidb_core::connection::SensitiveString;
8
use once_cell::sync::OnceCell;
9
use rand::{thread_rng, CryptoRng, Rng};
10

            
11
use crate::config::{ArgonConfiguration, ArgonParams};
12
use crate::Error;
13

            
14
#[derive(Debug)]
15
#[cfg_attr(not(test), allow(dead_code))]
16
pub struct Hasher {
17
    sender: flume::Sender<HashRequest>,
18
    threads: Vec<JoinHandle<()>>,
19
}
20

            
21
impl Hasher {
22
5262
    pub fn new(config: ArgonConfiguration) -> Self {
23
5262
        let (sender, receiver) = flume::unbounded();
24
5262
        let thread = HashingThread {
25
5262
            receiver,
26
5262
            algorithm: config.algorithm,
27
5262
            params: config.params,
28
5262
            blocks: Vec::default(),
29
5262
            builder_template: Arc::default(),
30
5262
        };
31
5262
        let mut threads = Vec::with_capacity(config.hashers as usize);
32
5262
        for _ in 0..config.hashers.max(1) {
33
5262
            let thread = thread.clone();
34
5262
            threads.push(
35
5262
                std::thread::Builder::new()
36
5262
                    .name(String::from("argon2"))
37
5262
                    .spawn(move || thread.process_requests())
38
5262
                    .unwrap(),
39
5262
            );
40
5262
        }
41
5262
        Hasher { sender, threads }
42
5262
    }
43

            
44
186
    pub fn hash(&self, id: u64, password: SensitiveString) -> Result<SensitiveString, Error> {
45
186
        let (result_sender, result_receiver) = flume::bounded(1);
46
186
        if self
47
186
            .sender
48
186
            .send(HashRequest {
49
186
                id,
50
186
                password,
51
186
                verify_against: None,
52
186
                result_sender,
53
186
            })
54
186
            .is_ok()
55
        {
56
186
            match result_receiver.recv()?.map_err(Error::from) {
57
186
                Ok(HashResponse::Hash(hash)) => Ok(hash),
58
                Ok(HashResponse::Verified) => unreachable!(),
59
                Err(err) => Err(err),
60
            }
61
        } else {
62
            Err(Error::InternalCommunication)
63
        }
64
186
    }
65

            
66
260
    pub fn verify(
67
260
        &self,
68
260
        id: u64,
69
260
        password: SensitiveString,
70
260
        saved_hash: SensitiveString,
71
260
    ) -> Result<(), Error> {
72
260
        let (result_sender, result_receiver) = flume::bounded(1);
73
260
        if self
74
260
            .sender
75
260
            .send(HashRequest {
76
260
                id,
77
260
                password,
78
260
                verify_against: Some(saved_hash),
79
260
                result_sender,
80
260
            })
81
260
            .is_ok()
82
        {
83
260
            match result_receiver.recv()?.map_err(Error::from) {
84
260
                Ok(_) => Ok(()),
85
                Err(err) => {
86
                    eprintln!("Error validating password for user {id}: {err:?}");
87
                    Err(Error::Core(bonsaidb_core::Error::InvalidCredentials))
88
                }
89
            }
90
        } else {
91
            Err(Error::InternalCommunication)
92
        }
93
260
    }
94
}
95

            
96
5262
#[derive(Clone, Debug)]
97
struct HashingThread {
98
    receiver: flume::Receiver<HashRequest>,
99
    algorithm: Algorithm,
100
    params: ArgonParams,
101
    blocks: Vec<Block>,
102
    builder_template: Arc<OnceCell<Result<ParamsBuilder, ArgonError>>>,
103
}
104

            
105
impl HashingThread {
106
186
    fn initialize_builder_template<R: Rng + CryptoRng>(
107
186
        &mut self,
108
186
        rng: &mut R,
109
186
    ) -> Result<ParamsBuilder, ArgonError> {
110
186
        match &self.params {
111
            ArgonParams::Params(builder) => Ok(builder.clone()),
112
186
            ArgonParams::Timed(config) => {
113
186
                let mut params_builder = ParamsBuilder::new();
114
186
                let params = params_builder
115
186
                    .m_cost(config.ram_per_hasher / 1_024)
116
186
                    .p_cost(config.lanes)
117
186
                    .data(AssociatedData::new(&0_u64.to_be_bytes())?);
118
186
                let salt = SaltString::generate(rng);
119
186
                let mut salt_arr = [0u8; 64];
120
186
                let salt_bytes = salt.decode_b64(&mut salt_arr)?;
121
186
                let mut output = Vec::default();
122
186

            
123
186
                let minimum_duration = config.minimum_duration;
124
186
                let mut min_cost = 2; // OWASP sets the minimum iteration count at 2
125
186
                let mut total_spent_t = 0;
126
186
                let mut total_duration = Duration::ZERO;
127

            
128
                loop {
129
372
                    let t_cost = if total_spent_t > 0 {
130
186
                        let average_duration_per_t = total_duration / total_spent_t;
131
186
                        u32::try_from(ceil_divide(
132
186
                            minimum_duration.as_nanos(),
133
186
                            average_duration_per_t.as_nanos(),
134
186
                        ))
135
186
                        .unwrap()
136
186
                        .max(min_cost)
137
                    } else {
138
186
                        min_cost
139
                    };
140
372
                    params.t_cost(t_cost);
141

            
142
372
                    let params = params.clone().build()?;
143
372
                    self.allocate_blocks(&params);
144
372
                    let output_len = params
145
372
                        .output_len()
146
372
                        .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
147
372
                    output.resize(output_len, 0);
148
372

            
149
372
                    let start = Instant::now();
150
372
                    let argon = Argon2::new(self.algorithm, Version::V0x13, params);
151
372
                    argon.hash_password_into_with_memory(
152
372
                        b"hunter2",
153
372
                        salt_bytes,
154
372
                        &mut output[..],
155
372
                        &mut self.blocks,
156
372
                    )?;
157

            
158
372
                    let Some(elapsed) = Instant::now().checked_duration_since(start) else {
159
                        continue;
160
                    };
161
372
                    if elapsed < minimum_duration {
162
186
                        total_spent_t += t_cost;
163
186
                        total_duration += elapsed;
164
186
                        min_cost = t_cost + 1;
165
186
                    } else {
166
                        // TODO if it's too far past the minimum duration, maybe we should try again at a smaller cost?
167
186
                        break;
168
186
                    }
169
186
                }
170
186
                Ok(params_builder)
171
            }
172
        }
173
186
    }
174

            
175
5262
    fn process_requests(mut self) {
176
5262
        let mut rng = thread_rng();
177
5708
        while let Ok(request) = self.receiver.recv() {
178
594
            let result = if let Some(verify_against) = &request.verify_against {
179
260
                Self::verify(&request, verify_against)
180
            } else {
181
186
                self.hash(&request, &mut rng)
182
            };
183
446
            drop(request.result_sender.send(result));
184
        }
185
5114
    }
186

            
187
260
    fn verify(request: &HashRequest, hash: &str) -> Result<HashResponse, Error> {
188
260
        let hash = PasswordHash::new(hash)?;
189

            
190
260
        let algorithm = Algorithm::try_from(hash.algorithm)?;
191
260
        let version = hash
192
260
            .version
193
260
            .map(Version::try_from)
194
260
            .transpose()?
195
260
            .unwrap_or(Version::V0x13);
196

            
197
260
        let argon = Argon2::new(algorithm, version, argon2::Params::try_from(&hash)?);
198

            
199
260
        hash.verify_password(&[&argon], request.password.0.as_bytes())
200
260
            .map(|_| HashResponse::Verified)
201
260
            .map_err(Error::from)
202
260
    }
203

            
204
186
    fn hash<R: Rng + CryptoRng>(
205
186
        &mut self,
206
186
        request: &HashRequest,
207
186
        rng: &mut R,
208
186
    ) -> Result<HashResponse, Error> {
209
186
        let builder_template = self.builder_template.clone();
210
186
        let mut params =
211
186
            match builder_template.get_or_init(|| self.initialize_builder_template(rng)) {
212
186
                Ok(template) => template.clone(),
213
                Err(error) => return Err(Error::other("argon2", error)),
214
            };
215

            
216
186
        params.data(AssociatedData::new(&request.id.to_be_bytes())?);
217

            
218
186
        let params = params.build()?;
219
186
        self.allocate_blocks(&params);
220
186

            
221
186
        let salt = SaltString::generate(rng);
222
186
        let mut salt_arr = [0u8; 64];
223
186
        let salt_bytes = salt.decode_b64(&mut salt_arr)?;
224

            
225
186
        let argon = Argon2::new(self.algorithm, Version::V0x13, params);
226
186

            
227
186
        let output_len = argon
228
186
            .params()
229
186
            .output_len()
230
186
            .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
231
186
        let output = argon2::password_hash::Output::init_with(output_len, |out| {
232
186
            Ok(argon.hash_password_into_with_memory(
233
186
                request.password.as_bytes(),
234
186
                salt_bytes,
235
186
                out,
236
186
                &mut self.blocks,
237
186
            )?)
238
186
        })?;
239

            
240
        Ok(HashResponse::Hash(SensitiveString(
241
            PasswordHash {
242
186
                algorithm: self.algorithm.ident(),
243
186
                version: Some(Version::V0x13.into()),
244
186
                params: ParamsString::try_from(argon.params())?,
245
186
                salt: Some(salt.as_salt()),
246
186
                hash: Some(output),
247
186
            }
248
186
            .to_string(),
249
        )))
250
186
    }
251

            
252
558
    fn allocate_blocks(&mut self, params: &argon2::Params) {
253
3618816
        for _ in self.blocks.len()..params.block_count() {
254
3618816
            self.blocks.push(Block::default());
255
3618816
        }
256
558
    }
257
}
258

            
259
#[derive(Debug)]
260
pub struct HashRequest {
261
    id: u64,
262
    password: SensitiveString,
263
    verify_against: Option<SensitiveString>,
264
    result_sender: flume::Sender<Result<HashResponse, Error>>,
265
}
266

            
267
#[derive(Debug)]
268
pub enum HashResponse {
269
    Hash(SensitiveString),
270
    Verified,
271
}
272

            
273
#[derive(thiserror::Error, Debug)]
274
enum ArgonError {
275
    #[error("{0}")]
276
    Argon(#[from] argon2::Error),
277
    #[error("{0}")]
278
    Hash(#[from] argon2::password_hash::Error),
279
}
280

            
281
186
fn ceil_divide(dividend: u128, divisor: u128) -> u128 {
282
186
    match divisor {
283
        0 => panic!("divide by 0"),
284
        1 => dividend,
285
        _ => {
286
186
            let rounding = divisor - 1;
287
186
            (dividend + rounding) / divisor
288
        }
289
    }
290
186
}
291

            
292
1
#[test]
293
1
fn basic_test() {
294
1
    use crate::config::SystemDefault;
295
1
    let hasher = Hasher::new(ArgonConfiguration::default());
296
1

            
297
1
    let password = SensitiveString(String::from("hunter2"));
298
1
    let hash = hasher.hash(1, password.clone()).unwrap();
299
1
    hasher.verify(1, password, hash).unwrap();
300
1

            
301
1
    let Hasher { sender, threads } = hasher;
302
1
    drop(sender);
303
2
    for thread in threads {
304
1
        thread.join().unwrap();
305
1
    }
306
1
}