1
use std::{
2
    borrow::Cow, convert::Infallible, io::ErrorKind, num::TryFromIntError, string::FromUtf8Error,
3
};
4

            
5
use arc_bytes::{
6
    serde::{Bytes, CowBytes},
7
    ArcBytes,
8
};
9
use num_traits::{FromPrimitive, ToPrimitive};
10
use ordered_varint::{Signed, Unsigned, Variable};
11

            
12
use crate::AnyError;
13

            
14
/// A trait that enables a type to convert itself to a big-endian/network byte order.
15
pub trait Key<'k>: Clone + Send + Sync {
16
    /// The error type that can be produced by either serialization or
17
    /// deserialization.
18
    type Error: AnyError;
19

            
20
    /// The size of the key, if constant.
21
    const LENGTH: Option<usize>;
22

            
23
    /// Convert `self` into a `Cow<[u8]>` containing bytes ordered in big-endian/network byte order.
24
    fn as_big_endian_bytes(&'k self) -> Result<Cow<'k, [u8]>, Self::Error>;
25

            
26
    /// Convert a slice of bytes into `Self` by interpretting `bytes` in big-endian/network byte order.
27
    fn from_big_endian_bytes(bytes: &'k [u8]) -> Result<Self, Self::Error>;
28
}
29

            
30
impl<'k> Key<'k> for Cow<'k, [u8]> {
31
    type Error = Infallible;
32

            
33
    const LENGTH: Option<usize> = None;
34

            
35
1
    fn as_big_endian_bytes(&'k self) -> Result<Cow<'k, [u8]>, Self::Error> {
36
1
        Ok(self.clone())
37
1
    }
38

            
39
1
    fn from_big_endian_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
40
1
        Ok(Cow::Owned(bytes.to_vec()))
41
1
    }
42
}
43

            
44
impl<'a> Key<'a> for Vec<u8> {
45
    type Error = Infallible;
46

            
47
    const LENGTH: Option<usize> = None;
48

            
49
    fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
50
        Ok(Cow::Borrowed(self))
51
    }
52

            
53
    fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
54
        Ok(bytes.to_vec())
55
    }
56
}
57

            
58
impl<'a> Key<'a> for ArcBytes<'a> {
59
    type Error = Infallible;
60

            
61
    const LENGTH: Option<usize> = None;
62

            
63
    fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
64
        Ok(Cow::Borrowed(self))
65
    }
66

            
67
    fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
68
        Ok(Self::from(bytes))
69
    }
70
}
71

            
72
impl<'a> Key<'a> for CowBytes<'a> {
73
    type Error = Infallible;
74

            
75
    const LENGTH: Option<usize> = None;
76

            
77
    fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
78
        Ok(self.0.clone())
79
    }
80

            
81
    fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
82
        Ok(Self::from(bytes))
83
    }
84
}
85

            
86
impl<'a> Key<'a> for Bytes {
87
    type Error = Infallible;
88

            
89
    const LENGTH: Option<usize> = None;
90

            
91
444061
    fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
92
444061
        Ok(Cow::Borrowed(self))
93
444061
    }
94

            
95
    fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
96
        Ok(Self::from(bytes))
97
    }
98
}
99

            
100
impl<'a> Key<'a> for String {
101
    type Error = FromUtf8Error;
102

            
103
    const LENGTH: Option<usize> = None;
104

            
105
234255
    fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
106
234255
        Ok(Cow::Borrowed(self.as_bytes()))
107
234255
    }
108

            
109
212336
    fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
110
212336
        Self::from_utf8(bytes.to_vec())
111
212336
    }
112
}
113

            
114
impl<'a> Key<'a> for () {
115
    type Error = Infallible;
116

            
117
    const LENGTH: Option<usize> = Some(0);
118

            
119
47
    fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
120
47
        Ok(Cow::default())
121
47
    }
122

            
123
70
    fn from_big_endian_bytes(_: &'a [u8]) -> Result<Self, Self::Error> {
124
70
        Ok(())
125
70
    }
