Search code examples
rustenumsserde

How to create a custom deserialiser for an internally tagged enum with serde in rust


I have a bunch of entities coming out from a DynamoDB table. Each entity has a type field identifying what type of entity it is. The type field is a single item list with a string with the type name. This is a weird format, but it is what I have to deal with.

I succesfully implemented a custom deserialiser for the type into an enum using Serde. My current approach involves first deserialise into a simplified struct with the type field, then match on that struct type and use the proper deserialiser based on the type.

I saw that Serde supports internally tagged enums to deserialise the rest of the fields that belong to a type in an enum with, but I don't know how to adapt my current deserialiser to opt into this feature.

For reference, this is the current deserialiser that I have:

use serde::{Deserialize, Deserializer, Serialize};
use types::*;

#[derive( Debug)]
enum Type {
    Story,
    Layer,
    Unknown(String),
}

impl<'de> Deserialize<'de> for Type {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        let s = <Vec<String>>::deserialize(deserializer)?;
        let k: &str = s[0].as_str();
        Ok(match k {
            "Story" => Type::Story,
            "Layer" => Type::Layer,
            v => Type::Unknown(v.to_string()),
        })
    }
}

#[derive(Deserialize, Serialize, Debug)]
pub struct SimpleItem {
    #[serde(rename = "type")]
    item_type: Type,
}

My ideal scenario would be this instead:

use serde::{Deserialize, Deserializer, Serialize};
use types::*;

#[derive(Deserialize, Debug)]
#[serde(tag = "type")]
enum Type {
    Story { id: String, name: string, duration: u32 }
    Layer { id: string, layout: string },
    Unknown,
}

And then just use that to get the right types deserialised directly.

This is how my data may look like: { type: ["Story"], ...otherFields}


Solution

  • A custom implementation of Deserialize is usually the best solution for this kind of “complex” deserialization:

    impl<'de> Deserialize<'de> for Type {
        fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<D, D::Error> {
            deserializer.deserialize_map(DeVisitor)
        }
    }
    
    struct DeVisitor;
    impl<'de> de::Visitor<'de> for DeVisitor {
        type Value = Type;
        fn expecting(&self, f: &mut Formatter<'_>) -> fmt::Result {
            f.write_str("a type")
        }
        fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
            // could use another custom Deserialize to avoid
            // allocation and improve error messages
            let type_key = map.next_key::<String>()?;
            if type_key.as_deref() != Some("type") {
                return Err(de::Error::missing_field("type"));
            }
            // You can also use a Vec<String> here so you avoid TypeField, at
            // the gain of less boilerplate but at cost of worse error messages
            let r#type = map.next_value::<TypeField>()?;
            match r#type {
                Type::Story => {
                    #[derive(Deserialize)]
                    struct Story {
                        id: String,
                        name: string,
                        duration: u32,
                    }
                    let story = Story::deserialize(MapAccessDeserializer::new(map))?;
                    // Construct your Type
                } // etc
            }
        }
    }
    
    enum TypeField {
        Story,
        Layer,
        Unknown(String),
    }
    
    // The manual Deserialize implementation you had above
    // goes here
    

    One big limitation of the above approach is that it requires the "type" field to be the first field. If you want to avoid this, the easiëst approach depends on the data format. For example, with JSON it might look something like:

    let mut type_field = None;
    let mut data = <Vec<(String, Box<serde_json::value::RawValue)>>::new();
    while let Some(key) = map.next_key()? {
        if key == "type" {
            if type_field.is_some() {
                return Err(de::Error::duplicate_field("type"));
            }
            type_field = Some(map.next_value()?);
        } else {
            data.push((key, map.next_value()?));
        }
    }
    // Now you can match on type_field and use MapDeserializer
    // to deserialize from `data`