Search code examples
rustserde

Custom deserializer to deserialize a HashMap over arbitrary types


I have a tricky situation where I have a collection of key-value pairs, and I need a custom serde Deserializer (custom data format in serde documentation) that can map those pairs over a struct.

The key-value pairs look something like this:

("device:address:ip", "127.0.0.1"),
("device:address:port", "9001"),
("device:keys:0", "0000"),
("device:keys:1", "1111"),

And the idea is to be able to deserialize them over top of a data model such as this one:

struct DeviceAddress {
  ip: String,
  port: u16
}

struct DeviceConfig {
  address: DeviceAddress
  keys: Vec<String>
}

struct Config {
  device: DeviceConfig
}

In that scenario, the goal is to deserialize a HashMap<String, String> of such key-value pairs onto the Config struct. Instead of a &str input, I am looking at a **HashMap<String, String> ** as input to the deserializer.

To be as clear as possible, as serde walks the data model:

  1. deserialize struct (Config)
  2. deserialize struct field (device)
  3. deserialize struct (DeviceConfig)
  4. deserialize struct field (address)
  5. deserialize struct (DeviceAddress)
  6. deserialize struct field (ip)
  7. deserialize value (String)

I need to construct this key:

device:address:ip

After which I can lookup the value in the HashMap of key-value pairs to resolve deserialization in step 7 (resolving value of the struct field).

Question

I know the approach I need to take to do this; however, where I am stuck is within the deserializer, how do I keep track of the path as the serde deserialization infrastructure walks over the data model?


Solution

  • I found a pathway to solving this deserialization problem by thinking more about how serde walks the data structure. Serde provides information about structs and their fields that a deserializer can remember while being walked by the serde deserialization infrastructure.

    For this particular deserialization implementation, the idea is to build up a key path as the deserializer walks over the data model, and then resolve a value from that key path when reaching a leaf, such as the value to be emitted for a struct field.

    It looks like this:

    struct ConfigMap<'a>(HashMap<&'a str, &'a str>);
    
    struct Deserializer<'a> {
        config: &'a ConfigMap<'a>,
        fields: Vec<&'static str>,
        key: Vec<String>
    }
    
    impl <'a> Deserializer<'a> {
        pub fn new(config: &'a ConfigMap<'a>) -> Self {
            Deserializer {
                config,
                fields: vec![],
                key: vec![]
            }
        }
        pub fn read<T: FromStr>(&mut self) -> Option<T> {
            let key = self.key.join(":");
            self.key.pop();
            self.config.0
                .get(key.as_str())
                .and_then(|v| v.parse().ok())
        }
    }
    
    impl <'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> {
        type Error = serde::de::value::Error;
    
        fn deserialize_any<V>(self, _: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            Err(serde::de::Error::custom(
                "Deserialization to target type is unsupported"))
        }
    
        fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<bool>().unwrap_or_default();
            visitor.visit_bool(value)
        }
    
        fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<i8>().unwrap_or_default();
            visitor.visit_i8(value)
        }
    
        fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<i16>().unwrap_or_default();
            visitor.visit_i16(value)
        }
    
        fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<i32>().unwrap_or_default();
            visitor.visit_i32(value)
        }
    
        fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<i64>().unwrap_or_default();
            visitor.visit_i64(value)
        }
    
        fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<u8>().unwrap_or_default();
            visitor.visit_u8(value)
        }
    
        fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<u16>().unwrap_or_default();
            visitor.visit_u16(value)
        }
    
        fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<u32>().unwrap_or_default();
            visitor.visit_u32(value)
        }
    
        fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<u64>().unwrap_or_default();
            visitor.visit_u64(value)
        }
    
        fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<f32>().unwrap_or_default();
            visitor.visit_f32(value)
        }
    
        fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<f64>().unwrap_or_default();
            visitor.visit_f64(value)
        }
    
        fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<char>().unwrap_or_default();
            visitor.visit_char(value)
        }
    
        fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            let value = self.read::<String>();
            visitor.visit_string(value.unwrap_or_default())
        }
    
        fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            visitor.visit_map(self)
        }
    
        fn deserialize_struct<V>(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            self.fields.clear();
            for field in fields.iter() {
                self.fields.push(field);
            }
    
            self.deserialize_map(visitor)
        }
    
        fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
        where
            V: Visitor<'de>
        {
            if let Some(field) = self.fields.pop() {
                self.key.push(field.to_string());
                let value = visitor.visit_str(field);
                return value;
            }
    
            Err(serde::de::Error::custom("invalid field"))
        }
    
        forward_to_deserialize_any! {
            enum ignored_any tuple tuple_struct seq
            newtype_struct unit_struct unit option
            byte_buf bytes str
        }
    }
    
    impl <'de, 'a> MapAccess<'de> for Deserializer<'de> {
        type Error = serde::de::value::Error;
    
        fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
        where
            K: DeserializeSeed<'de>
        {
            if self.fields.len() > 0 {
                seed.deserialize(&mut *self).map(Some)
            } else {
                self.key.pop();
                Ok(None)
            }
        }
    
        fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
        where
            V: DeserializeSeed<'de>
        {
            seed.deserialize(&mut *self)
        }
    }
    
    

    This serializer is now capable of deserializing this example:

    #[derive(Debug, Default, Deserialize)]
    pub struct Address {
        pub ip: String,
        pub port: u16
    }
    
    #[derive(Debug, Default, Deserialize)]
    pub struct Device {
        pub address: Address,
        pub enabled: bool
    }
    
    #[derive(Debug, Default, Deserialize)]
    pub struct Config {
        pub device: Device
    }
    
    let mut config = HashMap::new();
    config.insert("device:address:ip", "127.0.0.1");
    config.insert("device:address:port", "9206");
    config.insert("device:enabled", "true");
    
    let map = ConfigMap(config);
    let mut deserializer = Deserializer::new(&map);
    let config: Config = Config::deserialize(&mut deserializer).expect("failed to parse");
    
    println!("{:?}", config);
    
    // Config { device: Device { address: Address { ip: "127.0.0.1", port: 9206 }, enabled: true } }