126
}
127

            
128
macro_rules! impl_key_for_tuple {
129
    ($(($index:tt, $varname:ident, $generic:ident)),+) => {
130
        impl<'a, $($generic),+> Key<'a> for ($($generic),+)
131
        where
132
            $($generic: Key<'a>),+
133
        {
134
            type Error = CompositeKeyError;
135

            
136
            const LENGTH: Option<usize> = match ($($generic::LENGTH),+) {
137
                ($(Some($varname)),+) => Some($($varname +)+ 0),
138
                _ => None,
139
            };
140

            
141
508
            fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
142
508
                let mut bytes = Vec::new();
143
508

            
144
508
                $(encode_composite_key_field(&self.$index, &mut bytes)?;)+
145

            
146
508
                Ok(Cow::Owned(bytes))
147
508
            }
148

            
149
508
            fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
150
508
                $(let ($varname, bytes) = decode_composite_key_field::<$generic>(bytes)?;)+
151

            
152
508
                if bytes.is_empty() {
153
508
                    Ok(($($varname),+))
154
                } else {
155
                    Err(CompositeKeyError::new(std::io::Error::from(
156
                        ErrorKind::InvalidData,
157
                    )))
158
                }
159
508
            }
160
        }
161
    };
162
}
163

            
164
impl_key_for_tuple!((0, t1, T1), (1, t2, T2));
165
impl_key_for_tuple!((0, t1, T1), (1, t2, T2), (2, t3, T3));
166
impl_key_for_tuple!((0, t1, T1), (1, t2, T2), (2, t3, T3), (3, t4, T4));
167
impl_key_for_tuple!(
168
    (0, t1, T1),
169
    (1, t2, T2),
170
    (2, t3, T3),
171
    (3, t4, T4),
172
    (4, t5, T5)
173
);
174
impl_key_for_tuple!(
175
    (0, t1, T1),
176
    (1, t2, T2),
177
    (2, t3, T3),
178
    (3, t4, T4),
179
    (4, t5, T5),
180
    (5, t6, T6)
181
);
182
impl_key_for_tuple!(
183
    (0, t1, T1),
184
    (1, t2, T2),
185
    (2, t3, T3),
186
    (3, t4, T4),
187
    (4, t5, T5),
188
    (5, t6, T6),
189
    (6, t7, T7)
190
);
191
impl_key_for_tuple!(
192
    (0, t1, T1),
193
    (1, t2, T2),
194
    (2, t3, T3),
195
    (3, t4, T4),
196
    (4, t5, T5),
197
    (5, t6, T6),
198
    (6, t7, T7),
199
    (7, t8, T8)
200
);
201

            
202
3584
fn encode_composite_key_field<'a, T: Key<'a>>(
203
3584
    value: &'a T,
204
3584
    bytes: &mut Vec<u8>,
205
3584
) -> Result<(), CompositeKeyError> {
206
3584
    let t2 = T::as_big_endian_bytes(value).map_err(CompositeKeyError::new)?;
207
3584
    if T::LENGTH.is_none() {
208
3584
        (t2.len() as u64)
209
3584
            .encode_variable(bytes)
210
3584
            .map_err(CompositeKeyError::new)?;
211
    }
212
3584
    bytes.extend(t2.iter().copied());
213
3584
    Ok(())
214
3584
}
215

            
216
3584
fn decode_composite_key_field<'a, T: Key<'a>>(
217
3584
    mut bytes: &'a [u8],
218
3584
) -> Result<(T, &[u8]), CompositeKeyError> {
219
3584
    let length = if let Some(length) = T::LENGTH {
220
        length
221
    } else {
222
3584
        usize::try_from(u64::decode_variable(&mut bytes)?)?
223
    };
224
3584
    let (t2, remaining) = bytes.split_at(length);
225
3584
    Ok((
226
3584
        T::from_big_endian_bytes(t2).map_err(CompositeKeyError::new)?,
227
3584
        remaining,
228
    ))
229
3584
}
230

            
231
1
#[test]
232
#[allow(clippy::too_many_lines, clippy::cognitive_complexity)] // couldn't figure out how to macro-ize it
233
1
fn composite_key_tests() {
234
7
    fn roundtrip<T: for<'a> Key<'a> + Ord + Eq + std::fmt::Debug>(mut cases: Vec<T>) {
235
7
        let mut encoded = {
236
7
            cases
237
7
                .iter()
238
508
                .map(|tuple| tuple.as_big_endian_bytes().unwrap().to_vec())
239
7
                .collect::<Vec<Vec<u8>>>()
240
7
        };
241
7
        cases.sort();
242
7
        encoded.sort();
243
7
        let decoded = encoded
244
7
            .iter()
245
508
            .map(|encoded| T::from_big_endian_bytes(encoded).unwrap())
246
7
            .collect::<Vec<_>>();
247
7
        assert_eq!(cases, decoded);
248
7
    }
249
1

            
250
1
    let values = [Unsigned::from(0_u8), Unsigned::from(16_u8)];
251
1
    let mut cases = Vec::new();
252
3
    for t1 in values {
253
6
        for t2 in values {
254
4
            cases.push((t1, t2));
255
4
        }
256
    }
257
1
    roundtrip(cases);
258
1

            
259
1
    let mut cases = Vec::new();
260
3
    for t1 in values {
261
6
        for t2 in values {
262
12
            for t3 in values {
263
8
                cases.push((t1, t2, t3));
264
8
            }
265
        }
266
    }
267
1
    roundtrip(cases);
268
1

            
269
1
    let mut cases = Vec::new();
270
3
    for t1 in values {
271
6
        for t2 in values {
272
12
            for t3 in values {
273
24
                for t4 in values {
274
16
                    cases.push((t1, t2, t3, t4));
275
16
                }
276
            }
277
        }
278
    }
279
1
    roundtrip(cases);
280
1

            
281
1
    let mut cases = Vec::new();
282
3
    for t1 in values {
283
6
        for t2 in values {
284
12
            for t3 in values {
285
24
                for t4 in values {
286
48
                    for t5 in values {
287
32
                        cases.push((t1, t2, t3, t4, t5));
288
32
                    }
289
                }
290
            }
291
        }
292
    }
293
1
    roundtrip(cases);
294
1

            
295
1
    let mut cases = Vec::new();
296
3
    for t1 in values {
297
6
        for t2 in values {
298
12
            for t3 in values {
299
24
                for t4 in values {
300
48
                    for t5 in values {
301
96
                        for t6 in values {
302
64
                            cases.push((t1, t2, t3, t4, t5, t6));
303
64
                        }
304
                    }
305
                }
306
            }
307
        }
308
    }
309
1
    roundtrip(cases);
310
1

            
311
1
    let mut cases = Vec::new();
312
3
    for t1 in values {
313
6
        for t2 in values {
314
12
            for t3 in values {
315
24
                for t4 in values {
316
48
                    for t5 in values {
317
96
                        for t6 in values {
318
192
                            for t7 in values {
319
128
                                cases.push((t1, t2, t3, t4, t5, t6, t7));
320
128
                            }
321
                        }
322
                    }
323
                }
324
            }
325
        }
326
    }
327
1
    roundtrip(cases);
328
1

            
329
1
    let mut cases = Vec::new();
330
3
    for t1 in values {
331
6
        for t2 in values {
332
12
            for t3 in values {
333
24
                for t4 in values {
334
48
                    for t5 in values {
335
96
                        for t6 in values {
336
192
                            for t7 in values {
337
384
                                for t8 in values {
338
256
                                    cases.push((t1, t2, t3, t4, t5, t6, t7, t8));
339
256
                                }
340
                            }
341
                        }
342
                    }
343
                }
344
            }
345
        }
346
    }
347
1
    roundtrip(cases);
348
1
}
349

            
350
/// An error occurred inside of one of the composite key fields.
351
#[derive(thiserror::Error, Debug)]
352
#[error("key error: {0}")]
353
pub struct CompositeKeyError(Box<dyn AnyError>);
354

            
355
impl CompositeKeyError {
356
    pub(crate) fn new<E: AnyError>(error: E) -> Self {
357
        Self(Box::new(error))
358
    }
359
}
360

            
361
impl From<TryFromIntError> for CompositeKeyError {
362
    fn from(err: TryFromIntError) -> Self {
363
        Self::new(err)
364
    }
365
}
366

            
367
impl From<std::io::Error> for CompositeKeyError {
368
    fn from(err: std::io::Error) -> Self {
369
        Self::new(err)
370
    }
371
}
372

            
373
impl<'a> Key<'a> for Signed {
374
    type Error = std::io::Error;
375

            
376
    const LENGTH: Option<usize> = None;
377

            
378
    fn as_big_endian_bytes(&self) -> Result<Cow<'a, [u8]>, Self::Error> {
379
        self.to_variable_vec().map(Cow::Owned)
380
    }
