1
use std::{
2
    sync::Arc,
3
    thread::JoinHandle,
4
    time::{Duration, Instant},
5
};
6

            
7
use argon2::{
8
    password_hash::{ParamsString, SaltString},
9
    Algorithm, Argon2, Block, ParamsBuilder, PasswordHash, Version,
10
};
11
use bonsaidb_core::connection::SensitiveString;
12
use once_cell::sync::OnceCell;
13
use rand::{thread_rng, CryptoRng, Rng};
14
use tokio::sync::oneshot;
15

            
16
use crate::{
17
    config::{ArgonConfiguration, ArgonParams},
18
    Error,
19
};
20

            
21
#[derive(Debug)]
22
#[cfg_attr(not(test), allow(dead_code))]
23
pub struct Hasher {
24
    sender: flume::Sender<HashRequest>,
25
    threads: Vec<JoinHandle<()>>,
26
}
27

            
28
impl Hasher {
29
2595
    pub fn new(config: ArgonConfiguration) -> Self {
30
2595
        let (sender, receiver) = flume::unbounded();
31
2595
        let thread = HashingThread {
32
2595
            receiver,
33
2595
            algorithm: config.algorithm,
34
2595
            params: config.params,
35
2595
            blocks: Vec::default(),
36
2595
            builder_template: Arc::default(),
37
2595
        };
38
2595
        let mut threads = Vec::with_capacity(config.hashers as usize);
39
2595
        for _ in 0..config.hashers.max(1) {
40
2595
            let thread = thread.clone();
41
2595
            threads.push(std::thread::spawn(move || thread.process_requests()));
42
2595
        }
43
2595
        Hasher { sender, threads }
44
2595
    }
45

            
46
76
    pub async fn hash(&self, id: u64, password: SensitiveString) -> Result<SensitiveString, Error> {
47
4
        let (result_sender, result_receiver) = oneshot::channel();
48
4
        if self
49
4
            .sender
50
4
            .send(HashRequest {
51
4
                id,
52
4
                password,
53
4
                verify_against: None,
54
4
                result_sender,
55
4
            })
56
4
            .is_ok()
57
        {
58
4
            match result_receiver.await?.map_err(Error::from) {
59
4
                Ok(HashResponse::Hash(hash)) => Ok(hash),
60
                Ok(HashResponse::Verified) => unreachable!(),
61
                Err(err) => Err(err),
62
            }
63
        } else {
64
            Err(Error::InternalCommunication)
65
        }
66
4
    }
67

            
68
126
    pub async fn verify(
69
126
        &self,
70
126
        id: u64,
71
126
        password: SensitiveString,
72
126
        saved_hash: SensitiveString,
73
126
    ) -> Result<(), Error> {
74
6
        let (result_sender, result_receiver) = oneshot::channel();
75
6
        if self
76
6
            .sender
77
6
            .send(HashRequest {
78
6
                id,
79
6
                password,
80
6
                verify_against: Some(saved_hash),
81
6
                result_sender,
82
6
            })
83
6
            .is_ok()
84
        {
85
6
            match result_receiver.await?.map_err(Error::from) {
86
6
                Ok(_) => Ok(()),
87
                Err(err) => {
88
                    eprintln!("Error validating password for user {}: {:?}", id, err);
89
                    Err(Error::Core(bonsaidb_core::Error::InvalidCredentials))
90
                }
91
            }
92
        } else {
93
            Err(Error::InternalCommunication)
94
        }
95
6
    }
96
}
97

            
98
2595
#[derive(Clone, Debug)]
99
struct HashingThread {
100
    receiver: flume::Receiver<HashRequest>,
101
    algorithm: Algorithm,
102
    params: ArgonParams,
103
    blocks: Vec<Block>,
104
    builder_template: Arc<OnceCell<Result<ParamsBuilder, ArgonError>>>,
105
}
106

            
107
impl HashingThread {
108
76
    fn initialize_builder_template<R: Rng + CryptoRng>(
109
76
        &mut self,
110
76
        rng: &mut R,
111
76
    ) -> Result<ParamsBuilder, ArgonError> {
112
76
        match &self.params {
113
            ArgonParams::Params(builder) => Ok(builder.clone()),
114
76
            ArgonParams::Timed(config) => {
115
76
                let mut params_builder = ParamsBuilder::new();
116
76
                let params = params_builder
117
76
                    .m_cost(config.ram_per_hasher)?
118
76
                    .p_cost(config.lanes)?
119
76
                    .data(&0_u64.to_be_bytes())?;
120
76
                let salt = SaltString::generate(rng);
121
76
                let mut salt_arr = [0u8; 64];
122
76
                let salt_bytes = salt.b64_decode(&mut salt_arr)?;
123
76
                let mut output = Vec::default();
124
76

            
125
76
                let minimum_duration = config.minimum_duration;
126
76
                let mut min_cost = 1;
127
76
                let mut total_spent_t = 0;
128
76
                let mut total_duration = Duration::ZERO;
129

            
130
                loop {
131
76
                    let t_cost = if total_spent_t > 0 {
132
                        let average_duration_per_t = total_duration / total_spent_t;
133
                        u32::try_from(ceil_divide(
134
                            minimum_duration.as_nanos(),
135
                            average_duration_per_t.as_nanos(),
136
                        ))
137
                        .unwrap()
138
                        .max(min_cost)
139
                    } else {
140
76
                        min_cost
141
                    };
142
76
                    params.t_cost(t_cost)?;
143

            
144
76
                    let params = params.clone().params()?;
145
76
                    self.allocate_blocks(&params);
146
76
                    let output_len = params
147
76
                        .output_len()
148
76
                        .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
149
76
                    output.resize(output_len, 0);
150
76

            
151
76
                    let start = Instant::now();
152
76
                    let argon = Argon2::new(self.algorithm, Version::V0x13, params);
153
76
                    argon.hash_password_into_with_memory(
154
76
                        b"hunter2",
155
76
                        salt_bytes,
156
76
                        &mut output[..],
157
76
                        &mut self.blocks,
158
76
                    )?;
159
76
                    let elapsed = match Instant::now().checked_duration_since(start) {
160
76
                        Some(elapsed) => elapsed,
161
                        None => continue,
162
                    };
163
76
                    if elapsed < minimum_duration {
164
                        total_spent_t += t_cost;
165
                        total_duration += elapsed;
166
                        min_cost = t_cost + 1;
167
                    } else {
168
                        // TODO if it's too far past the minimum duration, maybe we should try again at a smaller cost?
169
76
                        break;
170
76
                    }
171
76
                }
172
76
                Ok(params_builder)
173
            }
174
        }
175
76
    }
176

            
177
2595
    fn process_requests(mut self) {
178
2595
        let mut rng = thread_rng();
179
2797
        while let Ok(request) = self.receiver.recv() {
180
414
            let result = if let Some(verify_against) = &request.verify_against {
181
126
                Self::verify(&request, verify_against)
182
            } else {
183
76
                self.hash(&request, &mut rng)
184
            };
185
202
            drop(request.result_sender.send(result));
186
        }
187
2383
    }
188

            
189
126
    fn verify(request: &HashRequest, hash: &str) -> Result<HashResponse, Error> {
190
126
        let hash = PasswordHash::new(hash)?;
191

            
192
126
        let algorithm = Algorithm::try_from(hash.algorithm)?;
193
126
        let version = hash
194
126
            .version
195
126
            .map(Version::try_from)
196
126
            .transpose()?
197
126
            .unwrap_or(Version::V0x13);
198

            
199
126
        let argon = Argon2::new(algorithm, version, argon2::Params::try_from(&hash)?);
200

            
201
126
        hash.verify_password(&[&argon], request.password.0.as_bytes())
202
126
            .map(|_| HashResponse::Verified)
203
126
            .map_err(Error::from)
204
126
    }
205

            
206
76
    fn hash<R: Rng + CryptoRng>(
207
76
        &mut self,
208
76
        request: &HashRequest,
209
76
        rng: &mut R,
210
76
    ) -> Result<HashResponse, Error> {
211
76
        let builder_template = self.builder_template.clone();
212
76
        let mut params =
213
76
            match builder_template.get_or_init(|| self.initialize_builder_template(rng)) {
214
76
                Ok(template) => template.clone(),
215
                Err(error) => return Err(Error::PasswordHash(error.to_string())),
216
            };
217

            
218
76
        params.data(&request.id.to_be_bytes())?;
219

            
220
76
        let params = params.params()?;
221
76
        self.allocate_blocks(&params);
222
76

            
223
76
        let salt = SaltString::generate(rng);
224
76
        let mut salt_arr = [0u8; 64];
225
76
        let salt_bytes = salt.b64_decode(&mut salt_arr)?;
226

            
227
76
        let argon = Argon2::new(self.algorithm, Version::V0x13, params);
228
76

            
229
76
        let output_len = argon
230
76
            .params()
231
76
            .output_len()
232
76
            .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
233
76
        let output = argon2::password_hash::Output::init_with(output_len, |out| {
234
76
            Ok(argon.hash_password_into_with_memory(
235
76
                request.password.as_bytes(),
236
76
                salt_bytes,
237
76
                out,
238
76
                &mut self.blocks,
239
76
            )?)
240
76
        })?;
241

            
242
        Ok(HashResponse::Hash(SensitiveString(
243
            PasswordHash {
244
76
                algorithm: self.algorithm.ident(),
245
76
                version: Some(Version::V0x13.into()),
246
76
                params: ParamsString::try_from(argon.params())?,
247
76
                salt: Some(salt.as_salt()),
248
76
                hash: Some(output),
249
76
            }
250
76
            .to_string(),
251
        )))
252
76
    }
253

            
254
152
    fn allocate_blocks(&mut self, params: &argon2::Params) {
255
2490368
        for _ in self.blocks.len()..params.block_count() {
256
2490368
            self.blocks.push(Block::default());
257
2490368
        }
258
152
    }
259
}
260

            
261
#[derive(Debug)]
262
pub struct HashRequest {
263
    id: u64,
264
    password: SensitiveString,
265
    verify_against: Option<SensitiveString>,
266
    result_sender: oneshot::Sender<Result<HashResponse, Error>>,
267
}
268

            
269
#[derive(Debug)]
270
pub enum HashResponse {
271
    Hash(SensitiveString),
272
    Verified,
273
}
274

            
275
#[derive(thiserror::Error, Debug)]
276
enum ArgonError {
277
    #[error("{0}")]
278
    Argon(#[from] argon2::Error),
279
    #[error("{0}")]
280
    Hash(#[from] argon2::password_hash::Error),
281
}
282

            
283
fn ceil_divide(dividend: u128, divisor: u128) -> u128 {
284
    match divisor {
285
        0 => panic!("divide by 0"),
286
        1 => dividend,
287
        _ => {
288
            let rounding = divisor - 1;
289
            (dividend + rounding) / divisor
290
        }
291
    }
292
}
293

            
294
1
#[tokio::test]
295
1
async fn basic_test() {
296
1
    use crate::config::SystemDefault;
297
1
    let hasher = Hasher::new(ArgonConfiguration::default());
298
1

            
299
1
    let password = SensitiveString(String::from("hunter2"));
300
1
    let hash = hasher.hash(1, password.clone()).await.unwrap();
301
1
    hasher
302
1
        .verify(1, password.clone(), hash.clone())
303
1
        .await
304
1
        .unwrap();
305
1

            
306
1
    let Hasher { sender, threads } = hasher;
307
1
    drop(sender);
308
2
    for thread in threads {
309
1
        thread.join().unwrap();
310
1
    }
311
1
}