1use std::any::type_name;
2
3use erased_serde::Deserializer;
4use serde::{
5 de::{DeserializeSeed, Error},
6 Deserialize, Serialize,
7};
8
9use crate::prelude::*;
10
11pub use serializer_deserializer::*;
12mod serializer_deserializer {
13 use serde::{
14 de::{Unexpected, VariantAccess, Visitor},
15 ser::{SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant},
16 };
17 use ustr::{ustr, Ustr};
18
19 use super::*;
20
21 pub struct SchemaSerializer<'a>(pub SchemaRef<'a>);
26
27 impl<'a> Serialize for SchemaSerializer<'a> {
28 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
29 where
30 S: serde::Serializer,
31 {
32 if let Ok(u) = self.0.try_cast::<Ustr>() {
34 return serializer.serialize_str(u);
35 }
36
37 match self.0.access() {
38 SchemaRefAccess::Struct(s) => {
39 if s.fields().count() == 1 && s.fields().nth(0).unwrap().name.is_none() {
40 SchemaSerializer(s.fields().nth(0).unwrap().value).serialize(serializer)
41 } else {
42 let named = s.fields().nth(0).map(|x| x.name.is_some()).unwrap_or(false);
43
44 if named {
45 let mut ser_struct = serializer
46 .serialize_struct(&self.0.schema().name, s.fields().count())?;
47 for field in s.fields() {
48 ser_struct.serialize_field(
49 field.name.as_ref().unwrap(),
50 &SchemaSerializer(field.value),
51 )?;
52 }
53 ser_struct.end()
54 } else {
55 let mut seq = serializer.serialize_seq(Some(s.fields().count()))?;
56 for field in s.fields() {
57 seq.serialize_element(&SchemaSerializer(field.value))?;
58 }
59 seq.end()
60 }
61 }
62 }
63 SchemaRefAccess::Vec(v) => {
64 let mut seq = serializer.serialize_seq(Some(v.len()))?;
65 for item in v.iter() {
66 seq.serialize_element(&SchemaSerializer(item))?;
67 }
68 seq.end()
69 }
70 SchemaRefAccess::Map(m) => {
71 let mut map = serializer.serialize_map(Some(m.len()))?;
72 for (key, value) in m.iter() {
73 map.serialize_entry(&SchemaSerializer(key), &SchemaSerializer(value))?;
74 }
75 map.end()
76 }
77 SchemaRefAccess::Enum(e) => {
78 let variant_idx = e.variant_idx();
79 let variant_info = e.variant_info();
80 let access = e.value();
81 let field_count = access.fields().count();
82
83 if field_count == 0 {
84 serializer.serialize_unit_variant(
85 &self.0.schema().name,
86 variant_idx,
87 &variant_info.name,
88 )
89 } else if field_count == 1 && access.fields().nth(0).unwrap().name.is_none() {
90 serializer.serialize_newtype_variant(
91 &self.0.schema().name,
92 variant_idx,
93 &variant_info.name,
94 &SchemaSerializer(access.as_schema_ref()),
95 )
96 } else {
97 let mut ser_struct = serializer.serialize_struct_variant(
98 &self.0.schema().name,
99 variant_idx,
100 &variant_info.name,
101 field_count,
102 )?;
103
104 for field in access.fields() {
105 ser_struct.serialize_field(
106 field.name.as_ref().unwrap(),
107 &SchemaSerializer(field.value),
108 )?;
109 }
110
111 ser_struct.end()
112 }
113 }
114 SchemaRefAccess::Primitive(p) => match p {
115 PrimitiveRef::Bool(b) => serializer.serialize_bool(*b),
116 PrimitiveRef::U8(n) => serializer.serialize_u8(*n),
117 PrimitiveRef::U16(n) => serializer.serialize_u16(*n),
118 PrimitiveRef::U32(n) => serializer.serialize_u32(*n),
119 PrimitiveRef::U64(n) => serializer.serialize_u64(*n),
120 PrimitiveRef::U128(n) => serializer.serialize_u128(*n),
121 PrimitiveRef::I8(n) => serializer.serialize_i8(*n),
122 PrimitiveRef::I16(n) => serializer.serialize_i16(*n),
123 PrimitiveRef::I32(n) => serializer.serialize_i32(*n),
124 PrimitiveRef::I64(n) => serializer.serialize_i64(*n),
125 PrimitiveRef::I128(n) => serializer.serialize_i128(*n),
126 PrimitiveRef::F32(n) => serializer.serialize_f32(*n),
127 PrimitiveRef::F64(n) => serializer.serialize_f64(*n),
128 PrimitiveRef::String(n) => serializer.serialize_str(n),
129 PrimitiveRef::Opaque { .. } => {
130 use serde::ser::Error;
131 Err(S::Error::custom("Cannot serialize opaque types"))
132 }
133 },
134 }
135 }
136 }
137
138 pub struct SchemaDeserializer(pub &'static Schema);
143
144 impl<'de> DeserializeSeed<'de> for SchemaDeserializer {
145 type Value = SchemaBox;
146
147 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
148 where
149 D: serde::Deserializer<'de>,
150 {
151 let mut ptr = SchemaBox::default(self.0);
153
154 ptr.as_mut().deserialize(deserializer)?;
156
157 Ok(ptr)
158 }
159 }
160
161 impl<'a, 'de> DeserializeSeed<'de> for SchemaRefMut<'a> {
162 type Value = ();
163
164 fn deserialize<D>(mut self, deserializer: D) -> Result<Self::Value, D::Error>
165 where
166 D: serde::Deserializer<'de>,
167 {
168 if let Some(schema_deserialize) = self.schema().type_data.get::<SchemaDeserialize>() {
170 return schema_deserialize.deserialize(self, deserializer);
171 }
172
173 match &self.schema().kind {
174 SchemaKind::Struct(s) => {
175 if s.fields.len() == 1 && s.fields[0].name.is_none() {
177 unsafe { SchemaRefMut::from_ptr_schema(self.as_ptr(), s.fields[0].schema) }
180 .deserialize(deserializer)?
181 } else {
182 deserializer.deserialize_any(StructVisitor(self))?
183 }
184 }
185 SchemaKind::Vec(_) => deserializer.deserialize_seq(VecVisitor(self))?,
186 SchemaKind::Map { .. } => deserializer.deserialize_map(MapVisitor(self))?,
187 SchemaKind::Enum(_) => deserializer.deserialize_any(EnumVisitor(self))?,
188 SchemaKind::Box(_) => self.into_box().unwrap().deserialize(deserializer)?,
189 SchemaKind::Primitive(p) => {
190 match p {
191 Primitive::Bool => *self.cast_mut() = bool::deserialize(deserializer)?,
192 Primitive::U8 => *self.cast_mut() = u8::deserialize(deserializer)?,
193 Primitive::U16 => *self.cast_mut() = u16::deserialize(deserializer)?,
194 Primitive::U32 => *self.cast_mut() = u32::deserialize(deserializer)?,
195 Primitive::U64 => *self.cast_mut() = u64::deserialize(deserializer)?,
196 Primitive::U128 => *self.cast_mut() = u128::deserialize(deserializer)?,
197 Primitive::I8 => *self.cast_mut() = i8::deserialize(deserializer)?,
198 Primitive::I16 => *self.cast_mut() = i16::deserialize(deserializer)?,
199 Primitive::I32 => *self.cast_mut() = i32::deserialize(deserializer)?,
200 Primitive::I64 => *self.cast_mut() = i64::deserialize(deserializer)?,
201 Primitive::I128 => *self.cast_mut() = i128::deserialize(deserializer)?,
202 Primitive::F32 => *self.cast_mut() = f32::deserialize(deserializer)?,
203 Primitive::F64 => *self.cast_mut() = f64::deserialize(deserializer)?,
204 Primitive::String => *self.cast_mut() = String::deserialize(deserializer)?,
205 Primitive::Opaque { .. } => {
206 return Err(D::Error::custom(
207 "Opaque types must be #[repr(C)] or have `SchemaDeserialize` type \
208 data in order to be deserialized.",
209 ));
210 }
211 };
212 }
213 };
214
215 Ok(())
216 }
217 }
218
219 struct StructVisitor<'a>(SchemaRefMut<'a>);
220 impl<'a, 'de> Visitor<'de> for StructVisitor<'a> {
221 type Value = ();
222 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
223 write!(
224 formatter,
225 "asset metadata matching the schema: {:#?}",
226 self.0.schema()
227 )
228 }
229
230 fn visit_seq<A>(mut self, mut seq: A) -> Result<Self::Value, A::Error>
231 where
232 A: serde::de::SeqAccess<'de>,
233 {
234 let field_count = self.0.schema().kind.as_struct().unwrap().fields.len();
235
236 for i in 0..field_count {
237 let field = self.0.access_mut().field(i).unwrap().into_schema_ref_mut();
238 if seq.next_element_seed(field)?.is_none() {
239 break;
240 }
241 }
242
243 Ok(())
244 }
245
246 fn visit_map<A>(mut self, mut map: A) -> Result<Self::Value, A::Error>
247 where
248 A: serde::de::MapAccess<'de>,
249 {
250 while let Some(key) = map.next_key::<String>()? {
251 match self
252 .0
253 .access_mut()
254 .field(&key)
255 .map(|x| x.into_schema_ref_mut())
256 {
257 Ok(field) => {
258 map.next_value_seed(field)?;
259 }
260 Err(_) => {
261 let fields = &self.0.schema().kind.as_struct().unwrap().fields;
262 let mut msg = format!("unknown field `{key}`, ");
263 if !fields.is_empty() {
264 msg += "expected one of ";
265 for (i, field) in fields.iter().enumerate() {
266 msg += &field
267 .name
268 .as_ref()
269 .map(|x| format!("`{x}`"))
270 .unwrap_or_else(|| format!("`{i}`"));
271 if i < fields.len() - 1 {
272 msg += ", "
273 }
274 }
275 } else {
276 msg += "there are no fields"
277 }
278 return Err(A::Error::custom(msg));
279 }
280 }
281 }
282
283 Ok(())
284 }
285 }
286
287 struct VecVisitor<'a>(SchemaRefMut<'a>);
288 impl<'a, 'de> Visitor<'de> for VecVisitor<'a> {
289 type Value = ();
290 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
291 write!(
292 formatter,
293 "asset metadata matching the schema: {:#?}",
294 self.0.schema()
295 )
296 }
297
298 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
299 where
300 A: serde::de::SeqAccess<'de>,
301 {
302 let v = unsafe { &mut *(self.0.as_ptr() as *mut SchemaVec) };
304 loop {
305 let item_schema = v.schema();
306 let mut item = SchemaBox::default(item_schema);
307 let item_ref = item.as_mut();
308 if seq.next_element_seed(item_ref)?.is_none() {
309 break;
310 }
311 v.push_box(item);
312 }
313
314 Ok(())
315 }
316 }
317 struct MapVisitor<'a>(SchemaRefMut<'a>);
318 impl<'a, 'de> Visitor<'de> for MapVisitor<'a> {
319 type Value = ();
320 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
321 write!(
322 formatter,
323 "asset metadata matching the schema: {:#?}",
324 self.0.schema()
325 )
326 }
327
328 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
329 where
330 A: serde::de::MapAccess<'de>,
331 {
332 let v = unsafe { &mut *(self.0.as_ptr() as *mut SchemaMap) };
334 let is_ustr = v.key_schema() == Ustr::schema();
335 if v.key_schema() != String::schema() && !is_ustr {
336 return Err(A::Error::custom(
337 "Can only deserialize maps with `String` or `Ustr` keys.",
338 ));
339 }
340 while let Some(key) = map.next_key::<String>()? {
341 let key = if is_ustr {
342 SchemaBox::new(ustr(&key))
343 } else {
344 SchemaBox::new(key)
345 };
346 let mut value = SchemaBox::default(v.value_schema());
347 map.next_value_seed(value.as_mut())?;
348
349 v.insert_box(key, value);
350 }
351 Ok(())
352 }
353 }
354 struct EnumVisitor<'a>(SchemaRefMut<'a>);
355 impl<'a, 'de> Visitor<'de> for EnumVisitor<'a> {
356 type Value = ();
357 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
358 write!(
359 formatter,
360 "asset metadata matching the schema: {:#?}",
361 self.0.schema()
362 )
363 }
364
365 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
366 where
367 E: Error,
368 {
369 let enum_info = self.0.schema().kind.as_enum().unwrap();
370 let var_idx = enum_info
371 .variants
372 .iter()
373 .position(|x| x.name == v)
374 .ok_or_else(|| E::invalid_value(Unexpected::Str(v), &self))?;
375
376 if !enum_info.variants[var_idx]
377 .schema
378 .kind
379 .as_struct()
380 .unwrap()
381 .fields
382 .is_empty()
383 {
384 return Err(E::custom(format!(
385 "Cannot deserialize enum variant with fields from string: {v}"
386 )));
387 }
388
389 unsafe {
391 match enum_info.tag_type {
392 EnumTagType::U8 => self.0.as_ptr().cast::<u8>().write(var_idx as u8),
393 EnumTagType::U16 => self.0.as_ptr().cast::<u16>().write(var_idx as u16),
394 EnumTagType::U32 => self.0.as_ptr().cast::<u32>().write(var_idx as u32),
395 }
396 }
397
398 Ok(())
399 }
400
401 fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
402 where
403 A: serde::de::EnumAccess<'de>,
404 {
405 let (value_ptr, var_access) = data.variant_seed(EnumLoad(self.0))?;
406 var_access.newtype_variant_seed(value_ptr)?;
407 Ok(())
408 }
409 }
410
411 struct EnumLoad<'a>(SchemaRefMut<'a>);
412 impl<'a, 'de> DeserializeSeed<'de> for EnumLoad<'a> {
413 type Value = SchemaRefMut<'a>;
414
415 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
416 where
417 D: serde::Deserializer<'de>,
418 {
419 let var_name = String::deserialize(deserializer)?;
420 let enum_info = self.0.schema().kind.as_enum().unwrap();
421 let value_offset = self.0.schema().field_offsets()[0].1;
422 let (var_idx, var_schema) = enum_info
423 .variants
424 .iter()
425 .enumerate()
426 .find_map(|(idx, info)| (info.name == var_name).then_some((idx, info.schema)))
427 .ok_or_else(|| {
428 D::Error::custom(format!(
429 "Unknown enum variant `{var_name}`, expected one of: {}",
430 enum_info
431 .variants
432 .iter()
433 .map(|x| format!("`{}`", x.name))
434 .collect::<Vec<_>>()
435 .join(", ")
436 ))
437 })?;
438
439 match enum_info.tag_type {
442 EnumTagType::U8 => unsafe { self.0.as_ptr().cast::<u8>().write(var_idx as u8) },
443 EnumTagType::U16 => unsafe { self.0.as_ptr().cast::<u16>().write(var_idx as u16) },
444 EnumTagType::U32 => unsafe { self.0.as_ptr().cast::<u32>().write(var_idx as u32) },
445 }
446
447 if var_schema.kind.as_struct().is_none() {
448 return Err(D::Error::custom(
449 "All enum variant types must have a struct Schema",
450 ));
451 }
452
453 unsafe {
454 Ok(SchemaRefMut::from_ptr_schema(
455 self.0.as_ptr().add(value_offset),
456 var_schema,
457 ))
458 }
459 }
460 }
461}
462
463pub struct SchemaDeserialize {
469 pub deserialize_fn: for<'a, 'de> fn(
471 SchemaRefMut<'a>,
472 deserializer: &'a mut dyn Deserializer<'de>,
473 ) -> Result<(), erased_serde::Error>,
474}
475
476unsafe impl HasSchema for SchemaDeserialize {
477 fn schema() -> &'static Schema {
478 use std::{alloc::Layout, any::TypeId, sync::OnceLock};
479 static S: OnceLock<&'static Schema> = OnceLock::new();
480 let layout = Layout::new::<Self>();
481 S.get_or_init(|| {
482 SCHEMA_REGISTRY.register(SchemaData {
483 name: type_name::<Self>().into(),
484 full_name: format!("{}::{}", module_path!(), type_name::<Self>()).into(),
485 kind: SchemaKind::Primitive(Primitive::Opaque {
486 size: layout.size(),
487 align: layout.align(),
488 }),
489 type_id: Some(TypeId::of::<Self>()),
490 clone_fn: None,
491 drop_fn: None,
492 default_fn: None,
493 hash_fn: None,
494 eq_fn: None,
495 type_data: Default::default(),
496 })
497 })
498 }
499}
500
501impl SchemaDeserialize {
502 pub fn deserialize<'a, 'de, D>(
505 &self,
506 reference: SchemaRefMut<'a>,
507 deserializer: D,
508 ) -> Result<(), D::Error>
509 where
510 D: serde::Deserializer<'de>,
511 {
512 let mut erased = <dyn erased_serde::Deserializer>::erase(deserializer);
513 (self.deserialize_fn)(reference, &mut erased)
514 .map_err(<<D as serde::Deserializer<'de>>::Error as serde::de::Error>::custom)
515 }
516}
517
518impl<T: HasSchema + for<'de> Deserialize<'de>> FromType<T> for SchemaDeserialize {
519 fn from_type() -> Self {
520 SchemaDeserialize {
521 deserialize_fn: |reference, deserializer| {
522 T::schema()
523 .ensure_match(reference.schema())
524 .map_err(|e| erased_serde::Error::custom(e.to_string()))?;
525 let data = T::deserialize(deserializer)?;
526
527 unsafe {
529 reference.as_ptr().cast::<T>().write(data);
530 }
531
532 Ok(())
533 },
534 }
535 }
536}
537
538#[cfg(test)]
539mod test {
540 use super::*;
541 use bones_schema_macros::HasSchema;
542
543 #[derive(HasSchema, Clone, Default)]
544 #[schema_module(crate)]
545 #[repr(C)]
546 struct MyData {
547 name: String,
548 age: Age,
549 favorite_things: SVec<String>,
550 map: SMap<String, String>,
551 }
552
553 #[derive(HasSchema, Clone, Default)]
554 #[schema_module(crate)]
555 #[repr(C)]
556 struct Age(u32);
557
558 const DEMO_YAML: &str = r"name: John
559age: 8
560favorite_things:
561- jelly
562- beans
563map:
564 hello: world
565";
566
567 #[test]
568 fn schema_deserializer() {
569 let deserializer = serde_yaml::Deserializer::from_str(DEMO_YAML);
570
571 let data = SchemaDeserializer(MyData::schema())
572 .deserialize(deserializer)
573 .unwrap()
574 .cast_into::<MyData>();
575
576 assert_eq!(data.name, "John");
577 assert_eq!(data.age.0, 8);
578 assert_eq!(
579 data.favorite_things,
580 ["jelly".to_string(), "beans".to_string()]
581 .into_iter()
582 .collect::<SVec<_>>()
583 );
584 assert_eq!(
585 data.map.into_iter().next().unwrap(),
586 (&"hello".to_string(), &"world".to_string())
587 );
588 }
589
590 #[test]
591 fn schema_serializer() {
592 let mut data = Vec::new();
593 let mut serializer = serde_yaml::Serializer::new(&mut data);
594
595 SchemaSerializer(
596 MyData {
597 name: "John".into(),
598 age: Age(8),
599 favorite_things: ["jelly".to_string(), "beans".to_string()]
600 .into_iter()
601 .collect(),
602 map: [("hello".to_string(), "world".to_string())]
603 .into_iter()
604 .collect(),
605 }
606 .as_schema_ref(),
607 )
608 .serialize(&mut serializer)
609 .unwrap();
610
611 assert_eq!(DEMO_YAML, String::from_utf8(data).unwrap());
612 }
613}