1
use std::{
2
    borrow::Cow,
3
    cmp::Ordering,
4
    fmt::{Display, Write},
5
    hash::Hash,
6
    ops::Deref,
7
    str::FromStr,
8
};
9

            
10
use actionable::Identifier;
11
use serde::{de::Visitor, Deserialize, Serialize};
12

            
13
use crate::key::{Key, KeyEncoding};
14

            
15
/// The serialized representation of a document's unique ID.
16
345087
#[derive(Clone, Copy)]
17
pub struct DocumentId {
18
    length: u8,
19
    bytes: [u8; Self::MAX_LENGTH],
20
}
21

            
22
impl Deref for DocumentId {
23
    type Target = [u8];
24
11676232
    fn deref(&self) -> &[u8] {
25
11676232
        &self.bytes[..usize::from(self.length)]
26
11676232
    }
27
}
28

            
29
impl Ord for DocumentId {
30
3503
    fn cmp(&self, other: &Self) -> Ordering {
31
3503
        (**self).cmp(&**other)
32
3503
    }
33
}
34

            
35
impl PartialOrd for DocumentId {
36
1209
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
37
1209
        Some(self.cmp(other))
38
1209
    }
39
}
40

            
41
impl Eq for DocumentId {}
42

            
43
impl PartialEq for DocumentId {
44
95833
    fn eq(&self, other: &Self) -> bool {
45
95833
        **self == **other
46
95833
    }
47
}
48

            
49
impl std::fmt::Debug for DocumentId {
50
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51
217
        f.write_str("DocumentId(")?;
52
217
        arc_bytes::print_bytes(self, f)?;
53
217
        f.write_char(')')
54
217
    }
55
}
56

            
57
impl Display for DocumentId {
58
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59
348
        if let Ok(string) = std::str::from_utf8(self.as_ref()) {
60
351
            if string.bytes().all(|b| (32..=127).contains(&b)) {
61
1
                return f.write_str(string);
62
346
            }
63
1
        }
64

            
65
347
        if let Some((first_nonzero_byte, _)) = self
66
347
            .as_ref()
67
347
            .iter()
68
347
            .copied()
69
347
            .enumerate()
70
2881
            .find(|(_index, b)| *b != 0)
71
        {
72
345
            if first_nonzero_byte > 0 {
73
344
                write!(f, "{:x}$", first_nonzero_byte)?;
74
            } else {
75
1
                f.write_char('$')?;
76
            }
77

            
78
347
            for (index, byte) in self[first_nonzero_byte..].iter().enumerate() {
79
347
                if index > 0 {
80
2
                    write!(f, "{:02x}", byte)?;
81
                } else {
82
345
                    write!(f, "{:x}", byte)?;
83
                }
84
            }
85
345
            Ok(())
86
        } else {
87
            // All zeroes
88
2
            write!(f, "{:x}$", self.len())
89
        }
90
348
    }
91
}
92

            
93
impl Hash for DocumentId {
94
96179
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
95
96179
        (**self).hash(state);
96
96179
    }
97
}
98

            
99
impl<'a> From<DocumentId> for Identifier<'a> {
100
2
    fn from(id: DocumentId) -> Self {
101
2
        Identifier::from(id.to_vec())
102
2
    }
103
}
104

            
105
impl<'a> From<&'a DocumentId> for Identifier<'a> {
106
1013855
    fn from(id: &'a DocumentId) -> Self {
107
1013855
        Identifier::from(&**id)
108
1013855
    }
