1
use std::{
2
    cmp::Ordering,
3
    fmt::{Display, Write},
4
    hash::Hash,
5
    ops::Deref,
6
    str::FromStr,
7
};
8

            
9
use actionable::Identifier;
10
use serde::{de::Visitor, Deserialize, Serialize};
11

            
12
use crate::key::Key;
13

            
14
/// The serialized representation of a document's unique ID.
15
272094
#[derive(Clone, Copy)]
16
pub struct DocumentId {
17
    length: u8,
18
    bytes: [u8; Self::MAX_LENGTH],
19
}
20

            
21
impl Deref for DocumentId {
22
    type Target = [u8];
23
9161610
    fn deref(&self) -> &[u8] {
24
9161610
        &self.bytes[..usize::from(self.length)]
25
9161610
    }
26
}
27

            
28
impl Ord for DocumentId {
29
2214
    fn cmp(&self, other: &Self) -> Ordering {
30
2214
        (&**self).cmp(&**other)
31
2214
    }
32
}
33

            
34
impl PartialOrd for DocumentId {
35
729
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
36
729
        Some(self.cmp(other))
37
729
    }
38
}
39

            
40
impl Eq for DocumentId {}
41

            
42
impl PartialEq for DocumentId {
43
69834
    fn eq(&self, other: &Self) -> bool {
44
69834
        **self == **other
45
69834
    }
46
}
47

            
48
impl std::fmt::Debug for DocumentId {
49
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50
135
        f.write_str("DocumentId(")?;
51
135
        arc_bytes::print_bytes(self, f)?;
52
135
        f.write_char(')')
53
135
    }
54
}
55

            
56
impl Display for DocumentId {
57
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58
304
        if let Ok(string) = std::str::from_utf8(self.as_ref()) {
59
307
            if string.bytes().all(|b| (32..=127).contains(&b)) {
60
1
                return f.write_str(string);
61
302
            }
62
1
        }
63

            
64
303
        if let Some((first_nonzero_byte, _)) = self
65
303
            .as_ref()
66
303
            .iter()
67
303
            .copied()
68
303
            .enumerate()
69
2529
            .find(|(_index, b)| *b != 0)
70
        {
71
301
            if first_nonzero_byte > 0 {
72
300
                write!(f, "{:x}$", first_nonzero_byte)?;
73
            } else {
74
1
                f.write_char('$')?;
75
            }
76

            
77
303
            for (index, byte) in self[first_nonzero_byte..].iter().enumerate() {
78
303
                if index > 0 {
79
2
                    write!(f, "{:02x}", byte)?;
80
                } else {
81
301
                    write!(f, "{:x}", byte)?;
82
                }
83
            }
84
301
            Ok(())
85
        } else {
86
            // All zeroes
87
2
            write!(f, "{:x}$", self.len())
88
        }
89
304
    }
90
}
91

            
92
impl Hash for DocumentId {
93
56639
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
94
56639
        (&**self).hash(state);
95
56639
    }
96
}
97

            
98
impl<'a> From<DocumentId> for Identifier<'a> {
99
2
    fn from(id: DocumentId) -> Self {
100
2
        Identifier::from(id.to_vec())
101
2
    }
102
}
103

            
104
impl<'a> From<&'a DocumentId> for Identifier<'a> {
105
333882
    fn from(id: &'a DocumentId) -> Self {
106
333882
        Identifier::from(&**id)
107
333882
    }