381

            
382
    fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
383
        Self::decode_variable(bytes)
384
    }
385
}
386

            
387
impl<'a> Key<'a> for Unsigned {
388
    type Error = std::io::Error;
389

            
390
    const LENGTH: Option<usize> = None;
391

            
392
3584
    fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
393
3584
        self.to_variable_vec().map(Cow::Owned)
394
3584
    }
395

            
396
3584
    fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
397
3584
        Self::decode_variable(bytes)
398
3584
    }
399
}
400

            
401
#[cfg(feature = "uuid")]
402
impl<'k> Key<'k> for uuid::Uuid {
403
    type Error = std::array::TryFromSliceError;
404

            
405
    const LENGTH: Option<usize> = Some(16);
406

            
407
    fn as_big_endian_bytes(&'k self) -> Result<Cow<'k, [u8]>, Self::Error> {
408
        Ok(Cow::Borrowed(self.as_bytes()))
409
    }
410

            
411
    fn from_big_endian_bytes(bytes: &'k [u8]) -> Result<Self, Self::Error> {
412
        Ok(Self::from_bytes(bytes.try_into()?))
413
    }
414
}
415

            
416
impl<'a, T> Key<'a> for Option<T>
417
where
418
    T: Key<'a>,
419
{
420
    type Error = T::Error;
421

            
422
    const LENGTH: Option<usize> = T::LENGTH;
423

            
424
    /// # Panics
425
    ///
426
    /// Panics if `T::into_big_endian_bytes` returns an empty `IVec`.
427
    // TODO consider removing this panic limitation by adding a single byte to
428
    // each key (at the end preferrably) so that we can distinguish between None
429
    // and a 0-byte type
430
3000
    fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
431
3000
        if let Some(contents) = self {
432
2539
            let contents = contents.as_big_endian_bytes()?;
433
2539
            assert!(!contents.is_empty());
434
2539
            Ok(contents)
435
        } else {
436
461
            Ok(Cow::default())
437
        }
438
3000
    }
439

            
440
3137
    fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
441
3137
        if bytes.is_empty() {
442
1012
            Ok(None)
443
        } else {
444
2125
            Ok(Some(T::from_big_endian_bytes(bytes)?))
445
        }
446
3137
    }
