Search code examples
jsonrustserializationyamlserde

Rust Serde - Custom deserialization from map or list


I have a problem in doing a custom deserialization in serde: my goal is to be able to deserialize a list of structs or a map of structs, which has the "id" field of each individual struct as its key, into a list.
Basically, my goal is to be able to get code like the following to work:

#[cfg(test)]
mod tests {
    #[test]
    fn deserialize() {
        #[derive(Debug, serde::Deserialize)]
        struct List {
            entries: Vec<Entry>,
        }

        #[derive(Debug, serde::Deserialize)]
        struct Entry {
            #[serde(flatten)]
            class: Class,
            label: Box<str>,
        }

        #[derive(Debug, serde::Deserialize)]
        #[serde(tag = "class")]
        enum Class {
            ClassOne(ClassOne),
            ClassTwo(ClassTwo),
        }

        #[derive(Debug, serde::Deserialize)]
        struct ClassOne {
            value: u32,
        }

        #[derive(Debug, serde::Deserialize)]
        struct ClassTwo {
            value: f32,
        }

        // ------------------------------

        let input_list = r#"
            entries:
                - class: ClassOne
                  value: 1234
                  label: Test#1
                - class: ClassTwo
                  value: 1.234
                  label: Test#2
                - class: ClassOne
                  value: 4321
                  label: Test#3
        "#;

        let input_map = r#"
            entries:
                ClassOne:
                    value: 1234
                    label: Test#1
                ClassTwo:
                    value: 1.234
                    label: Test#2
                ClassOne:
                    value: 4321
                    label: Test#3
        "#;

        // This works
        let list = serde_yaml::from_str::<List>(input_list).unwrap();
        println!("{list:?}");

        // This panics (as expected)
        //let list = serde_yaml::from_str::<List>(input_map).unwrap();
        //println!("{list:?}");
    }
}

Initially I thought about using the Visitor trait, I need it to be as generic as possible because I have various structures that, from specification, behave like the one in the test above, but I got stuck here (in the error in the form of a comment):

use serde::{
    de::{MapAccess, SeqAccess, Visitor},
    Deserialize,
};

pub(crate) struct MapOrListVisitor<T, A> {
    insert: Box<dyn Fn(&str, &mut A) -> T>,
}

impl<'de, T, A> MapOrListVisitor<T, A>
where
    T: Deserialize<'de>,
    A: MapAccess<'de>,
{
    pub fn new(insert: impl Fn(&str, &mut A) -> T + 'static) -> Self {
        Self {
            insert: Box::new(insert),
        }
    }
}

impl<'de, T, A> Visitor<'de> for MapOrListVisitor<T, A>
where
    T: Deserialize<'de>,
    A: MapAccess<'de>
{
    type Value = Box<[T]>;

    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
        todo!()
    }

    fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
    where
        M: MapAccess<'de>,
    {
        let entries = Vec::<T>::with_capacity(map.size_hint().unwrap_or_default());

        while let Some(key) = map.next_key::<Box<str>>()? {
            let entry = (&self.insert)(&key, &mut map);
            // mismatched types
            //     expected mutable reference `&mut A`
            //     found mutable reference `&mut M`

            todo!()
        }

        Ok(entries.into_boxed_slice())
    }

    fn visit_seq<S>(self, seq: S) -> Result<Self::Value, S::Error>
    where
        S: SeqAccess<'de>,
    {
        todo!()
    }
}

Would anyone know how I can do this?