108
}
109

            
110
1
#[test]
111
1
fn document_id_identifier_tests() {
112
1
    assert_eq!(
113
1
        Identifier::from(DocumentId::new(String::from("hello")).unwrap()),
114
1
        Identifier::from("hello")
115
1
    );
116
1
    assert_eq!(
117
1
        Identifier::from(DocumentId::from_u64(1)),
118
1
        Identifier::from(1)
119
1
    );
120
1
}
121

            
122
/// An invalid hexadecimal character was encountered.
123
#[derive(thiserror::Error, Debug)]
124
#[error("invalid hexadecimal bytes")]
125
pub struct InvalidHexadecimal;
126

            
127
284
const fn decode_hex_nibble(byte: u8) -> Result<u8, InvalidHexadecimal> {
128
284
    match byte {
129
284
        b'0'..=b'9' => Ok(byte - b'0'),
130
8
        b'A'..=b'F' => Ok(byte - b'A' + 10),
131
8
        b'a'..=b'f' => Ok(byte - b'a' + 10),
132
        _ => Err(InvalidHexadecimal),
133
    }
134
284
}
135

            
136
impl FromStr for DocumentId {
137
    type Err = crate::Error;
138

            
139
141
    fn from_str(s: &str) -> Result<Self, Self::Err> {
140
141
        if s.is_empty() {
141
            return Ok(Self::default());
142
141
        }
143
141
        let mut id = Self::default();
144
141
        let bytes = s.as_bytes();
145

            
146
286
        if let Some((pound_offset, _)) = s.bytes().enumerate().find(|(_index, b)| *b == b'$') {
147
140
            if pound_offset > 2 {
148
                return Err(crate::Error::DocumentIdTooLong);
149
140
            }
150

            
151
140
            let preceding_zeroes = if pound_offset > 0 {
152
139
                let mut length = [0_u8];
153
139
                decode_big_endian_hex(&bytes[0..pound_offset], &mut length)?;
154
139
                usize::from(length[0])
155
            } else {
156
1
                0
157
            };
158

            
159
140
            let decoded_length = decode_big_endian_hex(&bytes[pound_offset + 1..], &mut id.bytes)?;
160
140
            if preceding_zeroes > 0 {
161
139
                let total_length = preceding_zeroes + usize::from(decoded_length);
162
139
                if total_length > Self::MAX_LENGTH {
163
                    return Err(crate::Error::DocumentIdTooLong);
164
139
                }
165
139
                // The full length indicated a longer ID, so we need to prefix some null bytes.
166
139
                id.bytes
167
139
                    .copy_within(0..usize::from(decoded_length), preceding_zeroes);
168
139
                id.bytes[0..preceding_zeroes].fill(0);
169
139
                id.length = u8::try_from(total_length).unwrap();
170
1
            } else {
171
1
                id.length = decoded_length;
172
1
            }
173
1
        } else if bytes.len() > Self::MAX_LENGTH {
174
            return Err(crate::Error::DocumentIdTooLong);
175
1
        } else {
176
1
            // UTF-8 representable
177
1
            id.length = u8::try_from(bytes.len()).unwrap();
178
1
            id.bytes[0..bytes.len()].copy_from_slice(bytes);
179
1
        }
180
141
        Ok(id)
181
141
    }
182
}
183

            
184
279
fn decode_big_endian_hex(bytes: &[u8], output: &mut [u8]) -> Result<u8, crate::Error> {
185
279
    let mut length = 0;
186
279
    let mut chunks = if bytes.len() & 1 == 0 {
187
5
        bytes.chunks_exact(2)
188
    } else {
189
        // Odd amount of bytes, special case the first char
190
274
        output[0] = decode_hex_nibble(bytes[0])?;
191
274
        length = 1;
192
274
        bytes[1..].chunks_exact(2)
193
    };
194
284
    for chunk in &mut chunks {
195
5
        let write_at = length;
196
5
        length += 1;
197
5
        if length > output.len() {
198
            return Err(crate::Error::DocumentIdTooLong);
199
5
        }
200
5
        let upper = decode_hex_nibble(chunk[0])?;
201
5
        let lower = decode_hex_nibble(chunk[1])?;
202
5
        output[write_at] = upper << 4 | lower;
203
    }
204
279
    if !chunks.remainder().is_empty() {
205
        return Err(crate::Error::from(InvalidHexadecimal));
206
279
    }
207
279
    Ok(u8::try_from(length).unwrap())
208
279
}
209

            
210
1
#[test]
211
1
fn document_id_parsing() {
212
6
    fn test_id(bytes: &[u8], display: &str) {
213
6
        let id = DocumentId::try_from(bytes).unwrap();
214
6
        let as_string = id.to_string();
215
6
        assert_eq!(as_string, display);
216
6
        let parsed = DocumentId::from_str(&as_string).unwrap();
217
6
        assert_eq!(&*parsed, bytes);
218
6
    }
219
1

            
220
1
    test_id(b"hello", "hello");
221
1
    test_id(b"\x00\x0a\xaf\xfa", "1$aaffa");
222
1
    test_id(&1_u128.to_be_bytes(), "f$1");
223
1
    test_id(&17_u8.to_be_bytes(), "$11");
224
1
    test_id(&[0_u8; 63], "3f$");
225
1
    // The above test is the same as this one, at the time of writing, but in
226
1
    // case we update MAX_LENGTH in the future, this extra test will ensure the
227
1
    // max-length formatting is always tested.
228
1
    test_id(
229
1
        &[0_u8; DocumentId::MAX_LENGTH],
230
1
        &format!("{:x}$", DocumentId::MAX_LENGTH),
231
1
    );
232
1
}
233

            
234
impl Default for DocumentId {
235
4676124
    fn default() -> Self {
236
4676124
        Self {
237
4676124
            length: 0,
238
4676124
            bytes: [0; Self::MAX_LENGTH],
239
4676124
        }
240
4676124
    }
241
}
242

            
243
impl<'a> TryFrom<&'a [u8]> for DocumentId {
244
    type Error = crate::Error;
245

            
246
1449186
    fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
247
1449186
        if bytes.len() <= Self::MAX_LENGTH {
248
1449186
            let mut new_id = Self {
249
1449186
                length: u8::try_from(bytes.len()).unwrap(),
250
1449186
                ..Self::default()
251
1449186
            };
252
1449186
            new_id.bytes[..bytes.len()].copy_from_slice(bytes);
253
1449186
            Ok(new_id)
254
        } else {
255
            Err(crate::Error::DocumentIdTooLong)
256
        }
257
1449186
    }
