Search code examples
rustserde

How to implement a custom deserializer using serde that allows for parsing of untagged enums with non self-describing data format


I have a data format that I am currently using serde to parse. The actual format isn't important, but assume it's a string that looks like key_value_also_value_key2_value2_key3_value3 where there's no way to distinguish between where values end and the next key start. I feel like this is what serde describes as a non-self-describing data format.

I have a working Deserializer that can take this code and parse it to a Struct (as the Struct fields can then be used to figure out where keys start and stop). The implementation looks like this:

#[derive(Debug)]
pub struct Deserializer<'de> {
    // This string starts with the input data and when something has been read, it is removed from here
    input: &'de str,
}

impl<'de> Deserializer<'de> {
    pub fn from_str(input: &'de str) -> Self {
        Deserializer { input }
    }
}

impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
    type Error = Error;

    /// Deserialize to a struct. This is currently the only thing we support as we need the struct so we can tell the field names apart
    fn deserialize_struct<V>(
        self,
        _name: &'static str,
        fields: &'static [&'static str],
        visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        // Parses the string into a HashMap given the field names from the struct
        // the actual logic is not important
        let map = crate::parser::emission_factor_id_to_map(self.input, fields)?;

        // Then create a MapDeserializer that implements MapAccess so we can pass it to visitor.visit_map()
        let serializer: MapDeserializer<_, crate::Error> = MapDeserializer::new(map.into_iter());

        // This consumes the whole input string, so set it to "" afterwards.
        self.input = "";

        // And then it works
        visitor.visit_map(serializer)
    }

    // Serde forces us to implement a great many functions we'll never use - have them just throw an error.
    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        Err(crate::Error::Message(
            "Something unexpected happened - deserialize_any was called but it should never be"
                .to_string(),
        ))
    }

    forward_to_deserialize_any! {
        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
        bytes byte_buf option unit unit_struct newtype_struct seq tuple
        tuple_struct map enum identifier ignored_any
    }
}

So far so good. I now want to to allow this Deserializer to deserialize an untagged enum that could look like this:

#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum MyEnum {
    Foo(FooStruct),
    Bar(BarStruct),
}

However parsing this enums seems to go straight to deserialize_any, where I can't seem to figure out how to do any of my parsing. I have access to the input string, but no way to get the struct fields of the different enums, or to forward the call to deserialize_struct.

I can achieve the behaviour I expect by writing it myself outside of the context of the deserializer with something like the following code:

pub fn parse(input: &str){
    let foo = my_deserializer::from_str(&input);
    if let Ok(foo) = foo {
        return Ok(MyEnum::Foo(foo));
    }

    let bar = my_deserializer::from_str(&input);
    if let Ok(bar) = bar {
        return Ok(MyEnum::Bar(bar));
    }
    
    panic!();
}

Is there any way to achieve the above inside my deserializer without having to implement the logic myself for each enum?


Solution

  • Here is a somewhat hacky solution I came up with:

    For non-self-describing formats, you actually need to run the deserializer on the original data for each enum variant. Because cloning is not implemented for the deserializer trait, I had to use a different method to access the underlying data: Repurposing one of the deserialize_* methods of the Deserializer implementation to return the whole remaining input data.

     fn deserialize_bytes<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error> where V: Visitor<'de> {
        // Consume the entire input and return it as bytes
        // Used to create a copy of the deserializer for deserializing untagged enums of non-self-describing formats like RESP
        let bytes = self.input;
        self.input = &[];
        visitor.visit_bytes(bytes)
    }
    

    Then you can implement a custom Deserialize for your untagged enum which utilizes a newly constructed deserializer for each enum variant

    impl<'de> Deserialize<'de> for UntaggedEnum {
            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de> {
                
                struct NonSelfDescribingUntaggedEnumVisitor;
    
                impl<'de> Visitor<'de> for NonSelfDescribingUntaggedEnumVisitor {
                    type Value = Options;
    
                    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
                        formatter.write_str("One of the variants of the enum")
                    }
    
                    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E> where E: Error {
                        let variant_a: Result<A, crate::parser::ParseError> = from_slice(v);
                        if let Ok(res) = variant_a {
                            return Ok(UntaggedEnum::A(res));
                        }
    
                        let variant_b: Result<B, crate::parser::ParseError> = from_slice(v);
                        if let Ok(res) = variant_b {
                            return Ok(Options::B(res));
                        }
    
                        let a_err = variant_a.unwrap_err();
                        let b_err = variant_b.unwrap_err();
                        Err(serde::de::Error::custom(format!("No fitting variant found. \nError for variant A was: {}\nError for variant B was: {}", a_err, b_err)))
                    }
                }
    
                deserializer.deserialize_bytes(NonSelfDescribingUntaggedEnumVisitor)
            }
        }
    

    I'm sure you can turn this into a macro as well :)