Solution

  • A generic solution to this problem is tricky, as the two separate data formats you need to support cannot easily be addressed at the same time using serde's derive attributes. @PitaJ's presented solution in the comments above (which, for reference, is viewable here) is not generic, as they rightly pointed out, due to having to basically reimplement deserialization for the map case for every single struct contained within Class. To adequately solve this problem in a way that doesn't require constant tuning of our deserialization logic, we must create a Deserialize implementation for Entry that is generically usable in both the seq and map contexts, and then implement logic to deserialize Entrys in both data formats using that implementation.

    Warning: there will be a good amount of boilerplate in this solution. As I mentioned, we won't be able to take advantage of serde derive attributes in many parts of this solution, including the flatten attribute you were using in your minimal example. However, once the boilerplate is put in place, there will be very little changes required when expanding to more variants within the Class enum.

    Deserializing Entrys

    One of the toughest parts about the two data formats presented in your question is that we encounter the Class variant at different points during the deserialization process: for input_list, the class field is present alongside the other fields in the flattened struct, while in input_map the variant is the map's key. Rather than implementing two separate deserialization strategies for both cases, we instead want to use a deserialization strategy that can handle both cases. Therefore, deserializing the Entry struct should always begin by first obtaining the Class variant. From that point, we can use the variant to deserialize the rest of the flattened Entry struct.

    Begin by defining a ClassVariant enum. Defining a separate enum just to encode the variant may seem weird, but this is actually a common strategy for manually deserializing enums (see serde's own Deserialize implementation for Result, for example).

    #[derive(Debug, Deserialize)]
    enum ClassVariant {
        ClassOne,
        ClassTwo,
    }
    

    To reiterate, in the case of input_list, this is deserialized from a field. In the case of input_map, this is deserialized from a key.

    ClassVariant as a key

    Let's address the case of ClassVariant being deserialized from a key first, as this is the simpler case. Given a deserialized ClassVariant, we can deserialize an Entry using the variant as a seed by implementing DeserializeSeed. While we can't use the flatten attribute, we can instead use the serde-value crate to implement the exact same strategy the derived code uses (see prior discussion for a similar problem).

    use 
    
    #[derive(Debug, Deserialize)]
    struct ClassOne {
        value: u32,
    }
    
    #[derive(Debug, Deserialize)]
    struct ClassTwo {
        value: f32,
    }
    
    #[derive(Debug)]
    enum Class {
        ClassOne(ClassOne),
        ClassTwo(ClassTwo),
    }
    
    impl<'de> DeserializeSeed<'de> for ClassVariant {
        type Value = Entry;
    
        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
        where
            D: Deserializer<'de>,
        {
            enum Field {
                Label,
                Other(String),
            }
    
            impl<'de> Deserialize<'de> for Field {
                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
                where
                    D: Deserializer<'de>,
                {
                    struct FieldVisitor;
    
                    impl<'de> Visitor<'de> for FieldVisitor {
                        type Value = Field;
    
                        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                            formatter.write_str("`label`")
                        }
    
                        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
                        where
                            E: de::Error,
                        {
                            match value {
                                "label" => Ok(Field::Label),
                                name => Ok(Field::Other(name.to_owned())),
                            }
                        }
                    }
    
                    deserializer.deserialize_identifier(FieldVisitor)
                }
            }
    
            struct EntryVisitor(ClassVariant);
    
            impl<'de> Visitor<'de> for EntryVisitor {
                type Value = Entry;
    
                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                    formatter.write_str("struct Entry")
                }
    
                fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
                where
                    A: MapAccess<'de>,
                {
                    let mut label = None;
                    let mut values = BTreeMap::new();
    
                    while let Some(field) = map.next_key()? {
                        match field {
                            Field::Label => {
                                if label.is_some() {
                                    return Err(de::Error::duplicate_field("label"));
                                }
                                label = Some(map.next_value()?);
                            }
                            Field::Other(name) => {
                                values.insert(Value::String(name), map.next_value()?);
                            }
                        }
                    }
    
                    let label = label.ok_or_else(|| de::Error::missing_field("label"))?;
                    let class = match self.0 {
                        ClassVariant::ClassOne => Class::ClassOne(ClassOne::deserialize(
                            ValueDeserializer::new(Value::Map(values)),
                        )?),
                        ClassVariant::ClassTwo => Class::ClassTwo(ClassTwo::deserialize(
                            ValueDeserializer::new(Value::Map(values)),
                        )?),
                    };
    
                    Ok(Entry { class, label })
                }
            }
    
            deserializer.deserialize_map(EntryVisitor(self))
        }
    }
    

    This is similar to code presented in Manually implementing Deserialize for a struct, but instead of implementing Deserialize, we implement DeserializeSeed that will return a value of type Entry. The overall strategy is to deserialize each key as either the label field, or an unknown Other field which contains the flattened fields contained in Class's variant. We collect the unknown fields into a BTreeMap which can then be re-deserialized using serde_value::ValueDeserializer as either ClassOne or ClassTwo, depending on the ClassVariant seed. This re-deserialization allows us to derive a Deserialization implementation for the contained structs, and works nearly identical to serde's flatten attribute (see the source code, which uses a private deserializer for the same purpose).

    Note also that we have to reallocate for the unknown field names; this is a limitation with serde-value; there is an open issue that has had little traction.

    ClassVariant as a field

    In the case of input_list, the ClassVariant is provided in the class field. Therefore, if we can first extract the variant using a similar strategy as that used above for re-deserializing the other fields, we can reuse the rest of the deserialization logic we have already implemented.

    We therefore implement Deserialize for Entry, but only extract the class field, deferring the remainder of the deserialization to ClassVariant's DeserializeSeed implementation.

    #[derive(Debug)]
    struct Entry {
        class: Class,
        label: Box<str>,
    }
    
    impl<'de> Deserialize<'de> for Entry {
        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
        where
            D: Deserializer<'de>,
        {
            enum Field {
                Class,
                Other(String),
            }
    
            impl<'de> Deserialize<'de> for Field {
                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
                where
                    D: Deserializer<'de>,
                {
                    struct FieldVisitor;
    
                    impl<'de> Visitor<'de> for FieldVisitor {
                        type Value = Field;
    
                        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                            formatter.write_str("`class`")
                        }
    
                        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
                        where
                            E: de::Error,
                        {
                            match value {
                                "class" => Ok(Field::Class),
                                name => Ok(Field::Other(name.to_owned())),
                            }
                        }
                    }
    
                    deserializer.deserialize_identifier(FieldVisitor)
                }
            }
    
            struct EntryVisitor;
    
            impl<'de> Visitor<'de> for EntryVisitor {
                type Value = Entry;
    
                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                    formatter.write_str("struct Entry")
                }
    
                fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
                where
                    A: MapAccess<'de>,
                {
                    let mut class: Option<ClassVariant> = None;
                    let mut values = BTreeMap::new();
    
                    while let Some(field) = map.next_key()? {
                        match field {
                            Field::Class => {
                                if class.is_some() {
                                    return Err(de::Error::duplicate_field("class"));
                                }
                                class = Some(map.next_value()?);
                            }
                            Field::Other(name) => {
                                values.insert(Value::String(name), map.next_value()?);
                            }
                        }
                    }
    
                    let class = class.ok_or_else(|| de::Error::missing_field("class"))?;
    
                    class.deserialize(ValueDeserializer::new(Value::Map(values)))
                }
            }
    
            deserializer.deserialize_map(EntryVisitor)
        }
    }
    

    Now we can fully deserialize Entrys in both the seq and map contexts.

    Deserializing List as a seq or a map.

    We still need to write a bit more custom deserialization logic for the List struct. However, due to our excellent Entry implementation, this is relatively painless. By using the deserialize_with attribute, we can easily derive a Deserialize implementation for List by providing a custom function to deserialize the entries field.

    As you mentioned, the Visitor trait can be used here to instruct the deserializer how to visit both maps and seqs, and we can use Deserializer::deserialize_any() to let the deserializer choose which one to use based on the context. For seqs, we can rely on Entry's Deserialize implementation, and for maps we first deserialize a ClassVariant and rely on its DeserializeSeed implementation.

    fn deserialize_entries<'de, D>(deserializer: D) -> Result<Vec<Entry>, D::Error>
    where
        D: Deserializer<'de>,
    {
        struct EntriesVisitor;
    
        impl<'de> Visitor<'de> for EntriesVisitor {
            type Value = Vec<Entry>;
    
            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                formatter.write_str("either a sequence or map of entries")
            }
    
            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
            where
                A: SeqAccess<'de>,
            {
                let mut entries = Vec::new();
    
                while let Some(entry) = seq.next_element()? {
                    entries.push(entry);
                }
    
                Ok(entries)
            }
    
            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
            where
                A: MapAccess<'de>,
            {
                let mut entries = Vec::new();
    
                while let Some(variant) = map.next_key::<ClassVariant>()? {
                    entries.push(map.next_value_seed(variant)?);
                }
    
                Ok(entries)
            }
        }
    
        deserializer.deserialize_any(EntriesVisitor)
    }
    
    #[derive(Debug, Deserialize)]
    struct List {
        #[serde(deserialize_with = "deserialize_entries")]
        entries: Vec<Entry>,
    }
    

    This code can deserialize both input_list and input_map in your example. It is also generic; to expand to more structs in Class, simply add the new struct variant in the Class enum, add a new variant in ClassVariant, and slightly modify the DeserializeSeed implementation to call the new struct's Deserialize implementation for that variant. While adding a new struct still requires a tiny bit of boilerplate, it is nothing compared to reimplementing the Deserialize implementations for each of those cases.

    The full solution can be run, with your examples, on this playground.