1
//! Macros BonsaiDb.
2

            
3
#![forbid(unsafe_code)]
4
#![warn(
5
    clippy::cargo,
6
    missing_docs,
7
    // clippy::missing_docs_in_private_items,
8
    clippy::pedantic,
9
    future_incompatible,
10
    rust_2018_idioms,
11
)]
12
#![cfg_attr(doc, deny(rustdoc::all))]
13

            
14
use attribute_derive::parsing::{AttributeBase, AttributeValue, SpannedValue};
15
use attribute_derive::{Attribute, FromAttr};
16
use manyhow::{bail, error_message, manyhow, JoinToTokensError, Result};
17
use proc_macro2::{Span, TokenStream};
18
use proc_macro_crate::{crate_name, FoundCrate};
19
use quote::{quote_spanned, ToTokens};
20
use quote_use::{
21
    format_ident_namespaced as format_ident, parse_quote_use as parse_quote, quote_use as quote,
22
};
23
use syn::parse::ParseStream;
24
use syn::punctuated::Punctuated;
25
use syn::spanned::Spanned;
26
use syn::{
27
    parse, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
28
    FieldsUnnamed, Ident, Index, Path, Token, Type, TypePath, Variant,
29
};
30

            
31
mod view;
32

            
33
// -----------------------------------------------------------------------------
34
//     - Core Macros -
35
// -----------------------------------------------------------------------------
36

            
37
96
fn core_path() -> Path {
38
96
    match crate_name("bonsaidb")
39
96
        .or_else(|_| crate_name("bonsaidb_server"))
40
96
        .or_else(|_| crate_name("bonsaidb_local"))
41
96
        .or_else(|_| crate_name("bonsaidb_client"))
42
    {
43
96
        Ok(FoundCrate::Name(name)) => {
44
96
            let ident = Ident::new(&name, Span::call_site());
45
96
            parse_quote!(::#ident::core)
46
        }
47
        Ok(FoundCrate::Itself) => parse_quote!(crate::core),
48
        Err(_) => match crate_name("bonsaidb_core") {
49
            Ok(FoundCrate::Name(name)) => {
50
                let ident = Ident::new(&name, Span::call_site());
51
                parse_quote!(::#ident)
52
            }
53
            Ok(FoundCrate::Itself) => parse_quote!(crate),
54
            Err(_) => match () {
55
                () if cfg!(feature = "omnibus-path") => parse_quote!(::bonsaidb::core),
56
                () if cfg!(feature = "server-path") => parse_quote!(::bonsaidb_server::core),
57
                () if cfg!(feature = "local-path") => parse_quote!(::bonsaidb_local::core),
58
                () if cfg!(feature = "client-path") => parse_quote!(::bonsaidb_client::core),
59
                () => parse_quote!(::bonsaidb_core),
60
            },
61
        },
62
    }
63
96
}
64

            
65
1574
#[derive(FromAttr)]
66
#[attribute(ident = collection)]
67
struct CollectionAttribute {
68
    authority: Option<Expr>,
69
    #[attribute(example = "\"name\"")]
70
    name: String,
71
    #[attribute(optional, example = "[SomeView, AnotherView]")]
72
    views: Vec<Type>,
73
    #[attribute(example = "Format or None")]
74
    serialization: Option<Path>,
75
    #[attribute(example = "Some(KeyId::Master)")]
76
    encryption_key: Option<Expr>,
77
    encryption_required: bool,
78
    encryption_optional: bool,
79
    #[attribute(example = "u64")]
80
    primary_key: Option<Type>,
81
    #[attribute(example = "self.0 or something(self)")]
82
    natural_id: Option<Expr>,
83
    #[attribute(example = "bosaidb::core")]
84
    core: Option<Path>,
85
}
86

            
87
/// Derives the `bonsaidb::core::schema::Collection` trait.
88
/// `#[collection(authority = "Authority", name = "Name", views = [a, b, c])]`
89
70
#[manyhow]
90
#[proc_macro_derive(Collection, attributes(collection, natural_id))]
91
70
pub fn collection_derive(input: proc_macro::TokenStream) -> Result {
92
    let DeriveInput {
93
70
        attrs,
94
70
        ident,
95
70
        generics,
96
70
        data,
97
        ..
98
70
    } = parse(input)?;
99

            
100
    let CollectionAttribute {
101
70
        authority,
102
70
        name,
103
70
        views,
104
70
        serialization,
105
70
        mut primary_key,
106
70
        mut natural_id,
107
70
        core,
108
70
        encryption_key,
109
70
        encryption_required,
110
70
        encryption_optional,
111
70
    } = CollectionAttribute::from_attributes(&attrs)?;
112

            
113
70
    if let Data::Struct(DataStruct { fields, .. }) = data {
114
70
        let mut previous: Option<syn::Attribute> = None;
115
        for (
116
183
            idx,
117
183
            Field {
118
183
                attrs, ident, ty, ..
119
            },
120
70
        ) in fields.into_iter().enumerate()
121
        {
122
183
            if let Some(attr) = attrs
123
183
                .into_iter()
124
183
                .find(|attr| attr.path().is_ident("natural_id"))
125
            {
126
6
                if let Some(previous) = &previous {
127
                    bail!(error_message!(attr,
128
                            "marked multiple fields as `natural_id`";
129
                            note="currently only one field can be marked as `natural_id`";
130
                            help="use `#[collection(natural_id=...)]` on the struct instead")
131
                    .join(error_message!(previous, "previous `natural_id`")));
132
6
                }
133
6
                if let Some(natural_id) = &natural_id {
134
                    bail!(error_message!(attr, "field marked as `natural_id` while `natural_id` expression is specified as well";
135
                            help = "remove `#[natural_id]` attribute on field")
136
                        .join(error_message!(natural_id, "`natural_id` expression is specified here")));
137
6
                }
138
6
                previous = Some(attr);
139
6
                let ident = if let Some(ident) = ident {
140
5
                    quote!(#ident)
141
                } else {
142
1
                    let idx = Index::from(idx);
143
1
                    quote_spanned!(ty.span()=> #idx)
144
                };
145
6
                natural_id = Some(parse_quote!(Some(Clone::clone(&self.#ident))));
146
6
                if primary_key.is_none() {
147
2
                    primary_key = Some(ty);
148
4
                }
149
177
            }
150
        }
151
    };
152

            
153
70
    if encryption_required && encryption_key.is_none() {
154
        bail!("If `collection(encryption_required)` is set you need to provide an encryption key via `collection(encryption_key = EncryptionKey)`")
155
70
    }
156
70

            
157
70
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
158
70

            
159
70
    let core = core.unwrap_or_else(core_path);
160
70

            
161
70
    let primary_key = primary_key.unwrap_or_else(|| parse_quote!(u64));
162

            
163
70
    let serialization = if matches!(&serialization, Some(serialization) if serialization.is_ident("None"))
164
    {
165
1
        if let Some(natural_id) = natural_id {
166
            bail!(
167
                natural_id,
168
                "`natural_id` must be manually implemented when using `serialization = None`"
169
            );
170
1
        }
171
1

            
172
1
        TokenStream::new()
173
    } else {
174
69
        let natural_id = natural_id.map(|natural_id| {
175
8
            quote!(
176
8
                fn natural_id(&self) -> Option<Self::PrimaryKey> {
177
8
                    #[allow(clippy::clone_on_copy)]
178
8
                    #natural_id
179
8
                }
180
8
            )
181
69
        });
182

            
183
69
        if let Some(serialization) = serialization {
184
2
            let serialization = if serialization.is_ident("Key") {
185
1
                quote!(#core::key::KeyFormat)
186
            } else {
187
1
                quote!(#serialization)
188
            };
189
2
            quote! {
190
2
                impl #impl_generics #core::schema::SerializedCollection for #ident #ty_generics #where_clause {
191
2
                    type Contents = #ident #ty_generics;
192
2
                    type Format = #serialization;
193
2

            
194
2
                    fn format() -> Self::Format {
195
2
                        #serialization::default()
196
2
                    }
197
2

            
198
2
                    #natural_id
199
2
                }
200
2
            }
201
        } else {
202
67
            quote! {
203
67
                impl #impl_generics #core::schema::DefaultSerialization for #ident #ty_generics #where_clause {
204
67
                    #natural_id
205
67
                }
206
67
            }
207
        }
208
    };
209

            
210
70
    let name = authority.map_or_else(
211
70
        || quote!(#core::schema::Qualified::private(#name)),
212
70
        |authority| quote!(#core::schema::Qualified::new(#authority, #name)),
213
70
    );
214
70

            
215
70
    let encryption = encryption_key.map(|encryption_key| {
216
13
        let encryption = if encryption_required || !encryption_optional {
217
2
            encryption_key.into_token_stream()
218
        } else {
219
11
            quote! {
220
11
                if #core::ENCRYPTION_ENABLED {
221
11
                    #encryption_key
222
11
                } else {
223
11
                    None
224
11
                }
225
11
            }
226
        };
227
13
        quote! {
228
13
            fn encryption_key() -> Option<#core::document::KeyId> {
229
13
                #encryption
230
13
            }
231
13
        }
232
70
    });
233

            
234
70
    Ok(quote! {
235
70
        impl #impl_generics #core::schema::Collection for #ident #ty_generics #where_clause {
236
70
            type PrimaryKey = #primary_key;
237
70

            
238
70
            fn collection_name() -> #core::schema::CollectionName {
239
70
                #name
240
70
            }
241
70
            fn define_views(schema: &mut #core::schema::Schematic) -> Result<(), #core::Error> {
242
70
                #( schema.define_view(#views)?; )*
243
70
                Ok(())
244
70
            }
245
70
            #encryption
246
70
        }
247
70
        #serialization
248
70
    })
249
}
250
/// Derives the `bonsaidb::core::schema::View` trait.
251
///
252
/// `#[view(collection=CollectionType, key=KeyType, value=ValueType, name = "by-name")]`
253
/// `name` and `value` are optional
254
57
#[manyhow]
255
#[proc_macro_derive(View, attributes(view))]
256
57
pub fn view_derive(input: proc_macro::TokenStream) -> Result {
257
57
    view::derive(parse(input)?)
258
}
259
/// Derives the `bonsaidb::core::schema::ViewSchema` trait.
260
51
#[manyhow]
261
/// `#[view_schema(version = 1, policy = Unique, view=ViewType, mapped_key=KeyType<'doc>)]`
262
///
263
/// All attributes are optional.
264
#[proc_macro_derive(ViewSchema, attributes(view_schema))]
265
51
pub fn view_schema_derive(input: proc_macro::TokenStream) -> Result {
266
51
    view::derive_schema(parse(input)?)
267
}
268

            
269
199
#[derive(FromAttr)]
270
#[attribute(ident = schema)]
271
struct SchemaAttribute {
272
    #[attribute(example = "\"name\"")]
273
    name: String,
274
    #[attribute(example = "\"authority\"")]
275
    authority: Option<Expr>,
276
    #[attribute(optional, example = "[SomeCollection, AnotherCollection]")]
277
    collections: Vec<Type>,
278
    #[attribute(optional, example = "[SomeSchema, AnotherSchema]")]
279
    include: Vec<Type>,
280
    #[attribute(example = "bosaidb::core")]
281
    core: Option<Path>,
282
}
283

            
284
/// Derives the `bonsaidb::core::schema::Schema` trait.
285
///
286
/// `#[schema(name = "Name", authority = "Authority", collections = [A, B, C]), core = bonsaidb::core]`
287
/// `authority`, `collections` and `core` are optional
288
19
#[manyhow]
289
#[proc_macro_derive(Schema, attributes(schema))]
290
19
pub fn schema_derive(input: proc_macro::TokenStream) -> Result {
291
    let DeriveInput {
292
19
        attrs,
293
19
        ident,
294
19
        generics,
295
        ..
296
19
    } = parse(input)?;
297

            
298
    let SchemaAttribute {
299
19
        name,
300
19
        authority,
301
19
        collections,
302
19
        include,
303
19
        core,
304
19
    } = SchemaAttribute::from_attributes(&attrs)?;
305

            
306
19
    let core = core.unwrap_or_else(core_path);
307
19
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
308
19

            
309
19
    let name = authority.map_or_else(
310
19
        || quote!(#core::schema::Qualified::private(#name)),
311
19
        |authority| quote!(#core::schema::Qualified::new(#authority, #name)),
312
19
    );
313

            
314
19
    Ok(quote! {
315
19
        impl #impl_generics #core::schema::Schema for #ident #ty_generics #where_clause {
316
19
            fn schema_name() -> #core::schema::SchemaName {
317
19
                #name
318
19
            }
319
19

            
320
19
            fn define_collections(
321
19
                schema: &mut #core::schema::Schematic
322
19
            ) -> Result<(), #core::Error> {
323
19
                #( schema.define_collection::<#collections>()?; )*
324
19

            
325
19
                #( <#include as #core::schema::Schema>::define_collections(schema)?; )*
326
19

            
327
19
                Ok(())
328
19
            }
329
19
        }
330
19
    })
331
}
332

            
333
28
#[derive(FromAttr)]
334
#[attribute(ident = key)]
335
struct KeyAttribute {
336
    #[attribute(example = "bosaidb::core")]
337
    core: Option<Path>,
338
13
    #[attribute(default = NullHandling::Escape, example = "escape")]
339
    null_handling: NullHandling,
340
    can_own_bytes: bool,
341
    #[attribute(example = "u8")]
342
    enum_repr: Option<Type>,
343
    #[attribute(example = "\"name\"")]
344
    name: Option<String>,
345
}
346

            
347
enum NullHandling {
348
    Escape,
349
    Allow,
350
    Deny,
351
}
352

            
353
impl AttributeBase for NullHandling {
354
    type Partial = Self;
355
}
356

            
357
impl AttributeValue for NullHandling {
358
    fn parse_value(input: ParseStream<'_>) -> syn::Result<SpannedValue<Self>> {
359
        let ident: Ident = input.parse()?;
360

            
361
        Ok(SpannedValue::new(
362
            match ident.to_string().as_str() {
363
                "escape" => NullHandling::Escape,
364
                "allow" => NullHandling::Allow,
365
                "deny" => NullHandling::Deny,
366
                _ => {
367
                    return Err(syn::Error::new(
368
                        Span::call_site(),
369
                        "only `escape`, `allow`, and `deny` are allowed for `null_handling`",
370
                    ))
371
                }
372
            },
373
            ident.span(),
374
        ))
375
    }
376
}
377

            
378
/// Derives the `bonsaidb::core::key::Key` trait.
379
///
380
/// `#[key(null_handling = escape, enum_repr = u8, core = bonsaidb::core)]`, all parameters are optional
381
13
#[manyhow]
382
#[proc_macro_derive(Key, attributes(key))]
383
13
pub fn key_derive(input: proc_macro::TokenStream) -> Result {
384
    let DeriveInput {
385
13
        attrs,
386
13
        ident,
387
13
        generics,
388
13
        data,
389
        ..
390
13
    } = parse(input)?;
391

            
392
    // Only relevant if it is an enum, gets the representation to use for the variant key
393
13
    let repr = attrs.iter().find_map(|attr| {
394
5
        attr.path()
395
5
            .is_ident("repr")
396
5
            .then(|| attr.parse_args::<Ident>().ok())
397
5
            .flatten()
398
5
            .and_then(|ident| {
399
2
                matches!(
400
2
                    ident.to_string().as_ref(),
401
2
                    "u8" | "u16"
402
1
                        | "u32"
403
1
                        | "u64"
404
                        | "u128"
405
                        | "usize"
406
                        | "i8"
407
                        | "i16"
408
                        | "i32"
409
                        | "i64"
410
                        | "i128"
411
                        | "isize"
412
                )
413
2
                .then(|| ident)
414
5
            })
415
13
    });
416

            
417
    let KeyAttribute {
418
13
        core,
419
13
        null_handling,
420
13
        enum_repr,
421
13
        can_own_bytes,
422
13
        name,
423
13
    } = KeyAttribute::from_attributes(&attrs)?;
424

            
425
13
    let name = name.map_or_else(
426
13
        || quote!(std::any::type_name::<Self>()),
427
13
        |name| quote!(#name),
428
13
    );
429

            
430
13
    if matches!(data, Data::Struct(_)) && enum_repr.is_some() {
431
        // TODO better span when attribute-derive supports that
432
        bail!(enum_repr, "`enum_repr` is only usable with enums")
433
13
    }
434
13

            
435
13
    let repr: Type = enum_repr.unwrap_or_else(|| {
436
12
        Type::Path(TypePath {
437
12
            qself: None,
438
12
            path: repr.unwrap_or_else(|| format_ident!("isize")).into(),
439
12
        })
440
13
    });
441

            
442
13
    let (encoder_constructor, decoder_constructor) = match null_handling {
443
13
        NullHandling::Escape => (quote!(default), quote!(default_for)),
444
        NullHandling::Allow => (quote!(allowing_null_bytes), quote!(allowing_null_bytes)),
445
        NullHandling::Deny => (quote!(denying_null_bytes), quote!(denying_null_bytes)),
446
    };
447

            
448
13
    let core = core.unwrap_or_else(core_path);
449
13
    let (_, ty_generics, _) = generics.split_for_impl();
450
13
    let mut generics = generics.clone();
451
13
    let lifetimes: Vec<_> = generics.lifetimes().cloned().collect();
452
13
    let where_clause = generics.make_where_clause();
453
15
    for lifetime in lifetimes {
454
2
        where_clause.predicates.push(parse_quote!($'key: #lifetime));
455
2
    }
456
13
    generics
457
13
        .params
458
13
        .push(syn::GenericParam::Lifetime(parse_quote!($'key)));
459
13
    let (impl_generics, _, where_clause) = generics.split_for_impl();
460

            
461
    // Special case the implementation for 1
462
    // field -- just pass through to the
463
    // inner type so that this encoding is
464
    // completely transparent.
465
3
    if let Some((name, ty, map)) = match &data {
466
4
        Data::Struct(DataStruct {
467
4
            fields: Fields::Named(FieldsNamed { named, .. }),
468
4
            ..
469
4
        }) if named.len() == 1 => {
470
2
            let name = &named[0].ident;
471
2
            Some((
472
2
                quote!(#name),
473
2
                named[0].ty.clone(),
474
2
                quote!(|value| Self { #name: value }),
475
2
            ))
476
        }
477
        Data::Struct(DataStruct {
478
3
            fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
479
3
            ..
480
3
        }) if unnamed.len() == 1 => Some((quote!(0), unnamed[0].ty.clone(), quote!(Self))),
481
10
        _ => None,
482
    } {
483
3
        return Ok(quote! {
484
3
            # use std::{borrow::Cow, io::{self, ErrorKind}};
485
3
            # use #core::key::{ByteSource, KeyVisitor, IncorrectByteLength, Key, KeyEncoding};
486
3

            
487
3
            impl #impl_generics Key<$'key> for #ident #ty_generics #where_clause {
488
3
                const CAN_OWN_BYTES: bool = <#ty>::CAN_OWN_BYTES;
489
3

            
490
3
                fn from_ord_bytes<$'b>(bytes: ByteSource<$'key, $'b>) -> Result<Self, Self::Error> {
491
3
                    <#ty>::from_ord_bytes(bytes).map(#map)
492
3
                }
493
3
            }
494
3

            
495
3
            impl #impl_generics KeyEncoding<Self> for #ident #ty_generics #where_clause {
496
3
                type Error = <#ty as KeyEncoding>::Error;
497
3

            
498
3
                const LENGTH: Option<usize> = <#ty>::LENGTH;
499
3

            
500
3
                fn describe<Visitor>(visitor: &mut Visitor)
501
3
                where
502
3
                    Visitor: KeyVisitor,
503
3
                {
504
3
                    <#ty>::describe(visitor)
505
3
                }
506
3

            
507
3
                fn as_ord_bytes(&self) -> Result<Cow<'_, [u8]>, Self::Error> {
508
3
                    self.#name.as_ord_bytes()
509
3
                }
510
3
            }
511
3
        });
512
10
    }
513

            
514
5
    let (encode_fields, decode_fields, describe, composite_kind, field_count): (
515
5
        TokenStream,
516
5
        TokenStream,
517
5
        TokenStream,
518
5
        TokenStream,
519
5
        usize,
520
10
    ) = match data {
521
5
        Data::Struct(DataStruct { fields, .. }) => {
522
5
            let (encode_fields, decode_fields, describe, field_count) = match fields {
523
2
                Fields::Named(FieldsNamed { named, .. }) => {
524
2
                    let field_count = named.len();
525
2
                    let (encode_fields, (decode_fields, describe)): (
526
2
                        TokenStream,
527
2
                        (TokenStream, TokenStream),
528
2
                    ) = named
529
2
                        .into_iter()
530
4
                        .map(|Field { ident, ty, .. }| {
531
4
                            let ident = ident.expect("named fields have idents");
532
4
                            (
533
4
                                quote!($encoder.encode(&self.#ident)?;),
534
4
                                (
535
4
                                    quote!(#ident: $decoder.decode()?,),
536
4
                                    quote!(<#ty>::describe(visitor);),
537
4
                                ),
538
4
                            )
539
4
                        })
540
2
                        .unzip();
541
2
                    (
542
2
                        encode_fields,
543
2
                        quote!( Self { #decode_fields }),
544
2
                        describe,
545
2
                        field_count,
546
2
                    )
547
                }
548
2
                Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
549
2
                    let field_count = unnamed.len();
550
2
                    let (encode_fields, (decode_fields, describe)): (
551
2
                        TokenStream,
552
2
                        (TokenStream, TokenStream),
553
2
                    ) = unnamed
554
2
                        .into_iter()
555
2
                        .enumerate()
556
5
                        .map(|(idx, field)| {
557
5
                            let ty = field.ty;
558
5
                            let idx = Index::from(idx);
559
5
                            (
560
5
                                quote!($encoder.encode(&self.#idx)?;),
561
5
                                (
562
5
                                    quote!($decoder.decode()?,),
563
5
                                    quote!(<#ty>::describe(visitor);),
564
5
                                ),
565
5
                            )
566
5
                        })
567
2
                        .unzip();
568
2
                    (
569
2
                        encode_fields,
570
2
                        quote!(Self(#decode_fields)),
571
2
                        describe,
572
2
                        field_count,
573
2
                    )
574
                }
575
                Fields::Unit => {
576
1
                    return Ok(quote! {
577
1
                        # use std::{borrow::Cow, io::{self, ErrorKind}};
578
1
                        # use #core::key::{ByteSource, KeyVisitor, IncorrectByteLength, Key, KeyKind, KeyEncoding};
579
1

            
580
1
                        impl #impl_generics Key<$'key> for #ident #ty_generics #where_clause {
581
1
                            const CAN_OWN_BYTES: bool = false;
582
1

            
583
1
                            fn from_ord_bytes<$'b>(bytes: ByteSource<$'key, $'b>) -> Result<Self, Self::Error> {
584
1
                                Ok(Self)
585
1
                            }
586
1
                        }
587
1

            
588
1
                        impl #impl_generics KeyEncoding<Self> for #ident #ty_generics #where_clause {
589
1
                            type Error = std::convert::Infallible;
590
1

            
591
1
                            const LENGTH: Option<usize> = Some(0);
592
1

            
593
1
                            fn describe<Visitor>(visitor: &mut Visitor)
594
1
                            where
595
1
                                Visitor: KeyVisitor,
596
1
                            {
597
1
                                visitor.visit_type(KeyKind::Unit);
598
1
                            }
599
1

            
600
1
                            fn as_ord_bytes(&self) -> Result<Cow<'_, [u8]>, Self::Error> {
601
1
                                Ok(Cow::Borrowed(&[]))
602
1
                            }
603
1
                        }
604
1
                    })
605
                }
606
            };
607
4
            (
608
4
                encode_fields,
609
4
                quote!(let $self_ = #decode_fields;),
610
4
                describe,
611
4
                quote!(#core::key::CompositeKind::Struct(std::borrow::Cow::Borrowed(#name))),
612
4
                field_count,
613
4
            )
614
        }
615
5
        Data::Enum(DataEnum { variants, .. }) => {
616
5
            let mut prev_ident = None;
617
5
            let field_count = variants.len();
618
11
            let all_variants_are_empty = variants.iter().all(|variant| variant.fields.is_empty());
619
5

            
620
5
            let (consts, (encode_variants, (decode_variants, describe))): (
621
5
                TokenStream,
622
5
                (TokenStream, (TokenStream, TokenStream)),
623
5
            ) = variants
624
5
                .into_iter()
625
5
                .enumerate()
626
5
                .map(
627
5
                    |(
628
                        idx,
629
                        Variant {
630
                            fields,
631
                            ident,
632
                            discriminant,
633
                            ..
634
                        },
635
12
                    )| {
636
12
                        let discriminant = discriminant.map_or_else(
637
12
                            || {
638
6
                                prev_ident
639
6
                                    .as_ref()
640
6
                                    .map_or_else(|| quote!(0), |ident| quote!(#ident + 1))
641
12
                            },
642
12
                            |(_, expr)| expr.to_token_stream(),
643
12
                        );
644
12

            
645
12
                        let const_ident = format_ident!("$discriminant{idx}");
646
12
                        let const_ = quote!(const #const_ident: #repr = #discriminant;);
647

            
648
12
                        let ret = (
649
12
                            const_,
650
12
                            match fields {
651
1
                                Fields::Named(FieldsNamed { named, .. }) => {
652
1
                                    let (idents, (encode_fields, (decode_fields, describe))): (
653
1
                                        Punctuated<_, Token![,]>,
654
1
                                        (TokenStream, (TokenStream, TokenStream)),
655
1
                                    ) = named
656
1
                                        .into_iter()
657
2
                                        .map(|Field { ident, ty, .. }| {
658
2
                                            let ident = ident.expect("named fields have idents");
659
2
                                            (
660
2
                                                ident.clone(),
661
2
                                                (
662
2
                                                    quote!($encoder.encode(#ident)?;),
663
2
                                                    (
664
2
                                                        quote!(#ident: $decoder.decode()?,),
665
2
                                                        quote!(<#ty>::describe(visitor);),
666
2
                                                    ),
667
2
                                                ),
668
2
                                            )
669
2
                                        })
670
1
                                        .unzip();
671
1
                                    (
672
1
                                        quote! {
673
1
                                            Self::#ident{#idents} => {
674
1
                                                $encoder.encode(&#const_ident)?;
675
1
                                                #encode_fields
676
1
                                            },
677
1
                                        },
678
1
                                        (
679
1
                                            quote! {
680
1
                                                #const_ident => Self::#ident{#decode_fields},
681
1
                                            },
682
1
                                            describe,
683
1
                                        ),
684
1
                                    )
685
                                }
686
1
                                Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
687
1
                                    let (idents, (encode_fields, (decode_fields, describe))): (
688
1
                                        Punctuated<_, Token![,]>,
689
1
                                        (TokenStream, (TokenStream, TokenStream)),
690
1
                                    ) = unnamed
691
1
                                        .into_iter()
692
1
                                        .enumerate()
693
2
                                        .map(|(idx, field)| {
694
2
                                            let ident = format_ident!("$field_{idx}");
695
2
                                            let ty = field.ty;
696
2
                                            (
697
2
                                                ident.clone(),
698
2
                                                (
699
2
                                                    quote!($encoder.encode(#ident)?;),
700
2
                                                    (
701
2
                                                        quote!($decoder.decode()?,),
702
2
                                                        quote!(<#ty>::describe(visitor);),
703
2
                                                    ),
704
2
                                                ),
705
2
                                            )
706
2
                                        })
707
1
                                        .unzip();
708
1
                                    (
709
1
                                        quote! {
710
1
                                            Self::#ident(#idents) => {
711
1
                                                $encoder.encode(&#const_ident)?;
712
1
                                                #encode_fields
713
1
                                            },
714
1
                                        },
715
1
                                        (
716
1
                                            quote! {
717
1
                                                #const_ident => Self::#ident(#decode_fields),
718
1
                                            },
719
1
                                            describe,
720
1
                                        ),
721
1
                                    )
722
                                }
723
                                Fields::Unit => {
724
10
                                    let encode = if all_variants_are_empty {
725
9
                                        quote!(Self::#ident => #const_ident.as_ord_bytes(),)
726
                                    } else {
727
1
                                        quote!(Self::#ident => $encoder.encode(&#const_ident)?,)
728
                                    };
729
10
                                    (
730
10
                                        encode,
731
10
                                        (
732
10
                                            quote!(#const_ident => Self::#ident,),
733
10
                                            quote!(visitor.visit_type(#core::key::KeyKind::Unit);),
734
10
                                        ),
735
10
                                    )
736
                                }
737
                            },
738
                        );
739
12
                        prev_ident = Some(const_ident);
740
12
                        ret
741
12
                    },
742
5
                )
743
5
                .unzip();
744
5

            
745
5
            if all_variants_are_empty {
746
                // Special case: if no enum variants have embedded values,
747
                // implement Key as a plain value, avoiding the composite key
748
                // overhead.
749
4
                return Ok(quote! {
750
4
                    # use std::{borrow::Cow, io::{self, ErrorKind}};
751
4
                    # use #core::key::{ByteSource, CompositeKeyDecoder, KeyVisitor, CompositeKeyEncoder, CompositeKeyError, Key, KeyEncoding};
752
4

            
753
4
                    impl #impl_generics Key<$'key> for #ident #ty_generics #where_clause {
754
4
                        const CAN_OWN_BYTES: bool = false;
755
4

            
756
4
                        fn from_ord_bytes<$'b>(mut $bytes: ByteSource<$'key, $'b>) -> Result<Self, Self::Error> {
757
4
                            #consts
758
4
                            Ok(match <#repr>::from_ord_bytes($bytes).map_err(#core::key::CompositeKeyError::new)? {
759
4
                                #decode_variants
760
4
                                _ => return Err(#core::key::CompositeKeyError::from(io::Error::from(
761
4
                                        ErrorKind::InvalidData,
762
4
                                )))
763
4
                            })
764
4
                        }
765
4
                    }
766
4

            
767
4
                    impl #impl_generics KeyEncoding<Self> for #ident #ty_generics #where_clause {
768
4
                        type Error = CompositeKeyError;
769
4

            
770
4
                        const LENGTH: Option<usize> = <#repr as KeyEncoding>::LENGTH;
771
4

            
772
4
                        fn describe<Visitor>(visitor: &mut Visitor)
773
4
                        where
774
4
                            Visitor: KeyVisitor,
775
4
                        {
776
4
                            <#repr>::describe(visitor);
777
4
                        }
778
4

            
779
4
                        fn as_ord_bytes(&self) -> Result<Cow<'_, [u8]>, Self::Error> {
780
4
                            #consts
781
4
                            match self {
782
4
                                #encode_variants
783
4
                            }.map_err(#core::key::CompositeKeyError::new)
784
4
                        }
785
4
                    }
786
4
                });
787
1
            }
788
1

            
789
1
            // At least one variant has a value, which means we need to encode a composite field.
790
1
            (
791
1
                quote! {
792
1
                    #consts
793
1
                    match self{
794
1
                        #encode_variants
795
1
                    }
796
1
                },
797
1
                quote! {
798
1
                    # use std::io::{self, ErrorKind};
799
1
                    #consts
800
1
                    let $self_ = match $decoder.decode::<#repr>()? {
801
1
                        #decode_variants
802
1
                        _ => return Err(#core::key::CompositeKeyError::from(io::Error::from(
803
1
                                ErrorKind::InvalidData,
804
1
                        )))
805
1
                    };
806
1
                },
807
1
                describe,
808
1
                quote!(#core::key::CompositeKind::Tuple),
809
1
                field_count,
810
1
            )
811
        }
812
        Data::Union(_) => bail!("unions are not supported"),
813
    };
814

            
815
5
    Ok(quote! {
816
5
        # use std::{borrow::Cow, io::{self, ErrorKind}};
817
5
        # use #core::key::{ByteSource, CompositeKeyDecoder, KeyVisitor, CompositeKeyEncoder, CompositeKeyError, Key, KeyEncoding};
818
5

            
819
5
        impl #impl_generics Key<$'key> for #ident #ty_generics #where_clause {
820
5
            const CAN_OWN_BYTES: bool = #can_own_bytes;
821
5

            
822
5
            fn from_ord_bytes<$'b>(mut $bytes: ByteSource<$'key, $'b>) -> Result<Self, Self::Error> {
823
5

            
824
5
                let mut $decoder = CompositeKeyDecoder::#decoder_constructor($bytes);
825
5

            
826
5
                #decode_fields
827
5

            
828
5
                $decoder.finish()?;
829
5

            
830
5
                Ok($self_)
831
5
            }
832
5
        }
833
5

            
834
5
        impl #impl_generics KeyEncoding<Self> for #ident #ty_generics #where_clause {
835
5
            type Error = CompositeKeyError;
836
5

            
837
5
            // TODO fixed width if possible
838
5
            const LENGTH: Option<usize> = None;
839
5

            
840
5
            fn describe<Visitor>(visitor: &mut Visitor)
841
5
            where
842
5
                Visitor: KeyVisitor,
843
5
            {
844
5
                visitor.visit_composite(#composite_kind, #field_count);
845
5
                #describe
846
5
            }
847
5

            
848
5
            fn as_ord_bytes(&self) -> Result<Cow<'_, [u8]>, Self::Error> {
849
5
                let mut $encoder = CompositeKeyEncoder::#encoder_constructor();
850
5

            
851
5
                #encode_fields
852
5

            
853
5
                Ok(Cow::Owned($encoder.finish()))
854
5
            }
855
5
        }
856
5
    })
857
}
858

            
859
64
#[derive(FromAttr)]
860
#[attribute(ident = api)]
861
struct ApiAttribute {
862
    #[attribute(example = "\"name\"")]
863
    name: String,
864
    #[attribute(example = "\"authority\"")]
865
    authority: Option<Expr>,
866
    #[attribute(example = "ResponseType")]
867
    response: Option<Type>,
868
    #[attribute(example = "ErrorType")]
869
    error: Option<Type>,
870
    #[attribute(example = "bosaidb::core")]
871
    core: Option<Path>,
872
}
873

            
874
/// Derives the `bonsaidb::core::api::Api` trait.
875
///
876
/// `#[api(name = "Name", authority = "Authority", response = ResponseType, error = ErrorType, core = bonsaidb::core)]`
877
/// `authority`, `response`, `error` and `core` are optional
878
8
#[manyhow]
879
#[proc_macro_derive(Api, attributes(api))]
880
8
pub fn api_derive(input: proc_macro::TokenStream) -> Result {
881
    let DeriveInput {
882
8
        attrs,
883
8
        ident,
884
8
        generics,
885
        ..
886
8
    } = parse(input)?;
887

            
888
    let ApiAttribute {
889
8
        name,
890
8
        authority,
891
8
        response,
892
8
        error,
893
8
        core,
894
8
    } = ApiAttribute::from_attributes(&attrs)?;
895

            
896
8
    let core = core.unwrap_or_else(core_path);
897
8
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
898
8

            
899
8
    let name = authority.map_or_else(
900
8
        || quote!(#core::schema::Qualified::private(#name)),
901
8
        |authority| quote!(#core::schema::Qualified::new(#authority, #name)),
902
8
    );
903
8

            
904
8
    let response = response.unwrap_or_else(|| parse_quote!(()));
905
8
    let error = error.unwrap_or_else(|| parse_quote!(#core::api::Infallible));
906
8

            
907
8
    Ok(quote! {
908
8
        # use #core::api::{Api, ApiName};
909
8

            
910
8
        impl #impl_generics Api for #ident #ty_generics #where_clause {
911
8
            type Response = #response;
912
8
            type Error = #error;
913
8

            
914
8
            fn name() -> ApiName {
915
8
                #name
916
8
            }
917
8
        }
918
8
    })
919
}
920

            
921
// -----------------------------------------------------------------------------
922
//     - File Macros -
923
// -----------------------------------------------------------------------------
924

            
925
5
fn files_path() -> Path {
926
5
    match crate_name("bonsaidb") {
927
5
        Ok(FoundCrate::Name(name)) => {
928
5
            let ident = Ident::new(&name, Span::call_site());
929
5
            parse_quote!(::#ident::files)
930
        }
931
        Ok(FoundCrate::Itself) => parse_quote!(crate::files),
932
        Err(_) => match crate_name("bonsaidb_files") {
933
            Ok(FoundCrate::Name(name)) => {
934
                let ident = Ident::new(&name, Span::call_site());
935
                parse_quote!(::#ident)
936
            }
937
            Ok(FoundCrate::Itself) => parse_quote!(crate),
938
            Err(_) if cfg!(feature = "omnibus-path") => parse_quote!(::bonsaidb::files),
939
            Err(_) => parse_quote!(::bonsaidb_core),
940
        },
941
    }
942
5
}
943

            
944
51
#[derive(FromAttr)]
945
#[attribute(ident = file_config)]
946
struct FileConfigAttribute {
947
    #[attribute(example = "MetadataType")]
948
    metadata: Option<Type>,
949
    #[attribute(example = "65_536")]
950
    block_size: Option<usize>,
951
    #[attribute(example = "\"authority\"")]
952
    authority: Option<Expr>,
953
    #[attribute(example = "\"files\"")]
954
    files_name: Option<String>,
955
    #[attribute(example = "\"blocks\"")]
956
    blocks_name: Option<String>,
957
    #[attribute(example = "bosaidb::core")]
958
    core: Option<Path>,
959
    #[attribute(example = "bosaidb::files")]
960
    files: Option<Path>,
961
}
962

            
963
/// Derives the `bonsaidb::files::FileConfig` trait.
964
///
965
/// `#[api(metadata = MetadataType, block_size = 65_536, authority = "authority", files_name = "files", blocks_name = "blocks", core = bonsaidb::core, files = bosaidb::files)]`
966
/// all arguments are optional
967
5
#[manyhow]
968
#[proc_macro_derive(FileConfig, attributes(file_config))]
969
5
pub fn file_config_derive(input: proc_macro::TokenStream) -> Result {
970
    let DeriveInput {
971
5
        attrs,
972
5
        ident,
973
5
        generics,
974
        ..
975
5
    } = parse(input)?;
976

            
977
    let FileConfigAttribute {
978
5
        metadata,
979
5
        block_size,
980
5
        authority,
981
5
        files_name,
982
5
        blocks_name,
983
5
        core,
984
5
        files,
985
5
    } = FileConfigAttribute::from_attributes(&attrs)?;
986

            
987
5
    let core = core.unwrap_or_else(core_path);
988
5
    let files = files.unwrap_or_else(files_path);
989
5
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
990

            
991
5
    let (files_name, blocks_name) = match (authority, files_name, blocks_name) {
992
3
        (None, None, None) => (
993
3
            quote!(#files::BonsaiFiles::files_name()),
994
3
            quote!(#files::BonsaiFiles::blocks_name()),
995
3
        ),
996
1
        (Some(authority), Some(files_name), Some(blocks_name)) => (
997
1
            quote!(#core::schema::Qualified::new(#authority, #files_name)),
998
1
            quote!(#core::schema::Qualified::new(#authority, #blocks_name)),
999
1
        ),
1
        (None, Some(files_name), Some(blocks_name)) => (
1
            quote!(#core::schema::Qualified::private(#files_name)),
1
            quote!(#core::schema::Qualified::private(#blocks_name)),
1
        ),
        (Some(_), ..) => bail!(
            "if `authority` is specified, `files_name` and `blocks_name need to be provided as well"
        ),
        (_, Some(_), _) => {
            bail!("if `files_name` is specified, `blocks_name` needs to be provided as well")
        }
        (_, _, Some(_)) => {
            bail!("if `blocks_name` is specified, `files_name` needs to be provided as well")
        }
    };

            
5
    let metadata = metadata
5
        .unwrap_or_else(|| parse_quote!(<#files::BonsaiFiles as #files::FileConfig>::Metadata));
5
    let block_size = block_size.map_or_else(
5
        || quote!(<#files::BonsaiFiles as #files::FileConfig>::BLOCK_SIZE),
5
        |block_size| quote!(#block_size),
5
    );
5

            
5
    Ok(quote! {
5
        # use #files::FileConfig;
5
        # use #core::schema::CollectionName;
5

            
5
        impl #impl_generics FileConfig for #ident #ty_generics #where_clause {
5
            type Metadata = #metadata;
5
            const BLOCK_SIZE: usize = #block_size;
5

            
5
            fn files_name() -> CollectionName {
5
                #files_name
5
            }
5

            
5
            fn blocks_name() -> CollectionName {
5
                #blocks_name
5
            }
5
        }
5
    })
}

            
1
#[test]
1
fn ui() {
1
    use trybuild::TestCases;
1

            
1
    TestCases::new().compile_fail("tests/ui/*/*.rs");
1
}