1
use std::{
2
    cmp::Ordering,
3
    fmt::{Display, Write},
4
    hash::Hash,
5
    ops::Deref,
6
    str::FromStr,
7
};
8

            
9
use actionable::Identifier;
10
use serde::{de::Visitor, Deserialize, Serialize};
11

            
12
use crate::key::Key;
13

            
14
/// The serialized representation of a document's unique ID.
15
298117
#[derive(Clone, Copy)]
16
pub struct DocumentId {
17
    length: u8,
18
    bytes: [u8; Self::MAX_LENGTH],
19
}
20

            
21
impl Deref for DocumentId {
22
    type Target = [u8];
23
8845759
    fn deref(&self) -> &[u8] {
24
8845759
        &self.bytes[..usize::from(self.length)]
25
8845759
    }
26
}
27

            
28
impl Ord for DocumentId {
29
650
    fn cmp(&self, other: &Self) -> Ordering {
30
650
        (&**self).cmp(&**other)
31
650
    }
32
}
33

            
34
impl PartialOrd for DocumentId {
35
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
36
        Some(self.cmp(other))
37
    }
38
}
39

            
40
impl Eq for DocumentId {}
41

            
42
impl PartialEq for DocumentId {
43
106294
    fn eq(&self, other: &Self) -> bool {
44
106294
        **self == **other
45
106294
    }
46
}
47

            
48
impl std::fmt::Debug for DocumentId {
49
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50
130
        f.write_str("DocumentId(")?;
51
130
        arc_bytes::print_bytes(self, f)?;
52
130
        f.write_char(')')
53
130
    }
54
}
55

            
56
impl Display for DocumentId {
57
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58
293
        if let Ok(string) = std::str::from_utf8(self.as_ref()) {
59
296
            if string.bytes().all(|b| (32..=127).contains(&b)) {
60
1
                return f.write_str(string);
61
291
            }
62
1
        }
63

            
64
292
        if let Some((first_nonzero_byte, _)) = self
65
292
            .as_ref()
66
292
            .iter()
67
292
            .copied()
68
292
            .enumerate()
69
2441
            .find(|(_index, b)| *b != 0)
70
        {
71
290
            if first_nonzero_byte > 0 {
72
289
                write!(f, "{:x}$", first_nonzero_byte)?;
73
            } else {
74
1
                f.write_char('$')?;
75
            }
76

            
77
292
            for (index, byte) in self[first_nonzero_byte..].iter().enumerate() {
78
292
                if index > 0 {
79
2
                    write!(f, "{:02x}", byte)?;
80
                } else {
81
290
                    write!(f, "{:x}", byte)?;
82
                }
83
            }
84
290
            Ok(())
85
        } else {
86
            // All zeroes
87
2
            write!(f, "{:x}$", self.len())
88
        }
89
293
    }
90
}
91

            
92
impl Hash for DocumentId {
93
82934
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
94
82934
        (&**self).hash(state);
95
82934
    }
96
}
97

            
98
impl<'a> From<DocumentId> for Identifier<'a> {
99
2
    fn from(id: DocumentId) -> Self {
100
2
        Identifier::from(id.to_vec())
101
2
    }
102
}
103

            
104
impl<'a> From<&'a DocumentId> for Identifier<'a> {
105
313664
    fn from(id: &'a DocumentId) -> Self {
106
313664
        Identifier::from(&**id)
107
313664
    }
