1
//! Macros BonsaiDb.
2

            
3
#![forbid(unsafe_code)]
4
#![warn(
5
    clippy::cargo,
6
    missing_docs,
7
    // clippy::missing_docs_in_private_items,
8
    clippy::pedantic,
9
    future_incompatible,
10
    rust_2018_idioms,
11
)]
12
#![cfg_attr(doc, deny(rustdoc::all))]
13

            
14
use attribute_derive::{Attribute, ConvertParsed};
15
use manyhow::{bail, error_message, manyhow, JoinToTokensError, Result};
16
use proc_macro2::{Span, TokenStream};
17
use proc_macro_crate::{crate_name, FoundCrate};
18
use quote::{quote_spanned, ToTokens};
19
use quote_use::{
20
    format_ident_namespaced as format_ident, parse_quote_use as parse_quote, quote_use as quote,
21
};
22
use syn::punctuated::Punctuated;
23
use syn::spanned::Spanned;
24
use syn::{
25
    parse, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
26
    FieldsUnnamed, Ident, Index, Path, Token, Type, TypePath, Variant,
27
};
28

            
29
mod view;
30

            
31
// -----------------------------------------------------------------------------
32
//     - Core Macros -
33
// -----------------------------------------------------------------------------
34

            
35
fn core_path() -> Path {
36
96
    match crate_name("bonsaidb")
37
96
        .or_else(|_| crate_name("bonsaidb_server"))
38
96
        .or_else(|_| crate_name("bonsaidb_local"))
39
96
        .or_else(|_| crate_name("bonsaidb_client"))
40
    {
41
96
        Ok(FoundCrate::Name(name)) => {
42
96
            let ident = Ident::new(&name, Span::call_site());
43
96
            parse_quote!(::#ident::core)
44
        }
45
        Ok(FoundCrate::Itself) => parse_quote!(crate::core),
46
        Err(_) => match crate_name("bonsaidb_core") {
47
            Ok(FoundCrate::Name(name)) => {
48
                let ident = Ident::new(&name, Span::call_site());
49
                parse_quote!(::#ident)
50
            }
51
            Ok(FoundCrate::Itself) => parse_quote!(crate),
52
            Err(_) => match () {
53
                () if cfg!(feature = "omnibus-path") => parse_quote!(::bonsaidb::core),
54
                () if cfg!(feature = "server-path") => parse_quote!(::bonsaidb_server::core),
55
                () if cfg!(feature = "local-path") => parse_quote!(::bonsaidb_local::core),
56
                () if cfg!(feature = "client-path") => parse_quote!(::bonsaidb_client::core),
57
                _ => parse_quote!(::bonsaidb_core),
58
            },
59
        },
60
    }
61
96
}
62

            
63
641
#[derive(Attribute)]
64
#[attribute(ident = collection)]
65
struct CollectionAttribute {
66
    authority: Option<Expr>,
67
    #[attribute(example = "\"name\"")]
68
    name: String,
69
    #[attribute(optional, example = "[SomeView, AnotherView]")]
70
    views: Vec<Type>,
71
    #[attribute(example = "Format or None")]
72
    serialization: Option<Path>,
73
    #[attribute(example = "Some(KeyId::Master)")]
74
    encryption_key: Option<Expr>,
75
    encryption_required: bool,
76
    encryption_optional: bool,
77
    #[attribute(example = "u64")]
78
    primary_key: Option<Type>,
79
    #[attribute(example = "self.0 or something(self)")]
80
    natural_id: Option<Expr>,
81
    #[attribute(example = "bosaidb::core")]
82
    core: Option<Path>,
83
}
84

            
85
/// Derives the `bonsaidb::core::schema::Collection` trait.
86
/// `#[collection(authority = "Authority", name = "Name", views = [a, b, c])]`
87
70
#[manyhow]
88
#[proc_macro_derive(Collection, attributes(collection, natural_id))]
89
70
pub fn collection_derive(input: proc_macro::TokenStream) -> Result {
90
    let DeriveInput {
91
70
        attrs,
92
70
        ident,
93
70
        generics,
94
70
        data,
95
        ..
96
70
    } = parse(input)?;
97

            
98
    let CollectionAttribute {
99
70
        authority,
100
70
        name,
101
70
        views,
102
70
        serialization,
103
70
        mut primary_key,
104
70
        mut natural_id,
105
70
        core,
106
70
        encryption_key,
107
70
        encryption_required,
108
70
        encryption_optional,
109
70
    } = CollectionAttribute::from_attributes(&attrs)?;
110

            
111
70
    if let Data::Struct(DataStruct { fields, .. }) = data {
112
70
        let mut previous: Option<syn::Attribute> = None;
113
        for (
114
183
            idx,
115
183
            Field {
116
183
                attrs, ident, ty, ..
117
            },
118
70
        ) in fields.into_iter().enumerate()
119
        {
120
183
            if let Some(attr) = attrs
121
183
                .into_iter()
122
183
                .find(|attr| attr.path().is_ident("natural_id"))
123
            {
124
6
                if let Some(previous) = &previous {
125
                    bail!(error_message!(attr,
126
                            "marked multiple fields as `natural_id`";
127
                            note="currently only one field can be marked as `natural_id`";
128
                            help="use `#[collection(natural_id=...)]` on the struct instead")
129
                    .join(error_message!(previous, "previous `natural_id`")));
130
6
                }
131
6
                if let Some(natural_id) = &natural_id {
132
                    bail!(error_message!(attr, "field marked as `natural_id` while `natural_id` expression is specified as well";
133
                            help = "remove `#[natural_id]` attribute on field")
134
                        .join(error_message!(natural_id, "`natural_id` expression is specified here")));
135
6
                }
136
6
                previous = Some(attr);
137
6
                let ident = if let Some(ident) = ident {
138
5
                    quote!(#ident)
139
                } else {
140
1
                    let idx = Index::from(idx);
141
1
                    quote_spanned!(ty.span()=> #idx)
142
                };
143
6
                natural_id = Some(parse_quote!(Some(Clone::clone(&self.#ident))));
144
6
                if primary_key.is_none() {
145
2
                    primary_key = Some(ty);
146
4
                }
147
177
            }
148
        }
149
    };
150

            
151
70
    if encryption_required && encryption_key.is_none() {
152
        bail!("If `collection(encryption_required)` is set you need to provide an encryption key via `collection(encryption_key = EncryptionKey)`")
153
70
    }
154
70

            
155
70
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
156
70

            
157
70
    let core = core.unwrap_or_else(core_path);
158
70

            
159
70
    let primary_key = primary_key.unwrap_or_else(|| parse_quote!(u64));
160

            
161
70
    let serialization = if matches!(&serialization, Some(serialization) if serialization.is_ident("None"))
162
    {
163
1
        if let Some(natural_id) = natural_id {
164
            bail!(
165
                natural_id,
166
                "`natural_id` must be manually implemented when using `serialization = None`"
167
            );
168
1
        }
169
1

            
170
1
        TokenStream::new()
171
    } else {
172
69
        let natural_id = natural_id.map(|natural_id| {
173
8
            quote!(
174
8
                fn natural_id(&self) -> Option<Self::PrimaryKey> {
175
8
                    #[allow(clippy::clone_on_copy)]
176
8
                    #natural_id
177
8
                }
178
8
            )
179
69
        });
180

            
181
69
        if let Some(serialization) = serialization {
182
2
            let serialization = if serialization.is_ident("Key") {
183
1
                quote!(#core::key::KeyFormat)
184
            } else {
185
1
                quote!(#serialization)
186
            };
187
2
            quote! {
188
2
                impl #impl_generics #core::schema::SerializedCollection for #ident #ty_generics #where_clause {
189
2
                    type Contents = #ident #ty_generics;
190
2
                    type Format = #serialization;
191
2

            
192
2
                    fn format() -> Self::Format {
193
2
                        #serialization::default()
194
2
                    }
195
2

            
196
2
                    #natural_id
197
2
                }
198
2
            }
199
        } else {
200
67
            quote! {
201
67
                impl #impl_generics #core::schema::DefaultSerialization for #ident #ty_generics #where_clause {
202
67
                    #natural_id
203
67
                }
204
67
            }
205
        }
206
    };
207

            
208
70
    let name = authority.map_or_else(
209
70
        || quote!(#core::schema::Qualified::private(#name)),
210
70
        |authority| quote!(#core::schema::Qualified::new(#authority, #name)),
211
70
    );
212
70

            
213
70
    let encryption = encryption_key.map(|encryption_key| {
214
13
        let encryption = if encryption_required || !encryption_optional {
215
2
            encryption_key.into_token_stream()
216
        } else {
217
11
            quote! {
218
11
                if #core::ENCRYPTION_ENABLED {
219
11
                    #encryption_key
220
11
                } else {
221
11
                    None
222
11
                }
223
11
            }
224
        };
225
13
        quote! {
226
13
            fn encryption_key() -> Option<#core::document::KeyId> {
227
13
                #encryption
228
13
            }
229
13
        }
230
70
    });
231

            
232
70
    Ok(quote! {
233
70
        impl #impl_generics #core::schema::Collection for #ident #ty_generics #where_clause {
234
70
            type PrimaryKey = #primary_key;
235
70

            
236
70
            fn collection_name() -> #core::schema::CollectionName {
237
70
                #name
238
70
            }
239
70
            fn define_views(schema: &mut #core::schema::Schematic) -> Result<(), #core::Error> {
240
70
                #( schema.define_view(#views)?; )*
241
70
                Ok(())
242
70
            }
243
70
            #encryption
244
70
        }
245
70
        #serialization
246
70
    })
247
}
248
/// Derives the `bonsaidb::core::schema::View` trait.
249
///
250
/// `#[view(collection=CollectionType, key=KeyType, value=ValueType, name = "by-name")]`
251
/// `name` and `value` are optional
252
57
#[manyhow]
253
#[proc_macro_derive(View, attributes(view))]
254
57
pub fn view_derive(input: proc_macro::TokenStream) -> Result {
255
57
    view::derive(parse(input)?)
256
}
257
/// Derives the `bonsaidb::core::schema::ViewSchema` trait.
258
51
#[manyhow]
259
/// `#[view_schema(version = 1, policy = Unique, view=ViewType, mapped_key=KeyType<'doc>)]`
260
///
261
/// All attributes are optional.
262
#[proc_macro_derive(ViewSchema, attributes(view_schema))]
263
51
pub fn view_schema_derive(input: proc_macro::TokenStream) -> Result {
264
51
    view::derive_schema(parse(input)?)
265
}
266

            
267
137
#[derive(Attribute)]
268
#[attribute(ident = schema)]
269
struct SchemaAttribute {
270
    #[attribute(example = "\"name\"")]
271
    name: String,
272
    #[attribute(example = "\"authority\"")]
273
    authority: Option<Expr>,
274
    #[attribute(optional, example = "[SomeCollection, AnotherCollection]")]
275
    collections: Vec<Type>,
276
    #[attribute(optional, example = "[SomeSchema, AnotherSchema]")]
277
    include: Vec<Type>,
278
    #[attribute(example = "bosaidb::core")]
279
    core: Option<Path>,
280
}
281

            
282
/// Derives the `bonsaidb::core::schema::Schema` trait.
283
///
284
/// `#[schema(name = "Name", authority = "Authority", collections = [A, B, C]), core = bonsaidb::core]`
285
/// `authority`, `collections` and `core` are optional
286
19
#[manyhow]
287
#[proc_macro_derive(Schema, attributes(schema))]
288
19
pub fn schema_derive(input: proc_macro::TokenStream) -> Result {
289
    let DeriveInput {
290
19
        attrs,
291
19
        ident,
292
19
        generics,
293
        ..
294
19
    } = parse(input)?;
295

            
296
    let SchemaAttribute {
297
19
        name,
298
19
        authority,
299
19
        collections,
300
19
        include,
301
19
        core,
302
19
    } = SchemaAttribute::from_attributes(&attrs)?;
303

            
304
19
    let core = core.unwrap_or_else(core_path);
305
19
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
306
19

            
307
19
    let name = authority.map_or_else(
308
19
        || quote!(#core::schema::Qualified::private(#name)),
309
19
        |authority| quote!(#core::schema::Qualified::new(#authority, #name)),
310
19
    );
311

            
312
19
    Ok(quote! {
313
19
        impl #impl_generics #core::schema::Schema for #ident #ty_generics #where_clause {
314
19
            fn schema_name() -> #core::schema::SchemaName {
315
19
                #name
316
19
            }
317
19

            
318
19
            fn define_collections(
319
19
                schema: &mut #core::schema::Schematic
320
19
            ) -> Result<(), #core::Error> {
321
19
                #( schema.define_collection::<#collections>()?; )*
322
19

            
323
19
                #( <#include as #core::schema::Schema>::define_collections(schema)?; )*
324
19

            
325
19
                Ok(())
326
19
            }
327
19
        }
328
19
    })
329
}
330

            
331
18
#[derive(Attribute)]
332
#[attribute(ident = key)]
333
struct KeyAttribute {
334
    #[attribute(example = "bosaidb::core")]
335
    core: Option<Path>,
336
    #[attribute(default = NullHandling::Escape, example = "escape")]
337
    null_handling: NullHandling,
338
    can_own_bytes: bool,
339
    #[attribute(example = "u8")]
340
    enum_repr: Option<Type>,
341
    #[attribute(example = "\"name\"")]
342
    name: Option<String>,
343
}
344

            
345
enum NullHandling {
346
    Escape,
347
    Allow,
348
    Deny,
349
}
350

            
351
impl ConvertParsed for NullHandling {
352
    type Type = Ident;
353

            
354
    fn convert(value: Self::Type) -> syn::Result<Self> {
355
        if value == "escape" {
356
            Ok(NullHandling::Escape)
357
        } else if value == "allow" {
358
            Ok(NullHandling::Allow)
359
        } else if value == "deny" {
360
            Ok(NullHandling::Deny)
361
        } else {
362
            Err(syn::Error::new(
363
                Span::call_site(),
364
                "only `escape`, `allow`, and `deny` are allowed for `null_handling`",
365
            ))
366
        }
367
    }
368
}
369

            
370
/// Derives the `bonsaidb::core::key::Key` trait.
371
///
372
/// `#[key(null_handling = escape, enum_repr = u8, core = bonsaidb::core)]`, all parameters are optional
373
13
#[manyhow]
374
#[proc_macro_derive(Key, attributes(key))]
375
13
pub fn key_derive(input: proc_macro::TokenStream) -> Result {
376
    let DeriveInput {
377
13
        attrs,
378
13
        ident,
379
13
        generics,
380
13
        data,
381
        ..
382
13
    } = parse(input)?;
383

            
384
    // Only relevant if it is an enum, gets the representation to use for the variant key
385
13
    let repr = attrs.iter().find_map(|attr| {
386
5
        attr.path()
387
5
            .is_ident("repr")
388
5
            .then(|| attr.parse_args::<Ident>().ok())
389
5
            .flatten()
390
5
            .and_then(|ident| {
391
2
                matches!(
392
2
                    ident.to_string().as_ref(),
393
2
                    "u8" | "u16"
394
1
                        | "u32"
395
1
                        | "u64"
396
                        | "u128"
397
                        | "usize"
398
                        | "i8"
399
                        | "i16"
400
                        | "i32"
401
                        | "i64"
402
                        | "i128"
403
                        | "isize"
404
                )
405
2
                .then(|| ident)
406
5
            })
407
13
    });
408

            
409
    let KeyAttribute {
410
13
        core,
411
13
        null_handling,
412
13
        enum_repr,
413
13
        can_own_bytes,
414
13
        name,
415
13
    } = KeyAttribute::from_attributes(&attrs)?;
416

            
417
13
    let name = name.map_or_else(
418
13
        || quote!(std::any::type_name::<Self>()),
419
13
        |name| quote!(#name),
420
13
    );
421

            
422
13
    if matches!(data, Data::Struct(_)) && enum_repr.is_some() {
423
        // TODO better span when attribute-derive supports that
424
        bail!(enum_repr, "`enum_repr` is only usable with enums")
425
13
    }
426
13

            
427
13
    let repr: Type = enum_repr.unwrap_or_else(|| {
428
12
        Type::Path(TypePath {
429
12
            qself: None,
430
12
            path: repr.unwrap_or_else(|| format_ident!("isize")).into(),
431
12
        })
432
13
    });
433

            
434
13
    let (encoder_constructor, decoder_constructor) = match null_handling {
435
13
        NullHandling::Escape => (quote!(default), quote!(default_for)),
436
        NullHandling::Allow => (quote!(allowing_null_bytes), quote!(allowing_null_bytes)),
437
        NullHandling::Deny => (quote!(denying_null_bytes), quote!(denying_null_bytes)),
438
    };
439

            
440
13
    let core = core.unwrap_or_else(core_path);
441
13
    let (_, ty_generics, _) = generics.split_for_impl();
442
13
    let mut generics = generics.clone();
443
13
    let lifetimes: Vec<_> = generics.lifetimes().cloned().collect();
444
13
    let where_clause = generics.make_where_clause();
445
15
    for lifetime in lifetimes {
446
2
        where_clause.predicates.push(parse_quote!($'key: #lifetime));
447
2
    }
448
13
    generics
449
13
        .params
450
13
        .push(syn::GenericParam::Lifetime(parse_quote!($'key)));
451
13
    let (impl_generics, _, where_clause) = generics.split_for_impl();
452

            
453
    // Special case the implementation for 1
454
    // field -- just pass through to the
455
    // inner type so that this encoding is
456
    // completely transparent.
457
3
    if let Some((name, ty, map)) = match &data {
458
4
        Data::Struct(DataStruct {
459
4
            fields: Fields::Named(FieldsNamed { named, .. }),
460
4
            ..
461
4
        }) if named.len() == 1 => {
462
2
            let name = &named[0].ident;
463
2
            Some((
464
2
                quote!(#name),
465
2
                named[0].ty.clone(),
466
2
                quote!(|value| Self { #name: value }),
467
2
            ))
468
        }
469
        Data::Struct(DataStruct {
470
3
            fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
471
3
            ..
472
3
        }) if unnamed.len() == 1 => Some((quote!(0), unnamed[0].ty.clone(), quote!(Self))),
473
10
        _ => None,
474
    } {
475
3
        return Ok(quote! {
476
3
            # use std::{borrow::Cow, io::{self, ErrorKind}};
477
3
            # use #core::key::{ByteSource, KeyVisitor, IncorrectByteLength, Key, KeyEncoding};
478
3

            
479
3
            impl #impl_generics Key<$'key> for #ident #ty_generics #where_clause {
480
3
                const CAN_OWN_BYTES: bool = <#ty>::CAN_OWN_BYTES;
481
3

            
482
3
                fn from_ord_bytes<$'b>(bytes: ByteSource<$'key, $'b>) -> Result<Self, Self::Error> {
483
3
                    <#ty>::from_ord_bytes(bytes).map(#map)
484
3
                }
485
3
            }
486
3

            
487
3
            impl #impl_generics KeyEncoding<Self> for #ident #ty_generics #where_clause {
488
3
                type Error = <#ty as KeyEncoding>::Error;
489
3

            
490
3
                const LENGTH: Option<usize> = <#ty>::LENGTH;
491
3

            
492
3
                fn describe<Visitor>(visitor: &mut Visitor)
493
3
                where
494
3
                    Visitor: KeyVisitor,
495
3
                {
496
3
                    <#ty>::describe(visitor)
497
3
                }
498
3

            
499
3
                fn as_ord_bytes(&self) -> Result<Cow<'_, [u8]>, Self::Error> {
500
3
                    self.#name.as_ord_bytes()
501
3
                }
502
3
            }
503
3
        });
504
10
    }
505

            
506
5
    let (encode_fields, decode_fields, describe, composite_kind, field_count): (
507
5
        TokenStream,
508
5
        TokenStream,
509
5
        TokenStream,
510
5
        TokenStream,
511
5
        usize,
512
10
    ) = match data {
513
5
        Data::Struct(DataStruct { fields, .. }) => {
514
5
            let (encode_fields, decode_fields, describe, field_count) = match fields {
515
2
                Fields::Named(FieldsNamed { named, .. }) => {
516
2
                    let field_count = named.len();
517
2
                    let (encode_fields, (decode_fields, describe)): (
518
2
                        TokenStream,
519
2
                        (TokenStream, TokenStream),
520
2
                    ) = named
521
2
                        .into_iter()
522
4
                        .map(|Field { ident, ty, .. }| {
523
4
                            let ident = ident.expect("named fields have idents");
524
4
                            (
525
4
                                quote!($encoder.encode(&self.#ident)?;),
526
4
                                (
527
4
                                    quote!(#ident: $decoder.decode()?,),
528
4
                                    quote!(<#ty>::describe(visitor);),
529
4
                                ),
530
4
                            )
531
4
                        })
532
2
                        .unzip();
533
2
                    (
534
2
                        encode_fields,
535
2
                        quote!( Self { #decode_fields }),
536
2
                        describe,
537
2
                        field_count,
538
2
                    )
539
                }
540
2
                Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
541
2
                    let field_count = unnamed.len();
542
2
                    let (encode_fields, (decode_fields, describe)): (
543
2
                        TokenStream,
544
2
                        (TokenStream, TokenStream),
545
2
                    ) = unnamed
546
2
                        .into_iter()
547
2
                        .enumerate()
548
5
                        .map(|(idx, field)| {
549
5
                            let ty = field.ty;
550
5
                            let idx = Index::from(idx);
551
5
                            (
552
5
                                quote!($encoder.encode(&self.#idx)?;),
553
5
                                (
554
5
                                    quote!($decoder.decode()?,),
555
5
                                    quote!(<#ty>::describe(visitor);),
556
5
                                ),
557
5
                            )
558
5
                        })
559
2
                        .unzip();
560
2
                    (
561
2
                        encode_fields,
562
2
                        quote!(Self(#decode_fields)),
563
2
                        describe,
564
2
                        field_count,
565
2
                    )
566
                }
567
                Fields::Unit => {
568
1
                    return Ok(quote! {
569
1
                        # use std::{borrow::Cow, io::{self, ErrorKind}};
570
1
                        # use #core::key::{ByteSource, KeyVisitor, IncorrectByteLength, Key, KeyKind, KeyEncoding};
571
1

            
572
1
                        impl #impl_generics Key<$'key> for #ident #ty_generics #where_clause {
573
1
                            const CAN_OWN_BYTES: bool = false;
574
1

            
575
1
                            fn from_ord_bytes<$'b>(bytes: ByteSource<$'key, $'b>) -> Result<Self, Self::Error> {
576
1
                                Ok(Self)
577
1
                            }
578
1
                        }
579
1

            
580
1
                        impl #impl_generics KeyEncoding<Self> for #ident #ty_generics #where_clause {
581
1
                            type Error = std::convert::Infallible;
582
1

            
583
1
                            const LENGTH: Option<usize> = Some(0);
584
1

            
585
1
                            fn describe<Visitor>(visitor: &mut Visitor)
586
1
                            where
587
1
                                Visitor: KeyVisitor,
588
1
                            {
589
1
                                visitor.visit_type(KeyKind::Unit);
590
1
                            }
591
1

            
592
1
                            fn as_ord_bytes(&self) -> Result<Cow<'_, [u8]>, Self::Error> {
593
1
                                Ok(Cow::Borrowed(&[]))
594
1
                            }
595
1
                        }
596
1
                    })
597
                }
598
            };
599
4
            (
600
4
                encode_fields,
601
4
                quote!(let $self_ = #decode_fields;),
602
4
                describe,
603
4
                quote!(#core::key::CompositeKind::Struct(std::borrow::Cow::Borrowed(#name))),
604
4
                field_count,
605
4
            )
606
        }
607
5
        Data::Enum(DataEnum { variants, .. }) => {
608
5
            let mut prev_ident = None;
609
5
            let field_count = variants.len();
610
11
            let all_variants_are_empty = variants.iter().all(|variant| variant.fields.is_empty());
611
5

            
612
5
            let (consts, (encode_variants, (decode_variants, describe))): (
613
5
                TokenStream,
614
5
                (TokenStream, (TokenStream, TokenStream)),
615
5
            ) = variants
616
5
                .into_iter()
617
5
                .enumerate()
618
5
                .map(
619
5
                    |(
620
                        idx,
621
                        Variant {
622
                            fields,
623
                            ident,
624
                            discriminant,
625
                            ..
626
                        },
627
12
                    )| {
628
12
                        let discriminant = discriminant.map_or_else(
629
12
                            || {
630
6
                                prev_ident
631
6
                                    .as_ref()
632
6
                                    .map_or_else(|| quote!(0), |ident| quote!(#ident + 1))
633
12
                            },
634
12
                            |(_, expr)| expr.to_token_stream(),
635
12
                        );
636
12

            
637
12
                        let const_ident = format_ident!("$discriminant{idx}");
638
12
                        let const_ = quote!(const #const_ident: #repr = #discriminant;);
639

            
640
12
                        let ret = (
641
12
                            const_,
642
12
                            match fields {
643
1
                                Fields::Named(FieldsNamed { named, .. }) => {
644
1
                                    let (idents, (encode_fields, (decode_fields, describe))): (
645
1
                                        Punctuated<_, Token![,]>,
646
1
                                        (TokenStream, (TokenStream, TokenStream)),
647
1
                                    ) = named
648
1
                                        .into_iter()
649
2
                                        .map(|Field { ident, ty, .. }| {
650
2
                                            let ident = ident.expect("named fields have idents");
651
2
                                            (
652
2
                                                ident.clone(),
653
2
                                                (
654
2
                                                    quote!($encoder.encode(#ident)?;),
655
2
                                                    (
656
2
                                                        quote!(#ident: $decoder.decode()?,),
657
2
                                                        quote!(<#ty>::describe(visitor);),
658
2
                                                    ),
659
2
                                                ),
660
2
                                            )
661
2
                                        })
662
1
                                        .unzip();
663
1
                                    (
664
1
                                        quote! {
665
1
                                            Self::#ident{#idents} => {
666
1
                                                $encoder.encode(&#const_ident)?;
667
1
                                                #encode_fields
668
1
                                            },
669
1
                                        },
670
1
                                        (
671
1
                                            quote! {
672
1
                                                #const_ident => Self::#ident{#decode_fields},
673
1
                                            },
674
1
                                            describe,
675
1
                                        ),
676
1
                                    )
677
                                }
678
1
                                Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
679
1
                                    let (idents, (encode_fields, (decode_fields, describe))): (
680
1
                                        Punctuated<_, Token![,]>,
681
1
                                        (TokenStream, (TokenStream, TokenStream)),
682
1
                                    ) = unnamed
683
1
                                        .into_iter()
684
1
                                        .enumerate()
685
2
                                        .map(|(idx, field)| {
686
2
                                            let ident = format_ident!("$field_{idx}");
687
2
                                            let ty = field.ty;
688
2
                                            (
689
2
                                                ident.clone(),
690
2
                                                (
691
2
                                                    quote!($encoder.encode(#ident)?;),
692
2
                                                    (
693
2
                                                        quote!($decoder.decode()?,),
694
2
                                                        quote!(<#ty>::describe(visitor);),
695
2
                                                    ),
696
2
                                                ),
697
2
                                            )
698
2
                                        })
699
1
                                        .unzip();
700
1
                                    (
701
1
                                        quote! {
702
1
                                            Self::#ident(#idents) => {
703
1
                                                $encoder.encode(&#const_ident)?;
704
1
                                                #encode_fields
705
1
                                            },
706
1
                                        },
707
1
                                        (
708
1
                                            quote! {
709
1
                                                #const_ident => Self::#ident(#decode_fields),
710
1
                                            },
711
1
                                            describe,
712
1
                                        ),
713
1
                                    )
714
                                }
715
                                Fields::Unit => {
716
10
                                    let encode = if all_variants_are_empty {
717
9
                                        quote!(Self::#ident => #const_ident.as_ord_bytes(),)
718
                                    } else {
719
1
                                        quote!(Self::#ident => $encoder.encode(&#const_ident)?,)
720
                                    };
721
10
                                    (
722
10
                                        encode,
723
10
                                        (
724
10
                                            quote!(#const_ident => Self::#ident,),
725
10
                                            quote!(visitor.visit_type(#core::key::KeyKind::Unit);),
726
10
                                        ),
727
10
                                    )
728
                                }
729
                            },
730
                        );
731
12
                        prev_ident = Some(const_ident);
732
12
                        ret
733
12
                    },
734
5
                )
735
5
                .unzip();
736
5

            
737
5
            if all_variants_are_empty {
738
                // Special case: if no enum variants have embedded values,
739
                // implement Key as a plain value, avoiding the composite key
740
                // overhead.
741
4
                return Ok(quote! {
742
4
                    # use std::{borrow::Cow, io::{self, ErrorKind}};
743
4
                    # use #core::key::{ByteSource, CompositeKeyDecoder, KeyVisitor, CompositeKeyEncoder, CompositeKeyError, Key, KeyEncoding};
744
4

            
745
4
                    impl #impl_generics Key<$'key> for #ident #ty_generics #where_clause {
746
4
                        const CAN_OWN_BYTES: bool = false;
747
4

            
748
4
                        fn from_ord_bytes<$'b>(mut $bytes: ByteSource<$'key, $'b>) -> Result<Self, Self::Error> {
749
4
                            #consts
750
4
                            Ok(match <#repr>::from_ord_bytes($bytes).map_err(#core::key::CompositeKeyError::new)? {
751
4
                                #decode_variants
752
4
                                _ => return Err(#core::key::CompositeKeyError::from(io::Error::from(
753
4
                                        ErrorKind::InvalidData,
754
4
                                )))
755
4
                            })
756
4
                        }
757
4
                    }
758
4

            
759
4
                    impl #impl_generics KeyEncoding<Self> for #ident #ty_generics #where_clause {
760
4
                        type Error = CompositeKeyError;
761
4

            
762
4
                        const LENGTH: Option<usize> = <#repr as KeyEncoding>::LENGTH;
763
4

            
764
4
                        fn describe<Visitor>(visitor: &mut Visitor)
765
4
                        where
766
4
                            Visitor: KeyVisitor,
767
4
                        {
768
4
                            <#repr>::describe(visitor);
769
4
                        }
770
4

            
771
4
                        fn as_ord_bytes(&self) -> Result<Cow<'_, [u8]>, Self::Error> {
772
4
                            #consts
773
4
                            match self {
774
4
                                #encode_variants
775
4
                            }.map_err(#core::key::CompositeKeyError::new)
776
4
                        }
777
4
                    }
778
4
                });
779
1
            }
780
1

            
781
1
            // At least one variant has a value, which means we need to encode a composite field.
782
1
            (
783
1
                quote! {
784
1
                    #consts
785
1
                    match self{
786
1
                        #encode_variants
787
1
                    }
788
1
                },
789
1
                quote! {
790
1
                    # use std::io::{self, ErrorKind};
791
1
                    #consts
792
1
                    let $self_ = match $decoder.decode::<#repr>()? {
793
1
                        #decode_variants
794
1
                        _ => return Err(#core::key::CompositeKeyError::from(io::Error::from(
795
1
                                ErrorKind::InvalidData,
796
1
                        )))
797
1
                    };
798
1
                },
799
1
                describe,
800
1
                quote!(#core::key::CompositeKind::Tuple),
801
1
                field_count,
802
1
            )
803
        }
804
        Data::Union(_) => bail!("unions are not supported"),
805
    };
806

            
807
5
    Ok(quote! {
808
5
        # use std::{borrow::Cow, io::{self, ErrorKind}};
809
5
        # use #core::key::{ByteSource, CompositeKeyDecoder, KeyVisitor, CompositeKeyEncoder, CompositeKeyError, Key, KeyEncoding};
810
5

            
811
5
        impl #impl_generics Key<$'key> for #ident #ty_generics #where_clause {
812
5
            const CAN_OWN_BYTES: bool = #can_own_bytes;
813
5

            
814
5
            fn from_ord_bytes<$'b>(mut $bytes: ByteSource<$'key, $'b>) -> Result<Self, Self::Error> {
815
5

            
816
5
                let mut $decoder = CompositeKeyDecoder::#decoder_constructor($bytes);
817
5

            
818
5
                #decode_fields
819
5

            
820
5
                $decoder.finish()?;
821
5

            
822
5
                Ok($self_)
823
5
            }
824
5
        }
825
5

            
826
5
        impl #impl_generics KeyEncoding<Self> for #ident #ty_generics #where_clause {
827
5
            type Error = CompositeKeyError;
828
5

            
829
5
            // TODO fixed width if possible
830
5
            const LENGTH: Option<usize> = None;
831
5

            
832
5
            fn describe<Visitor>(visitor: &mut Visitor)
833
5
            where
834
5
                Visitor: KeyVisitor,
835
5
            {
836
5
                visitor.visit_composite(#composite_kind, #field_count);
837
5
                #describe
838
5
            }
839
5

            
840
5
            fn as_ord_bytes(&self) -> Result<Cow<'_, [u8]>, Self::Error> {
841
5
                let mut $encoder = CompositeKeyEncoder::#encoder_constructor();
842
5

            
843
5
                #encode_fields
844
5

            
845
5
                Ok(Cow::Owned($encoder.finish()))
846
5
            }
847
5
        }
848
5
    })
849
}
850

            
851
40
#[derive(Attribute)]
852
#[attribute(ident = api)]
853
struct ApiAttribute {
854
    #[attribute(example = "\"name\"")]
855
    name: String,
856
    #[attribute(example = "\"authority\"")]
857
    authority: Option<Expr>,
858
    #[attribute(example = "ResponseType")]
859
    response: Option<Type>,
860
    #[attribute(example = "ErrorType")]
861
    error: Option<Type>,
862
    #[attribute(example = "bosaidb::core")]
863
    core: Option<Path>,
864
}
865

            
866
/// Derives the `bonsaidb::core::api::Api` trait.
867
///
868
/// `#[api(name = "Name", authority = "Authority", response = ResponseType, error = ErrorType, core = bonsaidb::core)]`
869
/// `authority`, `response`, `error` and `core` are optional
870
8
#[manyhow]
871
#[proc_macro_derive(Api, attributes(api))]
872
8
pub fn api_derive(input: proc_macro::TokenStream) -> Result {
873
    let DeriveInput {
874
8
        attrs,
875
8
        ident,
876
8
        generics,
877
        ..
878
8
    } = parse(input)?;
879

            
880
    let ApiAttribute {
881
8
        name,
882
8
        authority,
883
8
        response,
884
8
        error,
885
8
        core,
886
8
    } = ApiAttribute::from_attributes(&attrs)?;
887

            
888
8
    let core = core.unwrap_or_else(core_path);
889
8
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
890
8

            
891
8
    let name = authority.map_or_else(
892
8
        || quote!(#core::schema::Qualified::private(#name)),
893
8
        |authority| quote!(#core::schema::Qualified::new(#authority, #name)),
894
8
    );
895
8

            
896
8
    let response = response.unwrap_or_else(|| parse_quote!(()));
897
8
    let error = error.unwrap_or_else(|| parse_quote!(#core::api::Infallible));
898
8

            
899
8
    Ok(quote! {
900
8
        # use #core::api::{Api, ApiName};
901
8

            
902
8
        impl #impl_generics Api for #ident #ty_generics #where_clause {
903
8
            type Response = #response;
904
8
            type Error = #error;
905
8

            
906
8
            fn name() -> ApiName {
907
8
                #name
908
8
            }
909
8
        }
910
8
    })
911
}
912

            
913
// -----------------------------------------------------------------------------
914
//     - File Macros -
915
// -----------------------------------------------------------------------------
916

            
917
fn files_path() -> Path {
918
5
    match crate_name("bonsaidb") {
919
5
        Ok(FoundCrate::Name(name)) => {
920
5
            let ident = Ident::new(&name, Span::call_site());
921
5
            parse_quote!(::#ident::files)
922
        }
923
        Ok(FoundCrate::Itself) => parse_quote!(crate::files),
924
        Err(_) => match crate_name("bonsaidb_files") {
925
            Ok(FoundCrate::Name(name)) => {
926
                let ident = Ident::new(&name, Span::call_site());
927
                parse_quote!(::#ident)
928
            }
929
            Ok(FoundCrate::Itself) => parse_quote!(crate),
930
            Err(_) if cfg!(feature = "omnibus-path") => parse_quote!(::bonsaidb::files),
931
            Err(_) => parse_quote!(::bonsaidb_core),
932
        },
933
    }
934
5
}
935

            
936
18
#[derive(Attribute)]
937
#[attribute(ident = file_config)]
938
struct FileConfigAttribute {
939
    #[attribute(example = "MetadataType")]
940
    metadata: Option<Type>,
941
    #[attribute(example = "65_536")]
942
    block_size: Option<usize>,
943
    #[attribute(example = "\"authority\"")]
944
    authority: Option<Expr>,
945
    #[attribute(example = "\"files\"")]
946
    files_name: Option<String>,
947
    #[attribute(example = "\"blocks\"")]
948
    blocks_name: Option<String>,
949
    #[attribute(example = "bosaidb::core")]
950
    core: Option<Path>,
951
    #[attribute(example = "bosaidb::files")]
952
    files: Option<Path>,
953
}
954

            
955
/// Derives the `bonsaidb::files::FileConfig` trait.
956
///
957
/// `#[api(metadata = MetadataType, block_size = 65_536, authority = "authority", files_name = "files", blocks_name = "blocks", core = bonsaidb::core, files = bosaidb::files)]`
958
/// all arguments are optional
959
5
#[manyhow]
960
#[proc_macro_derive(FileConfig, attributes(file_config))]
961
5
pub fn file_config_derive(input: proc_macro::TokenStream) -> Result {
962
    let DeriveInput {
963
5
        attrs,
964
5
        ident,
965
5
        generics,
966
        ..
967
5
    } = parse(input)?;
968

            
969
    let FileConfigAttribute {
970
5
        metadata,
971
5
        block_size,
972
5
        authority,
973
5
        files_name,
974
5
        blocks_name,
975
5
        core,
976
5
        files,
977
5
    } = FileConfigAttribute::from_attributes(&attrs)?;
978

            
979
5
    let core = core.unwrap_or_else(core_path);
980
5
    let files = files.unwrap_or_else(files_path);
981
5
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
982

            
983
5
    let (files_name, blocks_name) = match (authority, files_name, blocks_name) {
984
3
        (None, None, None) => (
985
3
            quote!(#files::BonsaiFiles::files_name()),
986
3
            quote!(#files::BonsaiFiles::blocks_name()),
987
3
        ),
988
1
        (Some(authority), Some(files_name), Some(blocks_name)) => (
989
1
            quote!(#core::schema::Qualified::new(#authority, #files_name)),
990
1
            quote!(#core::schema::Qualified::new(#authority, #blocks_name)),
991
1
        ),
992
1
        (None, Some(files_name), Some(blocks_name)) => (
993
1
            quote!(#core::schema::Qualified::private(#files_name)),
994
1
            quote!(#core::schema::Qualified::private(#blocks_name)),
995
1
        ),
996
        (Some(_), ..) => bail!(
997
            "if `authority` is specified, `files_name` and `blocks_name need to be provided as well"
998
        ),
999
        (_, Some(_), _) => {
            bail!("if `files_name` is specified, `blocks_name` needs to be provided as well")
        }
        (_, _, Some(_)) => {
            bail!("if `blocks_name` is specified, `files_name` needs to be provided as well")
        }
    };

            
5
    let metadata = metadata
5
        .unwrap_or_else(|| parse_quote!(<#files::BonsaiFiles as #files::FileConfig>::Metadata));
5
    let block_size = block_size.map_or_else(
5
        || quote!(<#files::BonsaiFiles as #files::FileConfig>::BLOCK_SIZE),
5
        |block_size| quote!(#block_size),
5
    );
5

            
5
    Ok(quote! {
5
        # use #files::FileConfig;
5
        # use #core::schema::CollectionName;
5

            
5
        impl #impl_generics FileConfig for #ident #ty_generics #where_clause {
5
            type Metadata = #metadata;
5
            const BLOCK_SIZE: usize = #block_size;
5

            
5
            fn files_name() -> CollectionName {
5
                #files_name
5
            }
5

            
5
            fn blocks_name() -> CollectionName {
5
                #blocks_name
5
            }
5
        }
5
    })
}

            
1
#[test]
1
fn ui() {
1
    use trybuild::TestCases;
1

            
1
    TestCases::new().compile_fail("tests/ui/*/*.rs");
1
}