bones_schema/
ser_de.rs

1use std::any::type_name;
2
3use erased_serde::Deserializer;
4use serde::{
5    de::{DeserializeSeed, Error},
6    Deserialize, Serialize,
7};
8
9use crate::prelude::*;
10
11pub use serializer_deserializer::*;
12mod serializer_deserializer {
13    use serde::{
14        de::{Unexpected, VariantAccess, Visitor},
15        ser::{SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant},
16    };
17    use ustr::{ustr, Ustr};
18
19    use super::*;
20
21    /// A struct that implements [`Serialize`] and wraps around a [`SchemaRef`] to serialize the value
22    /// using it's schema.
23    ///
24    /// This will error if there are opaque types in the schema ref that cannot be serialized.
25    pub struct SchemaSerializer<'a>(pub SchemaRef<'a>);
26
27    impl<'a> Serialize for SchemaSerializer<'a> {
28        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
29        where
30            S: serde::Serializer,
31        {
32            // Specifically handle `Ustr`
33            if let Ok(u) = self.0.try_cast::<Ustr>() {
34                return serializer.serialize_str(u);
35            }
36
37            match self.0.access() {
38                SchemaRefAccess::Struct(s) => {
39                    if s.fields().count() == 1 && s.fields().nth(0).unwrap().name.is_none() {
40                        SchemaSerializer(s.fields().nth(0).unwrap().value).serialize(serializer)
41                    } else {
42                        let named = s.fields().nth(0).map(|x| x.name.is_some()).unwrap_or(false);
43
44                        if named {
45                            let mut ser_struct = serializer
46                                .serialize_struct(&self.0.schema().name, s.fields().count())?;
47                            for field in s.fields() {
48                                ser_struct.serialize_field(
49                                    field.name.as_ref().unwrap(),
50                                    &SchemaSerializer(field.value),
51                                )?;
52                            }
53                            ser_struct.end()
54                        } else {
55                            let mut seq = serializer.serialize_seq(Some(s.fields().count()))?;
56                            for field in s.fields() {
57                                seq.serialize_element(&SchemaSerializer(field.value))?;
58                            }
59                            seq.end()
60                        }
61                    }
62                }
63                SchemaRefAccess::Vec(v) => {
64                    let mut seq = serializer.serialize_seq(Some(v.len()))?;
65                    for item in v.iter() {
66                        seq.serialize_element(&SchemaSerializer(item))?;
67                    }
68                    seq.end()
69                }
70                SchemaRefAccess::Map(m) => {
71                    let mut map = serializer.serialize_map(Some(m.len()))?;
72                    for (key, value) in m.iter() {
73                        map.serialize_entry(&SchemaSerializer(key), &SchemaSerializer(value))?;
74                    }
75                    map.end()
76                }
77                SchemaRefAccess::Enum(e) => {
78                    let variant_idx = e.variant_idx();
79                    let variant_info = e.variant_info();
80                    let access = e.value();
81                    let field_count = access.fields().count();
82
83                    if field_count == 0 {
84                        serializer.serialize_unit_variant(
85                            &self.0.schema().name,
86                            variant_idx,
87                            &variant_info.name,
88                        )
89                    } else if field_count == 1 && access.fields().nth(0).unwrap().name.is_none() {
90                        serializer.serialize_newtype_variant(
91                            &self.0.schema().name,
92                            variant_idx,
93                            &variant_info.name,
94                            &SchemaSerializer(access.as_schema_ref()),
95                        )
96                    } else {
97                        let mut ser_struct = serializer.serialize_struct_variant(
98                            &self.0.schema().name,
99                            variant_idx,
100                            &variant_info.name,
101                            field_count,
102                        )?;
103
104                        for field in access.fields() {
105                            ser_struct.serialize_field(
106                                field.name.as_ref().unwrap(),
107                                &SchemaSerializer(field.value),
108                            )?;
109                        }
110
111                        ser_struct.end()
112                    }
113                }
114                SchemaRefAccess::Primitive(p) => match p {
115                    PrimitiveRef::Bool(b) => serializer.serialize_bool(*b),
116                    PrimitiveRef::U8(n) => serializer.serialize_u8(*n),
117                    PrimitiveRef::U16(n) => serializer.serialize_u16(*n),
118                    PrimitiveRef::U32(n) => serializer.serialize_u32(*n),
119                    PrimitiveRef::U64(n) => serializer.serialize_u64(*n),
120                    PrimitiveRef::U128(n) => serializer.serialize_u128(*n),
121                    PrimitiveRef::I8(n) => serializer.serialize_i8(*n),
122                    PrimitiveRef::I16(n) => serializer.serialize_i16(*n),
123                    PrimitiveRef::I32(n) => serializer.serialize_i32(*n),
124                    PrimitiveRef::I64(n) => serializer.serialize_i64(*n),
125                    PrimitiveRef::I128(n) => serializer.serialize_i128(*n),
126                    PrimitiveRef::F32(n) => serializer.serialize_f32(*n),
127                    PrimitiveRef::F64(n) => serializer.serialize_f64(*n),
128                    PrimitiveRef::String(n) => serializer.serialize_str(n),
129                    PrimitiveRef::Opaque { .. } => {
130                        use serde::ser::Error;
131                        Err(S::Error::custom("Cannot serialize opaque types"))
132                    }
133                },
134            }
135        }
136    }
137
138    /// A struct that implements [`DeserializeSeed`] and can be used to deserialize values matching a
139    /// given [`Schema`].
140    ///
141    /// This will error if there are opaque types in the schema that cannot be deserialized.
142    pub struct SchemaDeserializer(pub &'static Schema);
143
144    impl<'de> DeserializeSeed<'de> for SchemaDeserializer {
145        type Value = SchemaBox;
146
147        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
148        where
149            D: serde::Deserializer<'de>,
150        {
151            // Allocate the object.
152            let mut ptr = SchemaBox::default(self.0);
153
154            // Deserialize into it
155            ptr.as_mut().deserialize(deserializer)?;
156
157            Ok(ptr)
158        }
159    }
160
161    impl<'a, 'de> DeserializeSeed<'de> for SchemaRefMut<'a> {
162        type Value = ();
163
164        fn deserialize<D>(mut self, deserializer: D) -> Result<Self::Value, D::Error>
165        where
166            D: serde::Deserializer<'de>,
167        {
168            // Use custom deserializer if present.
169            if let Some(schema_deserialize) = self.schema().type_data.get::<SchemaDeserialize>() {
170                return schema_deserialize.deserialize(self, deserializer);
171            }
172
173            match &self.schema().kind {
174                SchemaKind::Struct(s) => {
175                    // If this is a newtype struct
176                    if s.fields.len() == 1 && s.fields[0].name.is_none() {
177                        // Deserialize it as the inner type
178                        // SOUND: it is safe to cast a struct with one field to it's field type
179                        unsafe { SchemaRefMut::from_ptr_schema(self.as_ptr(), s.fields[0].schema) }
180                            .deserialize(deserializer)?
181                    } else {
182                        deserializer.deserialize_any(StructVisitor(self))?
183                    }
184                }
185                SchemaKind::Vec(_) => deserializer.deserialize_seq(VecVisitor(self))?,
186                SchemaKind::Map { .. } => deserializer.deserialize_map(MapVisitor(self))?,
187                SchemaKind::Enum(_) => deserializer.deserialize_any(EnumVisitor(self))?,
188                SchemaKind::Box(_) => self.into_box().unwrap().deserialize(deserializer)?,
189                SchemaKind::Primitive(p) => {
190                    match p {
191                        Primitive::Bool => *self.cast_mut() = bool::deserialize(deserializer)?,
192                        Primitive::U8 => *self.cast_mut() = u8::deserialize(deserializer)?,
193                        Primitive::U16 => *self.cast_mut() = u16::deserialize(deserializer)?,
194                        Primitive::U32 => *self.cast_mut() = u32::deserialize(deserializer)?,
195                        Primitive::U64 => *self.cast_mut() = u64::deserialize(deserializer)?,
196                        Primitive::U128 => *self.cast_mut() = u128::deserialize(deserializer)?,
197                        Primitive::I8 => *self.cast_mut() = i8::deserialize(deserializer)?,
198                        Primitive::I16 => *self.cast_mut() = i16::deserialize(deserializer)?,
199                        Primitive::I32 => *self.cast_mut() = i32::deserialize(deserializer)?,
200                        Primitive::I64 => *self.cast_mut() = i64::deserialize(deserializer)?,
201                        Primitive::I128 => *self.cast_mut() = i128::deserialize(deserializer)?,
202                        Primitive::F32 => *self.cast_mut() = f32::deserialize(deserializer)?,
203                        Primitive::F64 => *self.cast_mut() = f64::deserialize(deserializer)?,
204                        Primitive::String => *self.cast_mut() = String::deserialize(deserializer)?,
205                        Primitive::Opaque { .. } => {
206                            return Err(D::Error::custom(
207                                "Opaque types must be #[repr(C)] or have `SchemaDeserialize` type \
208                                data in order to be deserialized.",
209                            ));
210                        }
211                    };
212                }
213            };
214
215            Ok(())
216        }
217    }
218
219    struct StructVisitor<'a>(SchemaRefMut<'a>);
220    impl<'a, 'de> Visitor<'de> for StructVisitor<'a> {
221        type Value = ();
222        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
223            write!(
224                formatter,
225                "asset metadata matching the schema: {:#?}",
226                self.0.schema()
227            )
228        }
229
230        fn visit_seq<A>(mut self, mut seq: A) -> Result<Self::Value, A::Error>
231        where
232            A: serde::de::SeqAccess<'de>,
233        {
234            let field_count = self.0.schema().kind.as_struct().unwrap().fields.len();
235
236            for i in 0..field_count {
237                let field = self.0.access_mut().field(i).unwrap().into_schema_ref_mut();
238                if seq.next_element_seed(field)?.is_none() {
239                    break;
240                }
241            }
242
243            Ok(())
244        }
245
246        fn visit_map<A>(mut self, mut map: A) -> Result<Self::Value, A::Error>
247        where
248            A: serde::de::MapAccess<'de>,
249        {
250            while let Some(key) = map.next_key::<String>()? {
251                match self
252                    .0
253                    .access_mut()
254                    .field(&key)
255                    .map(|x| x.into_schema_ref_mut())
256                {
257                    Ok(field) => {
258                        map.next_value_seed(field)?;
259                    }
260                    Err(_) => {
261                        let fields = &self.0.schema().kind.as_struct().unwrap().fields;
262                        let mut msg = format!("unknown field `{key}`, ");
263                        if !fields.is_empty() {
264                            msg += "expected one of ";
265                            for (i, field) in fields.iter().enumerate() {
266                                msg += &field
267                                    .name
268                                    .as_ref()
269                                    .map(|x| format!("`{x}`"))
270                                    .unwrap_or_else(|| format!("`{i}`"));
271                                if i < fields.len() - 1 {
272                                    msg += ", "
273                                }
274                            }
275                        } else {
276                            msg += "there are no fields"
277                        }
278                        return Err(A::Error::custom(msg));
279                    }
280                }
281            }
282
283            Ok(())
284        }
285    }
286
287    struct VecVisitor<'a>(SchemaRefMut<'a>);
288    impl<'a, 'de> Visitor<'de> for VecVisitor<'a> {
289        type Value = ();
290        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
291            write!(
292                formatter,
293                "asset metadata matching the schema: {:#?}",
294                self.0.schema()
295            )
296        }
297
298        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
299        where
300            A: serde::de::SeqAccess<'de>,
301        {
302            // SOUND: schema asserts this is a SchemaVec.
303            let v = unsafe { &mut *(self.0.as_ptr() as *mut SchemaVec) };
304            loop {
305                let item_schema = v.schema();
306                let mut item = SchemaBox::default(item_schema);
307                let item_ref = item.as_mut();
308                if seq.next_element_seed(item_ref)?.is_none() {
309                    break;
310                }
311                v.push_box(item);
312            }
313
314            Ok(())
315        }
316    }
317    struct MapVisitor<'a>(SchemaRefMut<'a>);
318    impl<'a, 'de> Visitor<'de> for MapVisitor<'a> {
319        type Value = ();
320        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
321            write!(
322                formatter,
323                "asset metadata matching the schema: {:#?}",
324                self.0.schema()
325            )
326        }
327
328        fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
329        where
330            A: serde::de::MapAccess<'de>,
331        {
332            // SOUND: schema asserts this is a SchemaMap.
333            let v = unsafe { &mut *(self.0.as_ptr() as *mut SchemaMap) };
334            let is_ustr = v.key_schema() == Ustr::schema();
335            if v.key_schema() != String::schema() && !is_ustr {
336                return Err(A::Error::custom(
337                    "Can only deserialize maps with `String` or `Ustr` keys.",
338                ));
339            }
340            while let Some(key) = map.next_key::<String>()? {
341                let key = if is_ustr {
342                    SchemaBox::new(ustr(&key))
343                } else {
344                    SchemaBox::new(key)
345                };
346                let mut value = SchemaBox::default(v.value_schema());
347                map.next_value_seed(value.as_mut())?;
348
349                v.insert_box(key, value);
350            }
351            Ok(())
352        }
353    }
354    struct EnumVisitor<'a>(SchemaRefMut<'a>);
355    impl<'a, 'de> Visitor<'de> for EnumVisitor<'a> {
356        type Value = ();
357        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
358            write!(
359                formatter,
360                "asset metadata matching the schema: {:#?}",
361                self.0.schema()
362            )
363        }
364
365        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
366        where
367            E: Error,
368        {
369            let enum_info = self.0.schema().kind.as_enum().unwrap();
370            let var_idx = enum_info
371                .variants
372                .iter()
373                .position(|x| x.name == v)
374                .ok_or_else(|| E::invalid_value(Unexpected::Str(v), &self))?;
375
376            if !enum_info.variants[var_idx]
377                .schema
378                .kind
379                .as_struct()
380                .unwrap()
381                .fields
382                .is_empty()
383            {
384                return Err(E::custom(format!(
385                    "Cannot deserialize enum variant with fields from string: {v}"
386                )));
387            }
388
389            // SOUND: we match the cast with the enum tag type.
390            unsafe {
391                match enum_info.tag_type {
392                    EnumTagType::U8 => self.0.as_ptr().cast::<u8>().write(var_idx as u8),
393                    EnumTagType::U16 => self.0.as_ptr().cast::<u16>().write(var_idx as u16),
394                    EnumTagType::U32 => self.0.as_ptr().cast::<u32>().write(var_idx as u32),
395                }
396            }
397
398            Ok(())
399        }
400
401        fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
402        where
403            A: serde::de::EnumAccess<'de>,
404        {
405            let (value_ptr, var_access) = data.variant_seed(EnumLoad(self.0))?;
406            var_access.newtype_variant_seed(value_ptr)?;
407            Ok(())
408        }
409    }
410
411    struct EnumLoad<'a>(SchemaRefMut<'a>);
412    impl<'a, 'de> DeserializeSeed<'de> for EnumLoad<'a> {
413        type Value = SchemaRefMut<'a>;
414
415        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
416        where
417            D: serde::Deserializer<'de>,
418        {
419            let var_name = String::deserialize(deserializer)?;
420            let enum_info = self.0.schema().kind.as_enum().unwrap();
421            let value_offset = self.0.schema().field_offsets()[0].1;
422            let (var_idx, var_schema) = enum_info
423                .variants
424                .iter()
425                .enumerate()
426                .find_map(|(idx, info)| (info.name == var_name).then_some((idx, info.schema)))
427                .ok_or_else(|| {
428                    D::Error::custom(format!(
429                        "Unknown enum variant `{var_name}`, expected one of: {}",
430                        enum_info
431                            .variants
432                            .iter()
433                            .map(|x| format!("`{}`", x.name))
434                            .collect::<Vec<_>>()
435                            .join(", ")
436                    ))
437                })?;
438
439            // Write the enum variant
440            // SOUND: the schema asserts that the write to the enum discriminant is valid
441            match enum_info.tag_type {
442                EnumTagType::U8 => unsafe { self.0.as_ptr().cast::<u8>().write(var_idx as u8) },
443                EnumTagType::U16 => unsafe { self.0.as_ptr().cast::<u16>().write(var_idx as u16) },
444                EnumTagType::U32 => unsafe { self.0.as_ptr().cast::<u32>().write(var_idx as u32) },
445            }
446
447            if var_schema.kind.as_struct().is_none() {
448                return Err(D::Error::custom(
449                    "All enum variant types must have a struct Schema",
450                ));
451            }
452
453            unsafe {
454                Ok(SchemaRefMut::from_ptr_schema(
455                    self.0.as_ptr().add(value_offset),
456                    var_schema,
457                ))
458            }
459        }
460    }
461}
462
463/// Derivable schema [`type_data`][SchemaData::type_data] for types that implement
464/// [`Deserialize`].
465///
466/// This allows you use serde to implement custom deserialization logic instead of the default one
467/// used for `#[repr(C)]` structs that implement [`HasSchema`].
468pub struct SchemaDeserialize {
469    /// The function that may be used to deserialize the type.
470    pub deserialize_fn: for<'a, 'de> fn(
471        SchemaRefMut<'a>,
472        deserializer: &'a mut dyn Deserializer<'de>,
473    ) -> Result<(), erased_serde::Error>,
474}
475
476unsafe impl HasSchema for SchemaDeserialize {
477    fn schema() -> &'static Schema {
478        use std::{alloc::Layout, any::TypeId, sync::OnceLock};
479        static S: OnceLock<&'static Schema> = OnceLock::new();
480        let layout = Layout::new::<Self>();
481        S.get_or_init(|| {
482            SCHEMA_REGISTRY.register(SchemaData {
483                name: type_name::<Self>().into(),
484                full_name: format!("{}::{}", module_path!(), type_name::<Self>()).into(),
485                kind: SchemaKind::Primitive(Primitive::Opaque {
486                    size: layout.size(),
487                    align: layout.align(),
488                }),
489                type_id: Some(TypeId::of::<Self>()),
490                clone_fn: None,
491                drop_fn: None,
492                default_fn: None,
493                hash_fn: None,
494                eq_fn: None,
495                type_data: Default::default(),
496            })
497        })
498    }
499}
500
501impl SchemaDeserialize {
502    /// Use this [`SchemaDeserialize`] to deserialize data from the `deserializer` into the
503    /// `reference`.
504    pub fn deserialize<'a, 'de, D>(
505        &self,
506        reference: SchemaRefMut<'a>,
507        deserializer: D,
508    ) -> Result<(), D::Error>
509    where
510        D: serde::Deserializer<'de>,
511    {
512        let mut erased = <dyn erased_serde::Deserializer>::erase(deserializer);
513        (self.deserialize_fn)(reference, &mut erased)
514            .map_err(<<D as serde::Deserializer<'de>>::Error as serde::de::Error>::custom)
515    }
516}
517
518impl<T: HasSchema + for<'de> Deserialize<'de>> FromType<T> for SchemaDeserialize {
519    fn from_type() -> Self {
520        SchemaDeserialize {
521            deserialize_fn: |reference, deserializer| {
522                T::schema()
523                    .ensure_match(reference.schema())
524                    .map_err(|e| erased_serde::Error::custom(e.to_string()))?;
525                let data = T::deserialize(deserializer)?;
526
527                // SOUND: we ensured schemas match.
528                unsafe {
529                    reference.as_ptr().cast::<T>().write(data);
530                }
531
532                Ok(())
533            },
534        }
535    }
536}
537
538#[cfg(test)]
539mod test {
540    use super::*;
541    use bones_schema_macros::HasSchema;
542
543    #[derive(HasSchema, Clone, Default)]
544    #[schema_module(crate)]
545    #[repr(C)]
546    struct MyData {
547        name: String,
548        age: Age,
549        favorite_things: SVec<String>,
550        map: SMap<String, String>,
551    }
552
553    #[derive(HasSchema, Clone, Default)]
554    #[schema_module(crate)]
555    #[repr(C)]
556    struct Age(u32);
557
558    const DEMO_YAML: &str = r"name: John
559age: 8
560favorite_things:
561- jelly
562- beans
563map:
564  hello: world
565";
566
567    #[test]
568    fn schema_deserializer() {
569        let deserializer = serde_yaml::Deserializer::from_str(DEMO_YAML);
570
571        let data = SchemaDeserializer(MyData::schema())
572            .deserialize(deserializer)
573            .unwrap()
574            .cast_into::<MyData>();
575
576        assert_eq!(data.name, "John");
577        assert_eq!(data.age.0, 8);
578        assert_eq!(
579            data.favorite_things,
580            ["jelly".to_string(), "beans".to_string()]
581                .into_iter()
582                .collect::<SVec<_>>()
583        );
584        assert_eq!(
585            data.map.into_iter().next().unwrap(),
586            (&"hello".to_string(), &"world".to_string())
587        );
588    }
589
590    #[test]
591    fn schema_serializer() {
592        let mut data = Vec::new();
593        let mut serializer = serde_yaml::Serializer::new(&mut data);
594
595        SchemaSerializer(
596            MyData {
597                name: "John".into(),
598                age: Age(8),
599                favorite_things: ["jelly".to_string(), "beans".to_string()]
600                    .into_iter()
601                    .collect(),
602                map: [("hello".to_string(), "world".to_string())]
603                    .into_iter()
604                    .collect(),
605            }
606            .as_schema_ref(),
607        )
608        .serialize(&mut serializer)
609        .unwrap();
610
611        assert_eq!(DEMO_YAML, String::from_utf8(data).unwrap());
612    }
613}