108
}
109

            
110
1
#[test]
111
1
fn document_id_identifier_tests() {
112
1
    assert_eq!(
113
1
        Identifier::from(DocumentId::new(String::from("hello")).unwrap()),
114
1
        Identifier::from("hello")
115
1
    );
116
1
    assert_eq!(
117
1
        Identifier::from(DocumentId::from_u64(1)),
118
1
        Identifier::from(1)
119
1
    );
120
1
}
121

            
122
/// An invalid hexadecimal character was encountered.
123
#[derive(thiserror::Error, Debug)]
124
#[error("invalid hexadecimal bytes")]
125
pub struct InvalidHexadecimal;
126

            
127
274
const fn decode_hex_nibble(byte: u8) -> Result<u8, InvalidHexadecimal> {
128
274
    match byte {
129
274
        b'0'..=b'9' => Ok(byte - b'0'),
130
8
        b'A'..=b'F' => Ok(byte - b'A' + 10),
131
8
        b'a'..=b'f' => Ok(byte - b'a' + 10),
132
        _ => Err(InvalidHexadecimal),
133
    }
134
274
}
135

            
136
impl FromStr for DocumentId {
137
    type Err = crate::Error;
138

            
139
136
    fn from_str(s: &str) -> Result<Self, Self::Err> {
140
136
        if s.is_empty() {
141
            return Ok(Self::default());
142
136
        }
143
136
        let mut id = Self::default();
144
136
        let bytes = s.as_bytes();
145

            
146
276
        if let Some((pound_offset, _)) = s.bytes().enumerate().find(|(_index, b)| *b == b'$') {
147
135
            if pound_offset > 2 {
148
                return Err(crate::Error::DocumentIdTooLong);
149
135
            }
150

            
151
135
            let preceding_zeroes = if pound_offset > 0 {
152
134
                let mut length = [0_u8];
153
134
                decode_big_endian_hex(&bytes[0..pound_offset], &mut length)?;
154
134
                usize::from(length[0])
155
            } else {
156
1
                0
157
            };
158

            
159
135
            let decoded_length = decode_big_endian_hex(&bytes[pound_offset + 1..], &mut id.bytes)?;
160
135
            if preceding_zeroes > 0 {
161
134
                let total_length = preceding_zeroes + usize::from(decoded_length);
162
134
                if total_length > Self::MAX_LENGTH {
163
                    return Err(crate::Error::DocumentIdTooLong);
164
134
                }
165
134
                // The full length indicated a longer ID, so we need to prefix some null bytes.
166
134
                id.bytes
167
134
                    .copy_within(0..usize::from(decoded_length), preceding_zeroes);
168
134
                id.bytes[0..preceding_zeroes].fill(0);
169
134
                id.length = u8::try_from(total_length).unwrap();
170
1
            } else {
171
1
                id.length = decoded_length;
172
1
            }
173
1
        } else if bytes.len() > Self::MAX_LENGTH {
174
            return Err(crate::Error::DocumentIdTooLong);
175
1
        } else {
176
1
            // UTF-8 representable
177
1
            id.length = u8::try_from(bytes.len()).unwrap();
178
1
            id.bytes[0..bytes.len()].copy_from_slice(bytes);
179
1
        }
180
136
        Ok(id)
181
136
    }
182
}
183

            
184
269
fn decode_big_endian_hex(bytes: &[u8], output: &mut [u8]) -> Result<u8, crate::Error> {
185
269
    let mut length = 0;
186
269
    let mut chunks = if bytes.len() & 1 == 0 {
187
5
        bytes.chunks_exact(2)
188
    } else {
189
        // Odd amount of bytes, special case the first char
190
264
        output[0] = decode_hex_nibble(bytes[0])?;
191
264
        length = 1;
192
264
        bytes[1..].chunks_exact(2)
193
    };
194
274
    for chunk in &mut chunks {
195
5
        let write_at = length;
196
5
        length += 1;
197
5
        if length > output.len() {
198
            return Err(crate::Error::DocumentIdTooLong);
199
5
        }
200
5
        let upper = decode_hex_nibble(chunk[0])?;
201
5
        let lower = decode_hex_nibble(chunk[1])?;
202
5
        output[write_at] = upper << 4 | lower;
203
    }
204
269
    if !chunks.remainder().is_empty() {
205
        return Err(crate::Error::from(InvalidHexadecimal));
206
269
    }
207
269
    Ok(u8::try_from(length).unwrap())
208
269
}
209

            
210
1
#[test]
211
1
fn document_id_parsing() {
212
6
    fn test_id(bytes: &[u8], display: &str) {
213
6
        let id = DocumentId::try_from(bytes).unwrap();
214
6
        let as_string = id.to_string();
215
6
        assert_eq!(as_string, display);
216
6
        let parsed = DocumentId::from_str(&as_string).unwrap();
217
6
        assert_eq!(&*parsed, bytes);
218
6
    }
219
1

            
220
1
    test_id(b"hello", "hello");
221
1
    test_id(b"\x00\x0a\xaf\xfa", "1$aaffa");
222
1
    test_id(&1_u128.to_be_bytes(), "f$1");
223
1
    test_id(&17_u8.to_be_bytes(), "$11");
224
1
    test_id(&[0_u8; 63], "3f$");
225
1
    // The above test is the same as this one, at the time of writing, but in
226
1
    // case we update MAX_LENGTH in the future, this extra test will ensure the
227
1
    // max-length formatting is always tested.
228
1
    test_id(
229
1
        &[0_u8; DocumentId::MAX_LENGTH],
230
1
        &format!("{:x}$", DocumentId::MAX_LENGTH),
231
1
    );
232
1
}
233

            
234
impl Default for DocumentId {
235
4922519
    fn default() -> Self {
236
4922519
        Self {
237
4922519
            length: 0,
238
4922519
            bytes: [0; Self::MAX_LENGTH],
239
4922519
        }
240
4922519
    }
241
}
242

            
243
impl<'a> TryFrom<&'a [u8]> for DocumentId {
244
    type Error = crate::Error;
245

            
246
1398655
    fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
247
1398655
        if bytes.len() <= Self::MAX_LENGTH {
248
1398655
            let mut new_id = Self {
249
1398655
                length: u8::try_from(bytes.len()).unwrap(),
250
1398655
                ..Self::default()
251
1398655
            };
252
1398655
            new_id.bytes[..bytes.len()].copy_from_slice(bytes);
253
1398655
            Ok(new_id)
254
        } else {
255
            Err(crate::Error::DocumentIdTooLong)
256
        }
257
1398655
    }
