1
use std::{
2
    sync::Arc,
3
    thread::JoinHandle,
4
    time::{Duration, Instant},
5
};
6

            
7
use argon2::{
8
    password_hash::{ParamsString, SaltString},
9
    Algorithm, Argon2, Block, ParamsBuilder, PasswordHash, Version,
10
};
11
use bonsaidb_core::connection::SensitiveString;
12
use once_cell::sync::OnceCell;
13
use rand::{thread_rng, CryptoRng, Rng};
14
use tokio::sync::oneshot;
15

            
16
use crate::{
17
    config::{ArgonConfiguration, ArgonParams},
18
    Error,
19
};
20

            
21
#[derive(Debug)]
22
#[cfg_attr(not(test), allow(dead_code))]
23
pub struct Hasher {
24
    sender: flume::Sender<HashRequest>,
25
    threads: Vec<JoinHandle<()>>,
26
}
27

            
28
impl Hasher {
29
2700
    pub fn new(config: ArgonConfiguration) -> Self {
30
2700
        let (sender, receiver) = flume::unbounded();
31
2700
        let thread = HashingThread {
32
2700
            receiver,
33
2700
            algorithm: config.algorithm,
34
2700
            params: config.params,
35
2700
            blocks: Vec::default(),
36
2700
            builder_template: Arc::default(),
37
2700
        };
38
2700
        let mut threads = Vec::with_capacity(config.hashers as usize);
39
2700
        for _ in 0..config.hashers.max(1) {
40
2700
            let thread = thread.clone();
41
2700
            threads.push(std::thread::spawn(move || thread.process_requests()));
42
2700
        }
43
2700
        Hasher { sender, threads }
44
2700
    }
45

            
46
79
    pub async fn hash(&self, id: u64, password: SensitiveString) -> Result<SensitiveString, Error> {
47
4
        let (result_sender, result_receiver) = oneshot::channel();
48
4
        if self
49
4
            .sender
50
4
            .send(HashRequest {
51
4
                id,
52
4
                password,
53
4
                verify_against: None,
54
4
                result_sender,
55
4
            })
56
4
            .is_ok()
57
        {
58
4
            match result_receiver.await?.map_err(Error::from) {
59
4
                Ok(HashResponse::Hash(hash)) => Ok(hash),
60
                Ok(HashResponse::Verified) => unreachable!(),
61
                Err(err) => Err(err),
62
            }
63
        } else {
64
            Err(Error::InternalCommunication)
65
        }
66
4
    }
67

            
68
131
    pub async fn verify(
69
131
        &self,
70
131
        id: u64,
71
131
        password: SensitiveString,
72
131
        saved_hash: SensitiveString,
73
131
    ) -> Result<(), Error> {
74
6
        let (result_sender, result_receiver) = oneshot::channel();
75
6
        if self
76
6
            .sender
77
6
            .send(HashRequest {
78
6
                id,
79
6
                password,
80
6
                verify_against: Some(saved_hash),
81
6
                result_sender,
82
6
            })
83
6
            .is_ok()
84
        {
85
6
            match result_receiver.await?.map_err(Error::from) {
86
6
                Ok(_) => Ok(()),
87
                Err(err) => {
88
                    eprintln!("Error validating password for user {}: {:?}", id, err);
89
                    Err(Error::Core(bonsaidb_core::Error::InvalidCredentials))
90
                }
91
            }
92
        } else {
93
            Err(Error::InternalCommunication)
94
        }
95
6
    }
96
}
97

            
98
2700
#[derive(Clone, Debug)]
99
struct HashingThread {
100
    receiver: flume::Receiver<HashRequest>,
101
    algorithm: Algorithm,
102
    params: ArgonParams,
103
    blocks: Vec<Block>,
104
    builder_template: Arc<OnceCell<Result<ParamsBuilder, ArgonError>>>,
105
}
106

            
107
impl HashingThread {
108
79
    fn initialize_builder_template<R: Rng + CryptoRng>(
109
79
        &mut self,
110
79
        rng: &mut R,
111
79
    ) -> Result<ParamsBuilder, ArgonError> {
112
79
        match &self.params {
113
            ArgonParams::Params(builder) => Ok(builder.clone()),
114
79
            ArgonParams::Timed(config) => {
115
79
                let mut params_builder = ParamsBuilder::new();
116
79
                let params = params_builder
117
79
                    .m_cost(config.ram_per_hasher)?
118
79
                    .p_cost(config.lanes)?
119
79
                    .data(&0_u64.to_be_bytes())?;
120
79
                let salt = SaltString::generate(rng);
121
79
                let mut salt_arr = [0u8; 64];
122
79
                let salt_bytes = salt.b64_decode(&mut salt_arr)?;
123
79
                let mut output = Vec::default();
124
79

            
125
79
                let minimum_duration = config.minimum_duration;
126
79
                let mut min_cost = 1;
127
79
                let mut total_spent_t = 0;
128
79
                let mut total_duration = Duration::ZERO;
129

            
130
                loop {
131
79
                    let t_cost = if total_spent_t > 0 {
132
                        let average_duration_per_t = total_duration / total_spent_t;
133
                        u32::try_from(ceil_divide(
134
                            minimum_duration.as_nanos(),
135
                            average_duration_per_t.as_nanos(),
136
                        ))
137
                        .unwrap()
138
                        .max(min_cost)
139
                    } else {
140
79
                        min_cost
141
                    };
142
79
                    params.t_cost(t_cost)?;
143

            
144
79
                    let params = params.clone().params()?;
145
79
                    self.allocate_blocks(&params);
146
79
                    let output_len = params
147
79
                        .output_len()
148
79
                        .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
149
79
                    output.resize(output_len, 0);
150
79

            
151
79
                    let start = Instant::now();
152
79
                    let argon = Argon2::new(self.algorithm, Version::V0x13, params);
153
79
                    argon.hash_password_into_with_memory(
154
79
                        b"hunter2",
155
79
                        salt_bytes,
156
79
                        &mut output[..],
157
79
                        &mut self.blocks,
158
79
                    )?;
159
79
                    let elapsed = match Instant::now().checked_duration_since(start) {
160
79
                        Some(elapsed) => elapsed,
161
                        None => continue,
162
                    };
163
79
                    if elapsed < minimum_duration {
164
                        total_spent_t += t_cost;
165
                        total_duration += elapsed;
166
                        min_cost = t_cost + 1;
167
                    } else {
168
                        // TODO if it's too far past the minimum duration, maybe we should try again at a smaller cost?
169
79
                        break;
170
79
                    }
171
79
                }
172
79
                Ok(params_builder)
173
            }
174
        }
175
79
    }
176

            
177
2700
    fn process_requests(mut self) {
178
2700
        let mut rng = thread_rng();
179
2910
        while let Ok(request) = self.receiver.recv() {
180
404
            let result = if let Some(verify_against) = &request.verify_against {
181
131
                Self::verify(&request, verify_against)
182
            } else {
183
79
                self.hash(&request, &mut rng)
184
            };
185
210
            drop(request.result_sender.send(result));
186
        }
187
2506
    }
188

            
189
131
    fn verify(request: &HashRequest, hash: &str) -> Result<HashResponse, Error> {
190
131
        let hash = PasswordHash::new(hash)?;
191

            
192
131
        let algorithm = Algorithm::try_from(hash.algorithm)?;
193
131
        let version = hash
194
131
            .version
195
131
            .map(Version::try_from)
196
131
            .transpose()?
197
131
            .unwrap_or(Version::V0x13);
198

            
199
131
        let argon = Argon2::new(algorithm, version, argon2::Params::try_from(&hash)?);
200

            
201
131
        hash.verify_password(&[&argon], request.password.0.as_bytes())
202
131
            .map(|_| HashResponse::Verified)
203
131
            .map_err(Error::from)
204
131
    }
205

            
206
79
    fn hash<R: Rng + CryptoRng>(
207
79
        &mut self,
208
79
        request: &HashRequest,
209
79
        rng: &mut R,
210
79
    ) -> Result<HashResponse, Error> {
211
79
        let builder_template = self.builder_template.clone();
212
79
        let mut params =
213
79
            match builder_template.get_or_init(|| self.initialize_builder_template(rng)) {
214
79
                Ok(template) => template.clone(),
215
                Err(error) => return Err(Error::PasswordHash(error.to_string())),
216
            };
217

            
218
79
        params.data(&request.id.to_be_bytes())?;
219

            
220
79
        let params = params.params()?;
221
79
        self.allocate_blocks(&params);
222
79

            
223
79
        let salt = SaltString::generate(rng);
224
79
        let mut salt_arr = [0u8; 64];
225
79
        let salt_bytes = salt.b64_decode(&mut salt_arr)?;
226

            
227
79
        let argon = Argon2::new(self.algorithm, Version::V0x13, params);
228
79

            
229
79
        let output_len = argon
230
79
            .params()
231
79
            .output_len()
232
79
            .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
233
79
        let output = argon2::password_hash::Output::init_with(output_len, |out| {
234
79
            Ok(argon.hash_password_into_with_memory(
235
79
                request.password.as_bytes(),
236
79
                salt_bytes,
237
79
                out,
238
79
                &mut self.blocks,
239
79
            )?)
240
79
        })?;
241

            
242
        Ok(HashResponse::Hash(SensitiveString(
243
            PasswordHash {
244
79
                algorithm: self.algorithm.ident(),
245
79
                version: Some(Version::V0x13.into()),
246
79
                params: ParamsString::try_from(argon.params())?,
247
79
                salt: Some(salt.as_salt()),
248
79
                hash: Some(output),
249
79
            }
250
79
            .to_string(),
251
        )))
252
79
    }
253

            
254
158
    fn allocate_blocks(&mut self, params: &argon2::Params) {
255
2588672
        for _ in self.blocks.len()..params.block_count() {
256
2588672
            self.blocks.push(Block::default());
257
2588672
        }
258
158
    }
259
}
260

            
261
#[derive(Debug)]
262
pub struct HashRequest {
263
    id: u64,
264
    password: SensitiveString,
265
    verify_against: Option<SensitiveString>,
266
    result_sender: oneshot::Sender<Result<HashResponse, Error>>,
267
}
268

            
269
#[derive(Debug)]
270
pub enum HashResponse {
271
    Hash(SensitiveString),
272
    Verified,
273
}
274

            
275
#[derive(thiserror::Error, Debug)]
276
enum ArgonError {
277
    #[error("{0}")]
278
    Argon(#[from] argon2::Error),
279
    #[error("{0}")]
280
    Hash(#[from] argon2::password_hash::Error),
281
}
282

            
283
fn ceil_divide(dividend: u128, divisor: u128) -> u128 {
284
    match divisor {
285
        0 => panic!("divide by 0"),
286
        1 => dividend,
287
        _ => {
288
            let rounding = divisor - 1;
289
            (dividend + rounding) / divisor
290
        }
291
    }
292
}
293

            
294
1
#[tokio::test]
295
1
async fn basic_test() {
296
1
    use crate::config::SystemDefault;
297
1
    let hasher = Hasher::new(ArgonConfiguration::default());
298
1

            
299
1
    let password = SensitiveString(String::from("hunter2"));
300
1
    let hash = hasher.hash(1, password.clone()).await.unwrap();
301
1
    hasher
302
1
        .verify(1, password.clone(), hash.clone())
303
1
        .await
304
1
        .unwrap();
305
1

            
306
1
    let Hasher { sender, threads } = hasher;
307
1
    drop(sender);
308
2
    for thread in threads {
309
1
        thread.join().unwrap();
310
1
    }
311
1
}