Search code examples
rustserdeserde-json

Deserialize a JSON string or array of strings into a Vec


I'm writing a crate that interfaces with a JSON web API. One endpoint usually returns responses of the form { "key": ["value1", "value2"] }, but sometimes there's only one value for the key, and the endpoint returns { "key": "value" } instead of { "key": ["value"] }

I wanted to write something generic for this that I could use with #[serde(deserialize_with)] like so:

#[derive(Deserialize)]
struct SomeStruct {
    #[serde(deserialize_with = "deserialize_string_or_seq_string")]
    field1: Vec<SomeStringNewType>,

    #[serde(deserialize_with = "deserialize_string_or_seq_string")]
    field2: Vec<SomeTypeWithCustomDeserializeFromStr>,
}

#[derive(Deserialize)]
struct SomeStringNewType(String);

struct SomeTypeWithCustomDeserializeFromStr(String);
impl ::serde::de::Deserialize for SomeTypeWithCustomDeserializeFromStr {
    // Some custom implementation here
}

How can I write a deserialize_string_or_seq_string to be able to do this?


Solution

  • This solution works for Serde 1.0.

    The way I found also required me to write a custom deserializer, because I needed one that would call visitor.visit_newtype_struct to try deserializing newtypes, and there don't seem to be any in-built into serde that do so. (I was expecting something like the ValueDeserializer series of types.)

    A self-contained example is below. The SomeStruct is deserialized correctly for both inputs, one where the values are JSON arrays of strings, and the other where they're just strings.

    #[macro_use]
    extern crate serde;
    #[macro_use]
    extern crate serde_derive;
    extern crate serde_json;
    
    fn main() {
        #[derive(Debug, Deserialize)]
        struct SomeStringNewType(String);
    
        #[derive(Debug)]
        struct SomeTypeWithCustomDeserializeFromStr(String);
        impl<'de> ::serde::Deserialize<'de> for SomeTypeWithCustomDeserializeFromStr {
            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: ::serde::Deserializer<'de> {
                struct Visitor;
    
                impl<'de> ::serde::de::Visitor<'de> for Visitor {
                    type Value = SomeTypeWithCustomDeserializeFromStr;
    
                    fn expecting(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
                        write!(f, "a string")
                    }
    
                    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> where E: ::serde::de::Error {
                        Ok(SomeTypeWithCustomDeserializeFromStr(v.to_string() + " custom"))
                    }
                }
    
                deserializer.deserialize_any(Visitor)
            }
        }
    
        #[derive(Debug, Deserialize)]
        struct SomeStruct {
            #[serde(deserialize_with = "deserialize_string_or_seq_string")]
            field1: Vec<SomeStringNewType>,
    
            #[serde(deserialize_with = "deserialize_string_or_seq_string")]
            field2: Vec<SomeTypeWithCustomDeserializeFromStr>,
        }
    
        let x: SomeStruct = ::serde_json::from_str(r#"{ "field1": ["a"], "field2": ["b"] }"#).unwrap();
        println!("{:?}", x);
        assert_eq!(x.field1[0].0, "a");
        assert_eq!(x.field2[0].0, "b custom");
    
        let x: SomeStruct = ::serde_json::from_str(r#"{ "field1": "c", "field2": "d" }"#).unwrap();
        println!("{:?}", x);
        assert_eq!(x.field1[0].0, "c");
        assert_eq!(x.field2[0].0, "d custom");
    }
    
    /// Deserializes a string or a sequence of strings into a vector of the target type.
    pub fn deserialize_string_or_seq_string<'de, T, D>(deserializer: D) -> Result<Vec<T>, D::Error>
        where T: ::serde::Deserialize<'de>, D: ::serde::Deserializer<'de> {
    
        struct Visitor<T>(::std::marker::PhantomData<T>);
    
        impl<'de, T> ::serde::de::Visitor<'de> for Visitor<T>
            where T: ::serde::Deserialize<'de> {
    
            type Value = Vec<T>;
    
            fn expecting(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
                write!(f, "a string or sequence of strings")
            }
    
            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
                where E: ::serde::de::Error {
    
                let value = {
                    // Try parsing as a newtype
                    let deserializer = StringNewTypeStructDeserializer(v, ::std::marker::PhantomData);
                    ::serde::Deserialize::deserialize(deserializer)
                }.or_else(|_: E| {
                    // Try parsing as a str
                    let deserializer = ::serde::de::IntoDeserializer::into_deserializer(v);
                    ::serde::Deserialize::deserialize(deserializer)
                })?;
                Ok(vec![value])
            }
    
            fn visit_seq<A>(self, visitor: A) -> Result<Self::Value, A::Error>
                where A: ::serde::de::SeqAccess<'de> {
    
                ::serde::Deserialize::deserialize(::serde::de::value::SeqAccessDeserializer::new(visitor))
            }
        }
    
        deserializer.deserialize_any(Visitor(::std::marker::PhantomData))
    }
    
    // Tries to deserialize the given string as a newtype
    struct StringNewTypeStructDeserializer<'a, E>(&'a str, ::std::marker::PhantomData<E>);
    
    impl<'de, 'a, E> ::serde::Deserializer<'de> for StringNewTypeStructDeserializer<'a, E> where E: ::serde::de::Error {
        type Error = E;
    
        fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: ::serde::de::Visitor<'de> {
            visitor.visit_newtype_struct(self)
        }
    
        fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: ::serde::de::Visitor<'de> {
            // Called by newtype visitor
            visitor.visit_str(self.0)
        }
    
        forward_to_deserialize_any! {
            bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str bytes
            byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct map
            struct enum identifier ignored_any
        }
    }