Search code examples
rustdeserializationserdetoml

How to handle deserialization of mixed-type values in TOML (string/integer)?


When deserializing a TOML file, how can I instruct the deserializer to treat a value for a specific key as a string, regardless of whether it appears to be an integer?

Background

I want to parse a TOML file where the value of count could be either an integer or a string:

count = 3
# or count = -1
# or count = "all"

Valid values for count (in the TOML file) should be -1..i32::MAX, or "all". When deserializing the value, "all" should be interpreted as -1.

The Problem

This is my code, using a custom deserializer, and it pretty much works ...

use serde::{de::Error, Deserialize, Deserializer};

#[derive(Debug, Deserialize)]
struct ConfigFileOpts {
    #[serde(deserialize_with = "ds_i32_or_string")]
    count: i32,
}

fn ds_i32_or_string<'de, D>(deserializer: D) -> Result<i32, D::Error>
where
    D: Deserializer<'de>,
{
    let res = String::deserialize(deserializer);
    match res {
        Ok(s) => match s.parse::<i32>() {
            Ok(n) if n >= -1 => Ok(n),
            _ => match s.trim().to_lowercase().as_str() {
                "all" => Ok(-1),
                _ => Err(D::Error::custom("Value must be >= -1 or 'all'.")),
            },
        },
        Err(e) => Err(e),
    }
}

fn main() {
    let cfg_opts: ConfigFileOpts = toml::from_str("count = 3").unwrap();
    println!("deserialized count: {:?}", cfg_opts.count);
}

... however, it only works when integer values are written as strings, i.e. count = "3" instead of count = 3. When count is written as count = 3 I'm getting the following error message:

TomlError {
  message: "invalid type: integer `3`, expected a string",
  original: Some("count = 3"),
  keys: ["count"],
  span: Some(8..9)
}

Currently, the deserializer is not called for values that "look like" integers. Is there a way to ensure my deserializer is invoked for both integer and string values? Any suggestions would be appreciated.


Solution

  • Using your original design, you can use derive of serde of an enum to simplify the code for you:

    use serde::{de::Error, Deserialize, Deserializer};
    
    #[derive(Debug, Deserialize)]
    struct ConfigFileOpts {
        #[serde(deserialize_with = "ds_i32_or_string")]
        count: i32,
    }
    
    fn ds_i32_or_string<'de, D>(deserializer: D) -> Result<i32, D::Error>
    where
        D: Deserializer<'de>,
    {
        #[derive(Debug, Deserialize)]
        #[serde(untagged)]
        enum Count {
            S(String),
            I(i32),
        }
    
        let count = Count::deserialize(deserializer)?;
        match count {
            Count::S(s) => {
                if s.trim().to_lowercase().as_str() == "all" {
                    Ok(-1)
                } else {
                    Err(D::Error::custom("Value must be >= -1 or 'all'."))
                }
            }
            Count::I(n) => {
                if n >= -1 {
                    Ok(n)
                } else {
                    Err(D::Error::custom("Value must be >= -1"))
                }
            }
        }
    }
    
    fn main() {
        let cfg_opts: ConfigFileOpts = toml::from_str("count = 42").unwrap();
        assert_eq!(42, cfg_opts.count);
    
        let cfg_opts: ConfigFileOpts = toml::from_str("count = \"all\"").unwrap();
        assert_eq!(-1, cfg_opts.count);
    
        let cfg_opts: ConfigFileOpts = toml::from_str("count = -1").unwrap();
        assert_eq!(-1, cfg_opts.count);
    
        let ret: Result<ConfigFileOpts, _> = toml::from_str("count = -42");
        assert!(ret.is_err());
    }