109
}
110

            
111
1
#[test]
112
1
fn document_id_identifier_tests() {
113
1
    assert_eq!(
114
1
        Identifier::from(DocumentId::new(String::from("hello")).unwrap()),
115
1
        Identifier::from("hello")
116
1
    );
117
1
    assert_eq!(
118
1
        Identifier::from(DocumentId::from_u64(1)),
119
1
        Identifier::from(1)
120
1
    );
121
1
}
122

            
123
/// An invalid hexadecimal character was encountered.
124
#[derive(thiserror::Error, Debug)]
125
#[error("invalid hexadecimal bytes")]
126
pub struct InvalidHexadecimal;
127

            
128
324
const fn decode_hex_nibble(byte: u8) -> Result<u8, InvalidHexadecimal> {
129
324
    match byte {
130
324
        b'0'..=b'9' => Ok(byte - b'0'),
131
8
        b'A'..=b'F' => Ok(byte - b'A' + 10),
132
8
        b'a'..=b'f' => Ok(byte - b'a' + 10),
133
        _ => Err(InvalidHexadecimal),
134
    }
135
324
}
136

            
137
impl FromStr for DocumentId {
138
    type Err = crate::Error;
139

            
140
161
    fn from_str(s: &str) -> Result<Self, Self::Err> {
141
161
        if s.is_empty() {
142
            return Ok(Self::default());
143
161
        }
144
161
        let mut id = Self::default();
145
161
        let bytes = s.as_bytes();
146

            
147
326
        if let Some((pound_offset, _)) = s.bytes().enumerate().find(|(_index, b)| *b == b'$') {
148
160
            if pound_offset > 2 {
149
                return Err(crate::Error::DocumentIdTooLong);
150
160
            }
151

            
152
160
            let preceding_zeroes = if pound_offset > 0 {
153
159
                let mut length = [0_u8];
154
159
                decode_big_endian_hex(&bytes[0..pound_offset], &mut length)?;
155
159
                usize::from(length[0])
156
            } else {
157
1
                0
158
            };
159

            
160
160
            let decoded_length = decode_big_endian_hex(&bytes[pound_offset + 1..], &mut id.bytes)?;
161
160
            if preceding_zeroes > 0 {
162
159
                let total_length = preceding_zeroes + usize::from(decoded_length);
163
159
                if total_length > Self::MAX_LENGTH {
164
                    return Err(crate::Error::DocumentIdTooLong);
165
159
                }
166
159
                // The full length indicated a longer ID, so we need to prefix some null bytes.
167
159
                id.bytes
168
159
                    .copy_within(0..usize::from(decoded_length), preceding_zeroes);
169
159
                id.bytes[0..preceding_zeroes].fill(0);
170
159
                id.length = u8::try_from(total_length).unwrap();
171
1
            } else {
172
1
                id.length = decoded_length;
173
1
            }
174
1
        } else if bytes.len() > Self::MAX_LENGTH {
175
            return Err(crate::Error::DocumentIdTooLong);
176
1
        } else {
177
1
            // UTF-8 representable
178
1
            id.length = u8::try_from(bytes.len()).unwrap();
179
1
            id.bytes[0..bytes.len()].copy_from_slice(bytes);
180
1
        }
181
161
        Ok(id)
182
161
    }
183
}
184

            
185
319
fn decode_big_endian_hex(bytes: &[u8], output: &mut [u8]) -> Result<u8, crate::Error> {
186
319
    let mut length = 0;
187
319
    let mut chunks = if bytes.len() & 1 == 0 {
188
5
        bytes.chunks_exact(2)
189
    } else {
190
        // Odd amount of bytes, special case the first char
191
314
        output[0] = decode_hex_nibble(bytes[0])?;
192
314
        length = 1;
193
314
        bytes[1..].chunks_exact(2)
194
    };
195
324
    for chunk in &mut chunks {
196
5
        let write_at = length;
197
5
        length += 1;
198
5
        if length > output.len() {
199
            return Err(crate::Error::DocumentIdTooLong);
200
5
        }
201
5
        let upper = decode_hex_nibble(chunk[0])?;
202
5
        let lower = decode_hex_nibble(chunk[1])?;
203
5
        output[write_at] = upper << 4 | lower;
204
    }
205
319
    if !chunks.remainder().is_empty() {
206
        return Err(crate::Error::from(InvalidHexadecimal));
207
319
    }
208
319
    Ok(u8::try_from(length).unwrap())
209
319
}
210

            
211
1
#[test]
212
1
fn document_id_parsing() {
213
6
    fn test_id(bytes: &[u8], display: &str) {
214
6
        let id = DocumentId::try_from(bytes).unwrap();
215
6
        let as_string = id.to_string();
216
6
        assert_eq!(as_string, display);
217
6
        let parsed = DocumentId::from_str(&as_string).unwrap();
218
6
        assert_eq!(&*parsed, bytes);
219
6
    }
220
1

            
221
1
    test_id(b"hello", "hello");
222
1
    test_id(b"\x00\x0a\xaf\xfa", "1$aaffa");
223
1
    test_id(&1_u128.to_be_bytes(), "f$1");
224
1
    test_id(&17_u8.to_be_bytes(), "$11");
225
1
    test_id(&[0_u8; 63], "3f$");
226
1
    // The above test is the same as this one, at the time of writing, but in
227
1
    // case we update MAX_LENGTH in the future, this extra test will ensure the
228
1
    // max-length formatting is always tested.
229
1
    test_id(
230
1
        &[0_u8; DocumentId::MAX_LENGTH],
231
1
        &format!("{:x}$", DocumentId::MAX_LENGTH),
232
1
    );
233
1
}
234

            
235
impl Default for DocumentId {
236
6980477
    fn default() -> Self {
237
6980477
        Self {
238
6980477
            length: 0,
239
6980477
            bytes: [0; Self::MAX_LENGTH],
240
6980477
        }
241
6980477
    }
242
}
243

            
244
impl<'a> TryFrom<&'a [u8]> for DocumentId {
245
    type Error = crate::Error;
246

            
247
2420650
    fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
248
2420650
        if bytes.len() <= Self::MAX_LENGTH {
249
2420650
            let mut new_id = Self {
250
2420650
                length: u8::try_from(bytes.len()).unwrap(),
251
2420650
                ..Self::default()
252
2420650
            };
253
2420650
            new_id.bytes[..bytes.len()].copy_from_slice(bytes);
254
2420650
            Ok(new_id)
255
        } else {
256
            Err(crate::Error::DocumentIdTooLong)
257
        }
258
2420650
    }
