bones_schema_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Punct, Spacing, TokenStream as TokenStream2, TokenTree as TokenTree2};
3use quote::{format_ident, quote, quote_spanned, spanned::Spanned};
4use venial::{GenericBound, StructFields};
5
6/// Helper macro to bail out of the macro with a compile error.
7macro_rules! throw {
8    ($hasSpan:expr, $err:literal) => {
9        let span = $hasSpan.__span();
10        return quote_spanned!(span =>
11            compile_error!($err);
12        ).into();
13    };
14}
15
16/// Derive macro for the HasSchema trait.
17///
18/// ## Usage with #[repr(C)]
19/// HasSchema works with the #[repr(C)] annotation to fully implement its features.
20///
21/// If there is no #[repr(C)] annotation, the SchemaKind of your type's schema will be Opaque.
22///
23/// This means if you don't know the kind of type, like in the case of a SchemaBox, you'll be unable
24/// to read the fields.
25///
26/// This applies to bones' lua scripting since SchemaBox is effectively the "lua type".
27/// See SchemaBox.
28///
29/// If you intend a type to be opaque even though it has #[repr(C)] you can use #[schema(opaque)]
30/// to force an opaque schema representation.
31///
32/// Keep in mind, enums deriving HasSchema with a #[repr(C)] annotation must also specify an
33/// enum tag type like #[repr(C, u8)] where u8 could be either u16 or u32 if you need
34/// more than 256 enum variants.
35///
36/// ## no_default & no_clone attributes
37/// HasSchema derive requires the type to implement Default & Clone, if either of these cannot be
38/// implemented you can use the no_default & no_clone schema attributes respectively to ignore
39/// these requirements.
40/// ```ignore
41/// #[derive(HasSchema, Default)]
42/// #[schema(no_clone)]
43/// struct DoesntImplClone;
44///
45/// #[derive(HasSchema, Clone)]
46/// #[schema(no_default)]
47/// struct DoesntImplDefault;
48/// ```
49/// The caveat for putting no_default on your type is that it cannot be created from a Schema.
50/// This is necessary if you want to create your type within a bones lua script.
51///
52/// Since the fields that need to be initialized to create a complete version of your type cannot be
53/// determined, Schema needs a default function to initialize the data properly.
54///
55/// The caveat for putting no_clone on your type is that it cannot be cloned in the form of a
56/// SchemaBox.
57///
58/// This is critical in the case of bones' networking which will panic if your type is in the world
59/// and does not implement clone during the network rollback.
60///
61/// ## type_data attribute
62/// This attribute takes an expression and stores that value in what is basically
63/// a type keyed map accessible from your type's Schema.
64///
65/// ## derive_type_data attribute
66/// This attribute is simply a shortcut equivalent to using the type_data attribute
67/// with any type's `FromType<YourHasSchemaType>` implementation like so:
68/// ```ignore
69/// #[derive(HasSchema, Clone, Default)]
70/// #[type_data(<OtherType as FromType<Data>>::from_type())]
71/// struct Data;
72/// ```
73/// Simply specify a type instead of an expression:
74/// ```ignore
75/// #[derive(HasSchema, Clone, Default)]
76/// #[derive_type_data(OtherType)] // OtherType implements FromType<Data>
77/// struct Data;
78/// ```
79/// ## Known Limitations
80///
81/// Currently it isn't possible to construct a struct that contains itself. For example, this will
82/// not work:
83///
84/// ```ignore
85/// #[derive(HasSchema)]
86/// struct Data {
87///     others: Vec<Data>,
88/// }
89/// ```
90///
91/// If this is a problem for your use-case, please open an issue.
92#[proc_macro_derive(
93    HasSchema,
94    attributes(schema, derive_type_data, type_data, schema_module)
95)]
96pub fn derive_has_schema(input: TokenStream) -> TokenStream {
97    let input = venial::parse_declaration(input.into()).unwrap();
98    let name = input.name().expect("Type must have a name");
99
100    // Get the schema module, reading optionally from the `schema_module` attribute, so that we can
101    // set the module to `crate` when we want to use it within the `bones_schema` crate itself.
102    let schema_mod = input
103        .attributes()
104        .iter()
105        .find_map(|attr| {
106            (attr.path.len() == 1 && attr.path[0].to_string() == "schema_module").then(|| {
107                attr.value
108                    .get_value_tokens()
109                    .iter()
110                    .cloned()
111                    .collect::<TokenStream2>()
112            })
113        })
114        .unwrap_or_else(|| quote!(bones_schema));
115
116    // Get the type datas that have been added and derived
117    let derive_type_data_flags = get_flags_for_attr(&input, "derive_type_data");
118    let type_datas = {
119        let add_derive_type_datas = derive_type_data_flags.into_iter().map(|ty| {
120            let ty = format_ident!("{ty}");
121            quote! {
122                tds.insert(<#ty as #schema_mod::FromType<#name>>::from_type()).unwrap();
123            }
124        });
125        let add_type_datas = input
126            .attributes()
127            .iter()
128            .filter(|x| x.path.len() == 1 && x.path[0].to_string() == "type_data")
129            .map(|x| x.get_value_tokens())
130            .map(|x| x.iter().cloned().collect::<TokenStream2>());
131
132        quote! {
133            {
134                let tds = #schema_mod::alloc::TypeDatas::default();
135                #(#add_derive_type_datas),*
136                #(
137                    tds.insert(#add_type_datas).unwrap();
138                ),*
139                tds
140            }
141        }
142    };
143
144    // Collect repr tags
145    let mut repr_flags = get_flags_for_attr(&input, "repr");
146    repr_flags.iter_mut().for_each(|x| *x = x.to_lowercase());
147    let repr_c = repr_flags.iter().any(|x| x == "c");
148    let primitive_repr = repr_flags.iter().find_map(|x| match x.as_ref() {
149        "u8" => Some(quote!(U8)),
150        "u16" => Some(quote!(U16)),
151        "u32" => Some(quote!(U32)),
152        _ => None,
153    });
154
155    // Collect schema flags
156    let schema_flags = get_flags_for_attr(&input, "schema");
157    let no_clone = schema_flags.iter().any(|x| x.as_str() == "no_clone");
158    let no_default = schema_flags.iter().any(|x| x.as_str() == "no_default");
159    let is_opaque = schema_flags.iter().any(|x| x.as_str() == "opaque")
160        || !(repr_c || primitive_repr.is_some());
161
162    // Get the clone and default functions based on the flags
163    let clone_fn = if no_clone {
164        quote!(None)
165    } else {
166        quote!(Some(<Self as #schema_mod::raw_fns::RawClone>::raw_clone_cb()))
167    };
168    let default_fn = if no_default {
169        quote!(None)
170    } else {
171        quote!(Some(<Self as #schema_mod::raw_fns::RawDefault>::raw_default_cb()))
172    };
173
174    // Get the schema kind
175    let schema_kind = (|| {
176        if is_opaque {
177            return quote! {
178                {
179                    let layout = ::std::alloc::Layout::new::<Self>();
180                    #schema_mod::SchemaKind::Primitive(#schema_mod::Primitive::Opaque {
181                        size: layout.size(),
182                        align: layout.align(),
183                    })
184                }
185            };
186        }
187
188        // Helper to parse struct fields from structs or enum variants
189        let parse_struct_fields = |fields: &StructFields| {
190            match fields {
191                venial::StructFields::Tuple(tuple) => tuple
192                    .fields
193                    .iter()
194                    .map(|(field, _)| {
195                        let ty = &field.ty;
196                        quote_spanned! {field.ty.__span() =>
197                            #schema_mod::StructFieldInfo {
198                                name: None,
199                                schema: <#ty as #schema_mod::HasSchema>::schema(),
200                            }
201                        }
202                    })
203                    .collect::<Vec<_>>(),
204                venial::StructFields::Named(named) => named
205                    .fields
206                    .iter()
207                    .map(|(field, _)| {
208                        let name = &field.name;
209                        let ty = &field.ty;
210                        let opaque = field.attributes.iter().any(|attr| {
211                            &attr.path[0].to_string() == "schema"
212                                && &attr.value.get_value_tokens()[0].to_string() == "opaque"
213                        });
214
215                        if opaque {
216                            quote_spanned! {field.ty.__span() =>
217                                #schema_mod::StructFieldInfo {
218                                    name: Some(stringify!(#name).into()),
219                                    schema: {
220                                        let layout = ::std::alloc::Layout::new::<#ty>();
221                                        #schema_mod::registry::SCHEMA_REGISTRY.register(#schema_mod::SchemaData {
222                                            name: stringify!(#ty).into(),
223                                            full_name: concat!(module_path!(), "::", stringify!(#ty)).into(),
224                                            kind: #schema_mod::SchemaKind::Primitive(#schema_mod::Primitive::Opaque {
225                                                size: layout.size(),
226                                                align: layout.align(),
227                                            }),
228                                            type_id: Some(std::any::TypeId::of::<#ty>()),
229                                            type_data: #type_datas,
230                                            clone_fn: #clone_fn,
231                                            default_fn: #default_fn,
232                                            eq_fn: None,
233                                            hash_fn: None,
234                                            drop_fn: Some(<Self as #schema_mod::raw_fns::RawDrop>::raw_drop_cb()),
235                                        })
236                                    },
237                                }
238                            }
239                        } else {
240                            quote_spanned! {field.ty.__span() =>
241                                #schema_mod::StructFieldInfo {
242                                    name: Some(stringify!(#name).into()),
243                                    schema: <#ty as #schema_mod::HasSchema>::schema(),
244                                }
245                            }
246                        }
247                    })
248                    .collect::<Vec<_>>(),
249                venial::StructFields::Unit => Vec::new(),
250            }
251        };
252
253        // Match on the the type we are deriving on and return its SchemaData
254        match &input {
255            venial::Declaration::Struct(s) => {
256                let fields = parse_struct_fields(&s.fields);
257
258                quote! {
259                    #schema_mod::SchemaKind::Struct(#schema_mod::StructSchemaInfo {
260                        fields: vec![
261                            #(#fields),*
262                        ]
263                    })
264                }
265            }
266            venial::Declaration::Enum(e) => {
267                let Some(tag_type) = primitive_repr else {
268                    throw!(
269                        e,
270                        "Enums deriving HasSchema with a `#[repr(C)]` annotation \
271                        must also specify an enum tag type like `#[repr(C, u8)]` where \
272                        `u8` could be either `u16` or `u32` if you need more than 256 enum \
273                        variants."
274                    );
275                };
276                let mut variants = Vec::new();
277
278                for v in e.variants.items() {
279                    let name = v.name.to_string();
280                    let variant_schema_name = format!("{}::{}", e.name, name);
281                    let fields = parse_struct_fields(&v.contents);
282
283                    let register_schema = if input.generic_params().is_some() {
284                        quote! {
285                            static S: OnceLock<RwLock<HashMap<TypeId, &'static Schema>>> = OnceLock::new();
286                            let schema = {
287                                S.get_or_init(Default::default)
288                                    .read()
289                                    .get(&TypeId::of::<Self>())
290                                    .copied()
291                            };
292                            schema.unwrap_or_else(|| {
293                                let schema = compute_schema();
294
295                                S.get_or_init(Default::default)
296                                    .write()
297                                    .insert(TypeId::of::<Self>(), schema);
298
299                                schema
300                            })
301                        }
302                    } else {
303                        quote! {
304                            static S: ::std::sync::OnceLock<&'static #schema_mod::Schema> = ::std::sync::OnceLock::new();
305                            S.get_or_init(compute_schema)
306                        }
307                    };
308
309                    variants.push(quote! {
310                        #schema_mod::VariantInfo {
311                            name: #name.into(),
312                            schema: {
313                                let compute_schema = || {
314                                    #schema_mod::registry::SCHEMA_REGISTRY.register(#schema_mod::SchemaData {
315                                        name: #variant_schema_name.into(),
316                                        full_name: concat!(module_path!(), "::", #variant_schema_name).into(),
317                                        type_id: None,
318                                        kind: #schema_mod::SchemaKind::Struct(#schema_mod::StructSchemaInfo {
319                                            fields: vec![
320                                                #(#fields),*
321                                            ]
322                                        }),
323                                        type_data: Default::default(),
324                                        default_fn: None,
325                                        clone_fn: None,
326                                        eq_fn: None,
327                                        hash_fn: None,
328                                        drop_fn: None,
329                                    })
330                                };
331                                #register_schema
332                            }
333                        }
334                    })
335                }
336
337                quote! {
338                    #schema_mod::SchemaKind::Enum(#schema_mod::EnumSchemaInfo {
339                        tag_type: #schema_mod::EnumTagType::#tag_type,
340                        variants: vec![#(#variants),*],
341                    })
342                }
343            }
344            _ => {
345                throw!(
346                    input,
347                    "You may only derive HasSchema for structs and enums."
348                );
349            }
350        }
351    })();
352
353    let schema_register = quote! {
354        #schema_mod::registry::SCHEMA_REGISTRY.register(#schema_mod::SchemaData {
355            name: stringify!(#name).into(),
356            full_name: concat!(module_path!(), "::", stringify!(#name)).into(),
357            type_id: Some(::std::any::TypeId::of::<Self>()),
358            kind: #schema_kind,
359            type_data: #type_datas,
360            default_fn: #default_fn,
361            clone_fn: #clone_fn,
362            eq_fn: None,
363            hash_fn: None,
364            drop_fn: Some(<Self as #schema_mod::raw_fns::RawDrop>::raw_drop_cb()),
365        })
366    };
367
368    if let Some(generic_params) = input.generic_params() {
369        let mut sync_send_generic_params = generic_params.clone();
370        for (param, _) in sync_send_generic_params.params.iter_mut() {
371            let clone_bound = if !no_clone { quote!(+ Clone) } else { quote!() };
372            param.bound = Some(GenericBound {
373                tk_colon: Punct::new(':', Spacing::Joint),
374                tokens: quote!(HasSchema #clone_bound ).into_iter().collect(),
375            });
376        }
377        quote! {
378            unsafe impl #sync_send_generic_params #schema_mod::HasSchema for #name #generic_params {
379                fn schema() -> &'static #schema_mod::Schema {
380                    use ::std::sync::{OnceLock};
381                    use ::std::any::TypeId;
382                    use bones_utils::HashMap;
383                    use parking_lot::RwLock;
384                    static S: OnceLock<RwLock<HashMap<TypeId, &'static Schema>>> = OnceLock::new();
385                    let schema = {
386                        S.get_or_init(Default::default)
387                            .read()
388                            .get(&TypeId::of::<Self>())
389                            .copied()
390                    };
391                    schema.unwrap_or_else(|| {
392                        let schema = #schema_register;
393
394                        S.get_or_init(Default::default)
395                            .write()
396                            .insert(TypeId::of::<Self>(), schema);
397
398                        schema
399                    })
400
401                }
402            }
403        }
404    } else {
405        quote! {
406            unsafe impl #schema_mod::HasSchema for #name {
407                fn schema() -> &'static #schema_mod::Schema {
408                    static S: ::std::sync::OnceLock<&'static #schema_mod::Schema> = ::std::sync::OnceLock::new();
409                    S.get_or_init(|| {
410                        #schema_register
411                    })
412                }
413            }
414        }
415    }
416    .into()
417}
418
419//
420// Helpers
421//
422
423/// Look for an attribute with the given name and get all of the comma-separated flags that are
424/// in that attribute.
425///
426/// For example, with the given struct:
427///
428/// ```ignore
429/// #[example(test)]
430/// #[my_attr(hello, world)]
431/// struct Hello;
432/// ```
433///
434/// Calling `get_flags_for_attr("my_attr")` would return `vec!["hello", "world"]`.
435fn get_flags_for_attr(input: &venial::Declaration, attr_name: &str) -> Vec<String> {
436    let attrs = input
437        .attributes()
438        .iter()
439        .filter(|attr| attr.path.len() == 1 && attr.path[0].to_string() == attr_name)
440        .collect::<Vec<_>>();
441    attrs
442        .iter()
443        .map(|attr| match &attr.value {
444            venial::AttributeValue::Group(_, value) => {
445                let mut flags = Vec::new();
446
447                let mut current_flag = proc_macro2::TokenStream::new();
448                for token in value {
449                    match token {
450                        TokenTree2::Punct(x) if x.as_char() == ',' => {
451                            flags.push(current_flag.to_string());
452                            current_flag = Default::default();
453                        }
454                        x => current_flag.extend(std::iter::once(x.clone())),
455                    }
456                }
457                flags.push(current_flag.to_string());
458
459                flags
460            }
461            venial::AttributeValue::Equals(_, _) => {
462                // TODO: Improve macro error message span.
463                panic!("Unsupported attribute format");
464            }
465            venial::AttributeValue::Empty => Vec::new(),
466        })
467        .fold(Vec::new(), |mut acc, item| {
468            acc.extend(item);
469            acc
470        })
471}