447
}
448

            
449
/// Adds `Key` support to an enum. Requires implementing
450
/// [`ToPrimitive`](num_traits::ToPrimitive) and
451
/// [`FromPrimitive`](num_traits::FromPrimitive), or using a crate like
452
/// [num-derive](https://crates.io/crates/num-derive) to do it automatically.
453
/// Take care when using enums as keys: if the order changes or if the meaning
454
/// of existing numerical values changes, make sure to update any related views'
455
/// version number to ensure the values are re-evaluated.
456
pub trait EnumKey: ToPrimitive + FromPrimitive + Clone + Send + Sync {}
457

            
458
/// An error that indicates an unexpected number of bytes were present.
459
#[derive(thiserror::Error, Debug)]
460
#[error("incorrect byte length")]
461
pub struct IncorrectByteLength;
462

            
463
/// An error that indicates an unexpected enum variant value was found.
464
#[derive(thiserror::Error, Debug)]
465
#[error("unknown enum variant")]
466
pub struct UnknownEnumVariant;
467

            
468
impl From<std::array::TryFromSliceError> for IncorrectByteLength {
469
    fn from(_: std::array::TryFromSliceError) -> Self {
470
        Self
471
    }
472
}
473

            
474
// ANCHOR: impl_key_for_enumkey
475
impl<'a, T> Key<'a> for T
476
where
477
    T: EnumKey,
478
{
479
    type Error = std::io::Error;
480
    const LENGTH: Option<usize> = None;
481

            
482
4
    fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
483
4
        let integer = self
484
4
            .to_u64()
485
4
            .map(Unsigned::from)
486
4
            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidData, IncorrectByteLength))?;
487
4
        Ok(Cow::Owned(integer.to_variable_vec()?))
488
4
    }
489

            
490
2
    fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
491
2
        let primitive = u64::decode_variable(bytes)?;
492
2
        Self::from_u64(primitive)
493
2
            .ok_or_else(|| std::io::Error::new(ErrorKind::InvalidData, UnknownEnumVariant))
494
2
    }