259
}
260

            
261
impl<const N: usize> TryFrom<[u8; N]> for DocumentId {
262
    type Error = crate::Error;
263

            
264
    fn try_from(bytes: [u8; N]) -> Result<Self, Self::Error> {
265
        Self::try_from(&bytes[..])
266
    }
267
}
268

            
269
impl DocumentId {
270
    /// The maximum length, in bytes, that an id can contain.
271
    pub const MAX_LENGTH: usize = 63;
272

            
273
    /// Returns a new instance with `value` as the identifier..
274
605387
    pub fn new<PrimaryKey: for<'k> Key<'k>, PrimaryKeyRef: for<'k> KeyEncoding<'k, PrimaryKey>>(
275
605387
        value: PrimaryKeyRef,
276
605387
    ) -> Result<Self, crate::Error> {
277
605387
        let bytes = value
278
605387
            .as_ord_bytes()
279
605387
            .map_err(|err| crate::Error::Serialization(err.to_string()))?;
280
605387
        Self::try_from(&bytes[..])
281
605387
    }
282

            
283
    /// Returns a new document ID for a u64. This is equivalent to
284
    /// `DocumentId::new(id)`, but since this function accepts a non-generic
285
    /// type, it can help with type inference in some expressions.
286
    #[must_use]
287
    #[allow(clippy::missing_panics_doc)] // Unwrap is impossible to fail.
288
2826
    pub fn from_u64(id: u64) -> Self {
289
2826
        Self::try_from(&id.to_be_bytes()[..]).unwrap()
290
2826
    }
291

            
292
    /// Returns a new document ID for a u32. This is equivalent to
293
    /// `DocumentId::new(id)`, but since this function accepts a non-generic
294
    /// type, it can help with type inference in some expressions.
295
    #[must_use]
296
    #[allow(clippy::missing_panics_doc)] // Unwrap is impossible to fail.
297
    pub fn from_u32(id: u32) -> Self {
298
        Self::try_from(&id.to_be_bytes()[..]).unwrap()
299
    }
300

            
301
    /// Returns the contained value, deserialized back to its original type.
302
871776
    pub fn deserialize<'a, PrimaryKey: Key<'a>>(&'a self) -> Result<PrimaryKey, crate::Error> {
303
871776
        PrimaryKey::from_ord_bytes(self.as_ref())
304
871776
            .map_err(|err| crate::Error::Serialization(err.to_string()))
305
871776
    }
306
}
307

            
308
impl Serialize for DocumentId {
309
3574120
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
310
3574120
    where
311
3574120
        S: serde::Serializer,
312
3574120
    {
313
3574120
        serializer.serialize_bytes(self.as_ref())
314
3574120
    }
315
}
316

            
317
impl<'de> Deserialize<'de> for DocumentId {
318
3665595
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
319
3665595
    where
320
3665595
        D: serde::Deserializer<'de>,
321
3665595
    {
322
3665595
        deserializer.deserialize_byte_buf(DocumentIdVisitor)
323
3665595
    }
324
}
325

            
326
struct DocumentIdVisitor;
327

            
328
impl<'de> Visitor<'de> for DocumentIdVisitor {
329
    type Value = DocumentId;
330

            
331
    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332
        formatter.write_str("a document id (bytes)")
333
    }
334

            
335
3665512
    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
336
3665512
    where
337
3665512
        E: serde::de::Error,
338
3665512
    {
339
3665512
        if v.len() <= DocumentId::MAX_LENGTH {
340
3665512
            let mut document_id = DocumentId {
341
3665512
                length: u8::try_from(v.len()).unwrap(),
342
3665512
                ..DocumentId::default()
343
3665512
            };
344
3665512
            document_id.bytes[..v.len()].copy_from_slice(v);
345
3665512
            Ok(document_id)
346
        } else {
347
            Err(E::invalid_length(v.len(), &"< 64 bytes"))
348
        }
349
3665512
    }
350
}
351

            
352
impl<'k> Key<'k> for DocumentId {
353
    fn from_ord_bytes(bytes: &'k [u8]) -> Result<Self, Self::Error> {
354
        Self::try_from(bytes)
355
    }
356
}
357

            
358
impl<'k, PrimaryKey> KeyEncoding<'k, PrimaryKey> for DocumentId
359
where
360
    PrimaryKey: for<'a> Key<'a>,
361
{
362
    type Error = crate::Error;
363

            
364
    const LENGTH: Option<usize> = None;
365

            
366
58982
    fn as_ord_bytes(&'k self) -> Result<Cow<'k, [u8]>, Self::Error> {
367
58982
        Ok(Cow::Borrowed(self))
368
58982
    }
369
}