Search code examples
jsonrustserde

Using serde_json to serialise maps with non-String keys


I've written a test case:

use serde::{Serialize, Deserialize};
use std::collections::BTreeMap;
use std::fmt;

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord)]
struct Incline {
    rise: u8,
    distance: u8,
}

impl Incline {
    pub fn new(rise: u8, distance: u8) -> Incline {
        Incline {rise, distance}
    }
}

impl fmt::Display for Incline {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}:{}", self.rise, self.distance)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn display_format() {
        let incline = Incline::new(4, 3);
        assert_eq!(format!("{}", incline), "4:3");
    }

    #[test]
    fn serialisation() {
        let key = Incline::new(4, 3);
        let value = "a steep hill";

        let mut map: BTreeMap<Incline, &str> = BTreeMap::new();
        map.insert(key, value);
        let serialised = serde_json::to_string(&map).unwrap();

        assert_eq!(serialised, r#"{"4:3":"a steep hill"}"#);
    }
}

The display_format test passes as expected.

The serialisation test fails with an error:

thread 'tests::serialisation' panicked at 'called `Result::unwrap()` on an `Err` value: Error("key must be a string", line: 0, column: 0)', src/lib.rs:40:54

How do I tell serde_json to use Incline's implementation of std::fmt::Display::fmt to turn the Incline::new(4,3) into "4:3"?


Solution

  • With a little more searching I realised that I had to implement serialise myself.

    This does the job:

    use serde::{Serialize, Serializer};
    
    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
    struct Incline {
        rise: u8,
        distance: u8,
    }
    
    impl Incline {
        pub fn new(rise: u8, distance: u8) -> Incline {
            Incline {rise, distance}
        }
    }
    
    impl Serialize for Incline {
        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
        where
            S: Serializer,
        {
            serializer.serialize_str(&format!("{}:{}", self.rise, self.distance))
        }
    }
    
    #[cfg(test)]
    mod tests {
        use std::collections::BTreeMap;
        use super::*;
    
        #[test]
        fn serialisation() {
            let key = Incline::new(4, 3);
            let value = "a steep hill";
    
            let mut map: BTreeMap<Incline, &str> = BTreeMap::new();
            map.insert(key, value);
            let serialised = serde_json::to_string(&map).unwrap();
    
            assert_eq!(serialised, r#"{"4:3":"a steep hill"}"#);
        }
    }
    

    In full, serialisation and deserialisation look like:

    use serde::{Serialize, Serializer, Deserialize, Deserializer};
    use serde::de::{self, Visitor, Unexpected};
    use std::fmt;
    use std::str::FromStr;
    use regex::Regex;
    
    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
    struct Incline {
        rise: u8,
        distance: u8,
    }
    
    impl Incline {
        pub fn new(rise: u8, distance: u8) -> Incline {
            Incline {rise, distance}
        }
    }
    
    impl Serialize for Incline {
        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
        where
            S: Serializer,
        {
            serializer.serialize_str(&format!("{}:{}", self.rise, self.distance))
        }
    }
    
    struct InclineVisitor;
    
    impl<'de> Visitor<'de> for InclineVisitor {
        type Value = Incline;
    
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
            formatter.write_str("a colon-separated pair of integers between 0 and 255")
        }
    
        fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
        where
            E: de::Error,
        {
            let re = Regex::new(r"(\d+):(\d+)").unwrap(); // PERF: move this into a lazy_static!
            if let Some(nums) = re.captures_iter(s).next() {
                if let Ok(rise) = u8::from_str(&nums[1]) { // nums[0] is the whole match, so we must skip that
                    if let Ok(distance) = u8::from_str(&nums[2]) {
                        Ok(Incline::new(rise, distance))
                    } else {
                        Err(de::Error::invalid_value(Unexpected::Str(s), &self))
                    }
                } else {
                    Err(de::Error::invalid_value(Unexpected::Str(s), &self))
                }
            } else {
                Err(de::Error::invalid_value(Unexpected::Str(s), &self))
            }
        }
    
    }
    
    impl<'de> Deserialize<'de> for Incline {
        fn deserialize<D>(deserializer: D) -> Result<Incline, D::Error>
        where
            D: Deserializer<'de>,
        {
            deserializer.deserialize_string(InclineVisitor)
        }
    }
    
    #[cfg(test)]
    mod tests {
        use std::collections::BTreeMap;
        use super::*;
    
        #[test]
        fn serialisation() {
            let key = Incline::new(4, 3);
            let value = "a steep hill";
    
            let mut map: BTreeMap<Incline, &str> = BTreeMap::new();
            map.insert(key, value);
            let serialised = serde_json::to_string(&map).unwrap();
    
            assert_eq!(serialised, r#"{"4:3":"a steep hill"}"#);
        }
    
        #[test]
        fn deserialisation() {
            let json = r#"{"4:3":"a steep hill"}"#;
    
            let deserialised: BTreeMap<Incline, &str> = serde_json::from_str(&json).unwrap();
    
            let key = Incline::new(4, 3);
            let value = "a steep hill";
    
            let mut map: BTreeMap<Incline, &str> = BTreeMap::new();
            map.insert(key, value);
    
            assert_eq!(deserialised, map);
        }
    }