495
}
496
// ANCHOR_END: impl_key_for_enumkey
497

            
498
macro_rules! impl_key_for_primitive {
499
    ($type:ident) => {
500
        impl<'a> Key<'a> for $type {
501
            type Error = IncorrectByteLength;
502
            const LENGTH: Option<usize> = Some(std::mem::size_of::<$type>());
503

            
504
1992751
            fn as_big_endian_bytes(&'a self) -> Result<Cow<'a, [u8]>, Self::Error> {
505
1992751
                Ok(Cow::from(self.to_be_bytes().to_vec()))
506
1992751
            }
507

            
508
163392
            fn from_big_endian_bytes(bytes: &'a [u8]) -> Result<Self, Self::Error> {
509
163392
                Ok($type::from_be_bytes(bytes.try_into()?))
510
163392
            }
511
        }
512
    };
513
}
514

            
515
impl_key_for_primitive!(i8);
516
impl_key_for_primitive!(u8);
517
impl_key_for_primitive!(i16);
518
impl_key_for_primitive!(u16);
519
impl_key_for_primitive!(i32);
520
impl_key_for_primitive!(u32);
521
impl_key_for_primitive!(i64);
522
impl_key_for_primitive!(u64);
523
impl_key_for_primitive!(i128);
524
impl_key_for_primitive!(u128);
525

            
526
1
#[test]
527
#[allow(clippy::cognitive_complexity)] // I disagree - @ecton
528
1
fn primitive_key_encoding_tests() -> anyhow::Result<()> {
529
1
    macro_rules! test_primitive_extremes {
530
1
        ($type:ident) => {
531
1
            assert_eq!(
532
1
                &$type::MAX.to_be_bytes(),
533
1
                $type::MAX.as_big_endian_bytes()?.as_ref()
534
1
            );
535
1
            assert_eq!(
536
1
                $type::MAX,
537
1
                $type::from_big_endian_bytes(&$type::MAX.as_big_endian_bytes()?)?
538
1
            );
539
1
            assert_eq!(
540
1
                $type::MIN,
541
1
                $type::from_big_endian_bytes(&$type::MIN.as_big_endian_bytes()?)?
542
1
            );
543
1
        };
544
1
    }
545
1

            
546
1
    test_primitive_extremes!(i8);
547
1
    test_primitive_extremes!(u8);
548
1
    test_primitive_extremes!(i16);
549
1
    test_primitive_extremes!(u16);
550
1
    test_primitive_extremes!(i32);
551
1
    test_primitive_extremes!(u32);
552
1
    test_primitive_extremes!(i64);
553
1
    test_primitive_extremes!(u64);
554
1
    test_primitive_extremes!(i128);
555
1
    test_primitive_extremes!(u128);
556

            
557
1
    Ok(())
558
1
}
559

            
560
1
#[test]
561
1
fn optional_key_encoding_tests() -> anyhow::Result<()> {
562
1
    assert!(Option::<i8>::None.as_big_endian_bytes()?.is_empty());
563
1
    assert_eq!(
564
1
        Some(1_i8),
565
1
        Option::from_big_endian_bytes(&Some(1_i8).as_big_endian_bytes()?)?
566
    );
567
1
    Ok(())
568
1
}
569

            
570
1
#[test]
571
#[allow(clippy::unit_cmp)] // this is more of a compilation test
572
1
fn unit_key_encoding_tests() -> anyhow::Result<()> {
573
1
    assert!(().as_big_endian_bytes()?.is_empty());
574
1
    assert_eq!((), <() as Key>::from_big_endian_bytes(&[])?);
575
1
    Ok(())
576
1
}
577

            
578
1
#[test]
579
1
fn vec_key_encoding_tests() -> anyhow::Result<()> {
580
1
    const ORIGINAL_VALUE: &[u8] = b"bonsaidb";
581
1
    let vec = Cow::<'_, [u8]>::from(ORIGINAL_VALUE);
582
1
    assert_eq!(
583
1
        vec.clone(),
584
1
        Cow::from_big_endian_bytes(&vec.as_big_endian_bytes()?)?
585
    );
586
1
    Ok(())
587
1
}
588

            
589
1
#[test]
590
#[allow(clippy::use_self)] // Weird interaction with num_derive
591
1
fn enum_derive_tests() -> anyhow::Result<()> {
592
4
    #[derive(Clone, num_derive::ToPrimitive, num_derive::FromPrimitive)]
593
    enum SomeEnum {
594
        One = 1,
595
        NineNineNine = 999,
596
    }
597

            
598
    impl EnumKey for SomeEnum {}
599

            
600
1
    let encoded = SomeEnum::One.as_big_endian_bytes()?;
601
1
    let value = SomeEnum::from_big_endian_bytes(&encoded)?;
602
1
    assert!(matches!(value, SomeEnum::One));
603

            
604
1
    let encoded = SomeEnum::NineNineNine.as_big_endian_bytes()?;
605
1
    let value = SomeEnum::from_big_endian_bytes(&encoded)?;
606
1
    assert!(matches!(value, SomeEnum::NineNineNine));
607

            
608
1
    Ok(())
609
1
}