1
use std::{
2
    sync::Arc,
3
    thread::JoinHandle,
4
    time::{Duration, Instant},
5
};
6

            
7
use argon2::{
8
    password_hash::{ParamsString, SaltString},
9
    Algorithm, Argon2, Block, ParamsBuilder, PasswordHash, Version,
10
};
11
use bonsaidb_core::connection::SensitiveString;
12
use once_cell::sync::OnceCell;
13
use rand::{thread_rng, CryptoRng, Rng};
14
use tokio::sync::oneshot;
15

            
16
use crate::{
17
    config::{ArgonConfiguration, ArgonParams},
18
    Error,
19
};
20

            
21
#[derive(Debug)]
22
#[cfg_attr(not(test), allow(dead_code))]
23
pub struct Hasher {
24
    sender: flume::Sender<HashRequest>,
25
    threads: Vec<JoinHandle<()>>,
26
}
27

            
28
impl Hasher {
29
2042
    pub fn new(config: ArgonConfiguration) -> Self {
30
2042
        let (sender, receiver) = flume::unbounded();
31
2042
        let thread = HashingThread {
32
2042
            receiver,
33
2042
            algorithm: config.algorithm,
34
2042
            params: config.params,
35
2042
            blocks: Vec::default(),
36
2042
            builder_template: Arc::default(),
37
2042
        };
38
2042
        let mut threads = Vec::with_capacity(config.hashers as usize);
39
2042
        for _ in 0..config.hashers.max(1) {
40
2042
            let thread = thread.clone();
41
2042
            threads.push(std::thread::spawn(move || thread.process_requests()));
42
2042
        }
43
2042
        Hasher { sender, threads }
44
2042
    }
45

            
46
67
    pub async fn hash(&self, id: u64, password: SensitiveString) -> Result<SensitiveString, Error> {
47
4
        let (result_sender, result_receiver) = oneshot::channel();
48
4
        if self
49
4
            .sender
50
4
            .send(HashRequest {
51
4
                id,
52
4
                password,
53
4
                verify_against: None,
54
4
                result_sender,
55
4
            })
56
4
            .is_ok()
57
        {
58
4
            match result_receiver.await?.map_err(Error::from) {
59
4
                Ok(HashResponse::Hash(hash)) => Ok(hash),
60
                Ok(HashResponse::Verified) => unreachable!(),
61
                Err(err) => Err(err),
62
            }
63
        } else {
64
            Err(Error::InternalCommunication)
65
        }
66
4
    }
67

            
68
111
    pub async fn verify(
69
111
        &self,
70
111
        id: u64,
71
111
        password: SensitiveString,
72
111
        saved_hash: SensitiveString,
73
111
    ) -> Result<(), Error> {
74
6
        let (result_sender, result_receiver) = oneshot::channel();
75
6
        if self
76
6
            .sender
77
6
            .send(HashRequest {
78
6
                id,
79
6
                password,
80
6
                verify_against: Some(saved_hash),
81
6
                result_sender,
82
6
            })
83
6
            .is_ok()
84
        {
85
6
            match result_receiver.await?.map_err(Error::from) {
86
6
                Ok(_) => Ok(()),
87
                Err(err) => {
88
                    eprintln!("Error validating password for user {}: {:?}", id, err);
89
                    Err(Error::Core(bonsaidb_core::Error::InvalidCredentials))
90
                }
91
            }
92
        } else {
93
            Err(Error::InternalCommunication)
94
        }
95
6
    }
96
}
97

            
98
2042
#[derive(Clone, Debug)]
99
struct HashingThread {
100
    receiver: flume::Receiver<HashRequest>,
101
    algorithm: Algorithm,
102
    params: ArgonParams,
103
    blocks: Vec<Block>,
104
    builder_template: Arc<OnceCell<Result<ParamsBuilder, ArgonError>>>,
105
}
106

            
107
impl HashingThread {
108
67
    fn initialize_builder_template<R: Rng + CryptoRng>(
109
67
        &mut self,
110
67
        rng: &mut R,
111
67
    ) -> Result<ParamsBuilder, ArgonError> {
112
67
        match &self.params {
113
            ArgonParams::Params(builder) => Ok(builder.clone()),
114
67
            ArgonParams::Timed(config) => {
115
67
                let mut params_builder = ParamsBuilder::new();
116
67
                let params = params_builder
117
67
                    .m_cost(config.ram_per_hasher)?
118
67
                    .p_cost(config.lanes)?
119
67
                    .data(&0_u64.to_be_bytes())?;
120
67
                let salt = SaltString::generate(rng);
121
67
                let mut salt_arr = [0u8; 64];
122
67
                let salt_bytes = salt.b64_decode(&mut salt_arr)?;
123
67
                let mut output = Vec::default();
124
67

            
125
67
                let minimum_duration = config.minimum_duration;
126
67
                let mut min_cost = 1;
127
67
                let mut total_spent_t = 0;
128
67
                let mut total_duration = Duration::ZERO;
129

            
130
                loop {
131
67
                    let t_cost = if total_spent_t > 0 {
132
                        let average_duration_per_t = total_duration / total_spent_t;
133
                        u32::try_from(ceil_divide(
134
                            minimum_duration.as_nanos(),
135
                            average_duration_per_t.as_nanos(),
136
                        ))
137
                        .unwrap()
138
                        .max(min_cost)
139
                    } else {
140
67
                        min_cost
141
                    };
142
67
                    params.t_cost(t_cost)?;
143

            
144
67
                    let params = params.clone().params()?;
145
67
                    self.allocate_blocks(&params);
146
67
                    let output_len = params
147
67
                        .output_len()
148
67
                        .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
149
67
                    output.resize(output_len, 0);
150
67

            
151
67
                    let start = Instant::now();
152
67
                    let argon = Argon2::new(self.algorithm, Version::V0x13, params);
153
67
                    argon.hash_password_into_with_memory(
154
67
                        b"hunter2",
155
67
                        salt_bytes,
156
67
                        &mut output[..],
157
67
                        &mut self.blocks,
158
67
                    )?;
159
67
                    let elapsed = match Instant::now().checked_duration_since(start) {
160
67
                        Some(elapsed) => elapsed,
161
                        None => continue,
162
                    };
163
67
                    if elapsed < minimum_duration {
164
                        total_spent_t += t_cost;
165
                        total_duration += elapsed;
166
                        min_cost = t_cost + 1;
167
                    } else {
168
                        // TODO if it's too far past the minimum duration, maybe we should try again at a smaller cost?
169
67
                        break;
170
67
                    }
171
67
                }
172
67
                Ok(params_builder)
173
            }
174
        }
175
67
    }
176

            
177
2042
    fn process_requests(mut self) {
178
2042
        let mut rng = thread_rng();
179
2220
        while let Ok(request) = self.receiver.recv() {
180
382
            let result = if let Some(verify_against) = &request.verify_against {
181
111
                Self::verify(&request, verify_against)
182
            } else {
183
67
                self.hash(&request, &mut rng)
184
            };
185
178
            drop(request.result_sender.send(result));
186
        }
187
1838
    }
188

            
189
111
    fn verify(request: &HashRequest, hash: &str) -> Result<HashResponse, Error> {
190
111
        let hash = PasswordHash::new(hash)?;
191

            
192
111
        let algorithm = Algorithm::try_from(hash.algorithm)?;
193
111
        let version = hash
194
111
            .version
195
111
            .map(Version::try_from)
196
111
            .transpose()?
197
111
            .unwrap_or(Version::V0x13);
198

            
199
111
        let argon = Argon2::new(algorithm, version, argon2::Params::try_from(&hash)?);
200

            
201
111
        hash.verify_password(&[&argon], request.password.0.as_bytes())
202
111
            .map(|_| HashResponse::Verified)
203
111
            .map_err(Error::from)
204
111
    }
205

            
206
67
    fn hash<R: Rng + CryptoRng>(
207
67
        &mut self,
208
67
        request: &HashRequest,
209
67
        rng: &mut R,
210
67
    ) -> Result<HashResponse, Error> {
211
67
        let builder_template = self.builder_template.clone();
212
67
        let mut params =
213
67
            match builder_template.get_or_init(|| self.initialize_builder_template(rng)) {
214
67
                Ok(template) => template.clone(),
215
                Err(error) => return Err(Error::PasswordHash(error.to_string())),
216
            };
217

            
218
67
        let id = request.id.to_be_bytes();
219
67
        params.data(&id)?;
220

            
221
67
        let params = params.params()?;
222
67
        self.allocate_blocks(&params);
223
67

            
224
67
        let salt = SaltString::generate(rng);
225
67
        let mut salt_arr = [0u8; 64];
226
67
        let salt_bytes = salt.b64_decode(&mut salt_arr)?;
227

            
228
67
        let argon = Argon2::new(self.algorithm, Version::V0x13, params);
229
67

            
230
67
        let output_len = argon
231
67
            .params()
232
67
            .output_len()
233
67
            .unwrap_or(argon2::Params::DEFAULT_OUTPUT_LEN);
234
67
        let output = argon2::password_hash::Output::init_with(output_len, |out| {
235
67
            Ok(argon.hash_password_into_with_memory(
236
67
                request.password.as_bytes(),
237
67
                salt_bytes,
238
67
                out,
239
67
                &mut self.blocks,
240
67
            )?)
241
67
        })?;
242

            
243
        Ok(HashResponse::Hash(SensitiveString(
244
            PasswordHash {
245
67
                algorithm: self.algorithm.ident(),
246
67
                version: Some(Version::V0x13.into()),
247
67
                params: ParamsString::try_from(argon.params())?,
248
67
                salt: Some(salt.as_salt()),
249
67
                hash: Some(output),
250
67
            }
251
67
            .to_string(),
252
        )))
253
67
    }
254

            
255
134
    fn allocate_blocks(&mut self, params: &argon2::Params) {
256
2195456
        for _ in self.blocks.len()..params.block_count() {
257
2195456
            self.blocks.push(Block::default());
258
2195456
        }
259
134
    }
260
}
261

            
262
#[derive(Debug)]
263
pub struct HashRequest {
264
    id: u64,
265
    password: SensitiveString,
266
    verify_against: Option<SensitiveString>,
267
    result_sender: oneshot::Sender<Result<HashResponse, Error>>,
268
}
269

            
270
#[derive(Debug)]
271
pub enum HashResponse {
272
    Hash(SensitiveString),
273
    Verified,
274
}
275

            
276
#[derive(thiserror::Error, Debug)]
277
enum ArgonError {
278
    #[error("{0}")]
279
    Argon(#[from] argon2::Error),
280
    #[error("{0}")]
281
    Hash(#[from] argon2::password_hash::Error),
282
}
283

            
284
fn ceil_divide(dividend: u128, divisor: u128) -> u128 {
285
    match divisor {
286
        0 => panic!("divide by 0"),
287
        1 => dividend,
288
        _ => {
289
            let rounding = divisor - 1;
290
            (dividend + rounding) / divisor
291
        }
292
    }
293
}
294

            
295
1
#[tokio::test]
296
1
async fn basic_test() {
297
1
    use crate::config::SystemDefault;
298
1
    let hasher = Hasher::new(ArgonConfiguration::default());
299
1

            
300
1
    let password = SensitiveString(String::from("hunter2"));
301
1
    let hash = hasher.hash(1, password.clone()).await.unwrap();
302
1
    hasher
303
1
        .verify(1, password.clone(), hash.clone())
304
1
        .await
305
1
        .unwrap();
306
1

            
307
1
    let Hasher { sender, threads } = hasher;
308
1
    drop(sender);
309
2
    for thread in threads {
310
1
        thread.join().unwrap();
311
1
    }
312
1
}