Search code examples
rustserde

Is there a better way to derive serde impls for untagged Rust enums with unit variants?


My team has been struggling to derive Deserialize and Serialize impls for an enum that should be serialized as a string, with some predefined string values and with an "other" case. The gist of the problem is that we don't know of a way to support unit variants in an untagged enum without lots of boilerplate.

I want to write an enum like this,

#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum BinaryComparisonOperator {
    LessThan,
    GreaterThan,
    Equal,
    Custom(String),
}

And I want it to serialize like this:

Rust JSON
BinaryComparisonOperator::LessThan "less_than"
BinaryComparisonOperator::GreaterThan "greater_than"
BinaryComparisonOperator::Equal "equal"
BinaryComparisonOperator::Custom("other_op") "other_op"

But that doesn't work. The unit variants (LessThan, GreaterThan, Equal) serialize as expected, but the Custom variant serializes with an external tag like this:

{ "custom": "other_op" } // not what we want!

I tried adding #[serde(untagged)] to the enum to remove the tag. That fixes the Custom case, but then all of the unit variants fail to serialize or to deserialize. All of our tests deserialize to the Custom variant when they shouldn't, and all of the unit variants serialize to Null.

We came up with this version which works, but is way more complicated than I like:

use serde::{de, Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Deserialize)]
#[serde(untagged)]
pub enum BinaryComparisonOperator {
    #[serde(deserialize_with = "parse_less_than")]
    LessThan,
    #[serde(deserialize_with = "parse_greater_than")]
    GreaterThan,
    #[serde(deserialize_with = "parse_equal")]
    Equal,
    Custom(String),
}

impl Serialize for BinaryComparisonOperator {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        match self {
            BinaryComparisonOperator::LessThan => serializer.serialize_str("less_than"),
            BinaryComparisonOperator::GreaterThan => serializer.serialize_str("greater_than"),
            BinaryComparisonOperator::Equal => serializer.serialize_str("equal"),
            BinaryComparisonOperator::Custom(s) => {
                serializer.serialize_str(&s)
            }
        }
    }
}

fn string_p<'de, D>(expected: String, input: String) -> Result<(), D::Error>
where
    D: de::Deserializer<'de>,
{
    if input == expected {
        Ok(())
    } else {
        Err(de::Error::custom("invalid value"))
    }
}

fn parse_less_than<'de, D>(deserializer: D) -> Result<(), D::Error>
where
    D: de::Deserializer<'de>,
{
    let s = String::deserialize(deserializer)?;
    string_p::<'de, D>(s, "less_than".to_owned())
}

// lots more parse helpers...

I saw https://github.com/serde-rs/serde/issues/1158 which seems to imply that our complicated solution is necessary. But I'm hoping that maybe there is a better solution out there that we have overlooked?


Solution

  • You can use two enums.

    #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
    #[serde(untagged)]
    pub enum SerdeBinCompOp {
        Known(KnownOp),
        Custom(String),
    }
    
    #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
    #[serde(rename_all = "snake_case")]
    pub enum KnownOp {
        LessThan,
        GreaterThan,
        Equal,
    }
    

    This allows you to leave the parent enum untagged while using the child enum's variant names.

    This has the disadvantage that if you want the final enum to be flat, you need to have two enums with all the known variants, plus conversion implementations. This could be reduced with a macro.

    #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
    #[serde(from = "SerdeBinCompOp", into = "SerdeBinCompOp")]
    pub enum BinaryComparisonOperator {
        LessThan,
        GreaterThan,
        Equal,
        Custom(String),
    }
    
    impl From<BinaryComparisonOperator> for SerdeBinCompOp {
        fn from(bco: BinaryComparisonOperator) -> Self {
            match bco {
                BinaryComparisonOperator::LessThan => Self::Known(KnownOp::LessThan),
                BinaryComparisonOperator::GreaterThan => Self::Known(KnownOp::GreaterThan),
                BinaryComparisonOperator::Equal => Self::Known(KnownOp::Equal),
                BinaryComparisonOperator::Custom(s) => Self::Custom(s),
            }
        }
    }
    
    impl From<SerdeBinCompOp> for BinaryComparisonOperator {
        fn from(sbco: SerdeBinCompOp) -> Self {
            match sbco {
                SerdeBinCompOp::Known(KnownOp::LessThan) => Self::LessThan,
                SerdeBinCompOp::Known(KnownOp::GreaterThan) => Self::GreaterThan,
                SerdeBinCompOp::Known(KnownOp::Equal) => Self::Equal,
                SerdeBinCompOp::Custom(s) => Self::Custom(s),
            }
        }
    }