258
}
259

            
260
impl<const N: usize> TryFrom<[u8; N]> for DocumentId {
261
    type Error = crate::Error;
262

            
263
    fn try_from(bytes: [u8; N]) -> Result<Self, Self::Error> {
264
        Self::try_from(&bytes[..])
265
    }
266
}
267

            
268
impl DocumentId {
269
    /// The maximum length, in bytes, that an id can contain.
270
    pub const MAX_LENGTH: usize = 63;
271

            
272
    /// Returns a new instance with `value` as the identifier..
273
355698
    pub fn new<PrimaryKey: for<'a> Key<'a>>(value: PrimaryKey) -> Result<Self, crate::Error> {
274
355698
        let bytes = value
275
355698
            .as_ord_bytes()
276
355698
            .map_err(|err| crate::Error::Serialization(err.to_string()))?;
277
355698
        Self::try_from(&bytes[..])
278
355698
    }
279

            
280
    /// Returns a new document ID for a u64. This is equivalent to
281
    /// `DocumentId::new(id)`, but since this function accepts a non-generic
282
    /// type, it can help with type inference in some expressions.
283
    #[must_use]
284
    #[allow(clippy::missing_panics_doc)] // Unwrap is impossible to fail.
285
157
    pub fn from_u64(id: u64) -> Self {
286
157
        Self::try_from(&id.to_be_bytes()[..]).unwrap()
287
157
    }
288

            
289
    /// Returns a new document ID for a u32. This is equivalent to
290
    /// `DocumentId::new(id)`, but since this function accepts a non-generic
291
    /// type, it can help with type inference in some expressions.
292
    #[must_use]
293
    #[allow(clippy::missing_panics_doc)] // Unwrap is impossible to fail.
294
    pub fn from_u32(id: u32) -> Self {
295
        Self::try_from(&id.to_be_bytes()[..]).unwrap()
296
    }
297

            
298
    /// Returns the contained value, deserialized back to its original type.
299
546410
    pub fn deserialize<'a, PrimaryKey: Key<'a>>(&'a self) -> Result<PrimaryKey, crate::Error> {
300
546410
        PrimaryKey::from_ord_bytes(self.as_ref())
301
546410
            .map_err(|err| crate::Error::Serialization(err.to_string()))
302
546410
    }
303
}
304

            
305
impl Serialize for DocumentId {
306
2570272
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
307
2570272
    where
308
2570272
        S: serde::Serializer,
309
2570272
    {
310
2570272
        serializer.serialize_bytes(self.as_ref())
311
2570272
    }
312
}
313

            
314
impl<'de> Deserialize<'de> for DocumentId {
315
2441304
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
316
2441304
    where
317
2441304
        D: serde::Deserializer<'de>,
318
2441304
    {
319
2441304
        deserializer.deserialize_byte_buf(DocumentIdVisitor)
320
2441304
    }
321
}
322

            
323
struct DocumentIdVisitor;
324

            
325
impl<'de> Visitor<'de> for DocumentIdVisitor {
326
    type Value = DocumentId;
327

            
328
    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329
        formatter.write_str("a document id (bytes)")
330
    }