258
}
259

            
260
impl<const N: usize> TryFrom<[u8; N]> for DocumentId {
261
    type Error = crate::Error;
262

            
263
    fn try_from(bytes: [u8; N]) -> Result<Self, Self::Error> {
264
        Self::try_from(&bytes[..])
265
    }
266
}
267

            
268
impl DocumentId {
269
    /// The maximum length, in bytes, that an id can contain.
270
    pub const MAX_LENGTH: usize = 63;
271

            
272
    /// Returns a new instance with `value` as the identifier..
273
372446
    pub fn new<PrimaryKey: for<'a> Key<'a>>(value: PrimaryKey) -> Result<Self, crate::Error> {
274
372446
        let bytes = value
275
372446
            .as_ord_bytes()
276
372446
            .map_err(|err| crate::Error::Serialization(err.to_string()))?;
277
372446
        Self::try_from(&bytes[..])
278
372446
    }
279

            
280
    /// Returns a new document ID for a u64. This is equivalent to
281
    /// `DocumentId::new(id)`, but since this function accepts a non-generic
282
    /// type, it can help with type inference in some expressions.
283
    #[must_use]
284
    #[allow(clippy::missing_panics_doc)] // Unwrap is impossible to fail.
285
2408
    pub fn from_u64(id: u64) -> Self {
286
2408
        Self::try_from(&id.to_be_bytes()[..]).unwrap()
287
2408
    }
288

            
289
    /// Returns a new document ID for a u32. This is equivalent to
290
    /// `DocumentId::new(id)`, but since this function accepts a non-generic
291
    /// type, it can help with type inference in some expressions.
292
    #[must_use]
293
    #[allow(clippy::missing_panics_doc)] // Unwrap is impossible to fail.
294
    pub fn from_u32(id: u32) -> Self {
295
        Self::try_from(&id.to_be_bytes()[..]).unwrap()
296
    }
297

            
298
    /// Returns the contained value, deserialized back to its original type.
299
572110
    pub fn deserialize<'a, PrimaryKey: Key<'a>>(&'a self) -> Result<PrimaryKey, crate::Error> {
300
572110
        PrimaryKey::from_ord_bytes(self.as_ref())
301
572110
            .map_err(|err| crate::Error::Serialization(err.to_string()))
302
572110
    }
303
}
304

            
305
impl Serialize for DocumentId {
306
2653827
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
307
2653827
    where
308
2653827
        S: serde::Serializer,
309
2653827
    {
310
2653827
        serializer.serialize_bytes(self.as_ref())
311
2653827
    }
312
}
313

            
314
impl<'de> Deserialize<'de> for DocumentId {
315
2097859
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
316
2097859
    where
317
2097859
        D: serde::Deserializer<'de>,
318
2097859
    {
319
2097859
        deserializer.deserialize_byte_buf(DocumentIdVisitor)
320
2097859
    }
321
}
322

            
323
struct DocumentIdVisitor;
324

            
325
impl<'de> Visitor<'de> for DocumentIdVisitor {
326
    type Value = DocumentId;
327

            
328
    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329
        formatter.write_str("a document id (bytes)")
330
    }
331

            
332
2983243
    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
333
2983243
    where
334
2983243
        E: serde::de::Error,
335
2983243
    {
336
2983243
        if v.len() <= DocumentId::MAX_LENGTH {
337
2983243
            let mut document_id = DocumentId {
338
2983243
                length: u8::try_from(v.len()).unwrap(),
339
2983243
                ..DocumentId::default()
340
2983243
            };
341
2983243
            document_id.bytes[..v.len()].copy_from_slice(v);
342
2983243
            Ok(document_id)
343
        } else {
344
            Err(E::invalid_length(v.len(), &"< 64 bytes"))
345
        }
346
2983243
    }
347
}
348

            
349
/// A unique id for a document, either serialized or deserialized.
350
pub enum AnyDocumentId<PrimaryKey> {
351
    /// A serialized id.
352
    Serialized(DocumentId),
353
    /// A deserialized id.
354
    Deserialized(PrimaryKey),
355
}
356

            
357
impl<PrimaryKey> AnyDocumentId<PrimaryKey>
358
where
359
    PrimaryKey: for<'k> Key<'k>,
360
{
361
    /// Converts this value to a document id.
362
126432
    pub fn to_document_id(&self) -> Result<DocumentId, crate::Error> {
363
126432
        match self {
364
46453
            Self::Serialized(id) => Ok(*id),
365
79979
            Self::Deserialized(key) => DocumentId::new(key.clone()),
366
        }
367
126432
    }
368

            
369
    /// Converts this value to the primary key type.
370
26
    pub fn to_primary_key(&self) -> Result<PrimaryKey, crate::Error> {
371
26
        match self {
372
1
            Self::Serialized(id) => id.deserialize::<PrimaryKey>(),
373
25
            Self::Deserialized(key) => Ok(key.clone()),
374
        }
375
26
    }
376
}
377

            
378
impl<PrimaryKey> From<PrimaryKey> for AnyDocumentId<PrimaryKey>
379
where
380
    PrimaryKey: for<'k> Key<'k>,
381
{
382
80379
    fn from(key: PrimaryKey) -> Self {
383
80379
        Self::Deserialized(key)
384
80379
    }
385
}
386

            
387
impl<PrimaryKey> From<DocumentId> for AnyDocumentId<PrimaryKey> {
388
5550
    fn from(id: DocumentId) -> Self {
389
5550
        Self::Serialized(id)
390
5550
    }
391
}