331

            
332
3274615
    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
333
3274615
    where
334
3274615
        E: serde::de::Error,
335
3274615
    {
336
3274615
        if v.len() <= DocumentId::MAX_LENGTH {
337
3274615
            let mut document_id = DocumentId {
338
3274615
                length: u8::try_from(v.len()).unwrap(),
339
3274615
                ..DocumentId::default()
340
3274615
            };
341
3274615
            document_id.bytes[..v.len()].copy_from_slice(v);
342
3274615
            Ok(document_id)
343
        } else {
344
            Err(E::invalid_length(v.len(), &"< 64 bytes"))
345
        }
346
3274615
    }
347

            
348
    // Provided for backwards compatibility. No new data is written with this.
349
    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
350
    where
351
        E: serde::de::Error,
352
    {
353
        Ok(DocumentId::from_u64(v))
354
    }
355
}
356

            
357
/// A unique id for a document, either serialized or deserialized.
358
pub enum AnyDocumentId<PrimaryKey> {
359
    /// A serialized id.
360
    Serialized(DocumentId),
361
    /// A deserialized id.
362
    Deserialized(PrimaryKey),
363
}
364

            
365
impl<PrimaryKey> AnyDocumentId<PrimaryKey>
366
where
367
    PrimaryKey: for<'k> Key<'k>,
368
{
369
    /// Converts this value to a document id.
370
121353
    pub fn to_document_id(&self) -> Result<DocumentId, crate::Error> {
371
121353
        match self {
372
44750
            Self::Serialized(id) => Ok(*id),
373
76603
            Self::Deserialized(key) => DocumentId::new(key.clone()),
374
        }
375
121353
    }
376

            
377
    /// Converts this value to the primary key type.
378
26
    pub fn to_primary_key(&self) -> Result<PrimaryKey, crate::Error> {
379
26
        match self {
380
1
            Self::Serialized(id) => id.deserialize::<PrimaryKey>(),
381
25
            Self::Deserialized(key) => Ok(key.clone()),
382
        }
383
26
    }
384
}
385

            
386
impl<PrimaryKey> From<PrimaryKey> for AnyDocumentId<PrimaryKey>
387
where
388
    PrimaryKey: for<'k> Key<'k>,
389
{
390
76988
    fn from(key: PrimaryKey) -> Self {
391
76988
        Self::Deserialized(key)
392
76988
    }
393
}
394

            
395
impl<PrimaryKey> From<DocumentId> for AnyDocumentId<PrimaryKey> {
396
5482
    fn from(id: DocumentId) -> Self {
397
5482
        Self::Serialized(id)
398
5482
    }
399
}