Search code examples
rustenumsserdebincode

How to Serialize Enum in Rust with Bincode While Retaining Enum Discriminant Instead of Index?


I've been working on serializing an enum in Rust using bincode, but I am facing an issue where I receive the index of the enum variant instead of its assigned discriminant. Here's an example of the enum I'm trying to serialize:

    #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
    #[repr(u64)]
    pub enum StreamOutputFormat {
        A(X) = 0x01,
        B(Y) = 0x02,
        C(Z) = 0x03,
    }

In this code, X, Y, and Z represent structs.

When I try to serialize an instance of StreamOutputFormat::C(Z {some_z_stuff...}) using bincode like so:

let sof = StreamOutputFormat::C(Z {some_z_stuff...});
println!("sof: {:?}", bincode::serialize(&sof));

The output I get is:

[2, 0, 0, 0, 0, 0, 0, 0, ...]

This is problematic because I need the serialized output to be the discriminant of the enum variant (in this case 0x03), not the index, due to interoperability requirements with other components. For comparison, other enums in my codebase that are unit enums serialize correctly using (De)Serialize_repr.

What is the correct way to serialize (and deserialize) this type of enum in bincode so that I receive the enum variant discriminant instead of its index?


Solution

  • What makes this complicated is that serde's enum deserialization only allows strings, bytes, or u32 for the enum tag (each format chooses one of those three). This is hardcoded into each format, for example here in bincode. An enum with a u64 tag is essentially not an enum as far as serde is concerned.

    So with that in mind, you have to serialize and deserialize your enum as something other than an enum. I've chosen to use a tuple, which is probably as close to an enum as you'll get. The serialization is quite simple since we know what type everything is.

    impl Serialize for StreamOutputFormat {
        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
        where
            S: Serializer,
        {
            match self {
                StreamOutputFormat::A(x) => {
                    let mut tuple = serializer.serialize_tuple(2)?;
                    tuple.serialize_element(&0x01u64)?;
                    tuple.serialize_element(x)?;
                    tuple.end()
                }
                StreamOutputFormat::B(y) => {
                    let mut tuple = serializer.serialize_tuple(2)?;
                    tuple.serialize_element(&0x02u64)?;
                    tuple.serialize_element(y)?;
                    tuple.end()
                }
                StreamOutputFormat::C(z) => {
                    let mut tuple = serializer.serialize_tuple(2)?;
                    tuple.serialize_element(&0x03u64)?;
                    tuple.serialize_element(z)?;
                    tuple.end()
                }
            }
        }
    }
    

    I've kept this extremely repetitive for clarity. If you have a variant with more than one field, you will need to increment the number passed to serialize_tuple. Also don't forget the discriminants need to be u64.

    Now the deserialization.

    impl<'de> Deserialize<'de> for StreamOutputFormat {
        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
        where
            D: serde::Deserializer<'de>,
        {
            use serde::de::Error;
    
            struct StreamOutputFormatVisitor;
    
            impl<'de> Visitor<'de> for StreamOutputFormatVisitor {
                type Value = StreamOutputFormat;
    
                fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
                    write!(
                        formatter,
                        "a tuple of size 2 consisting of a u64 discriminant and a value"
                    )
                }
    
                fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
                where
                    A: serde::de::SeqAccess<'de>,
                {
                    let discriminant: u64 = seq
                        .next_element()?
                        .ok_or_else(|| A::Error::invalid_length(0, &self))?;
                    match discriminant {
                        0x01 => {
                            let x = seq
                                .next_element()?
                                .ok_or_else(|| A::Error::invalid_length(1, &self))?;
                            Ok(StreamOutputFormat::A(x))
                        }
                        0x02 => {
                            let y = seq
                                .next_element()?
                                .ok_or_else(|| A::Error::invalid_length(1, &self))?;
                            Ok(StreamOutputFormat::B(y))
                        }
                        0x03 => {
                            let z = seq
                                .next_element()?
                                .ok_or_else(|| A::Error::invalid_length(1, &self))?;
                            Ok(StreamOutputFormat::C(z))
                        }
                        d => Err(A::Error::invalid_value(
                            serde::de::Unexpected::Unsigned(d),
                            &"0x01, 0x02, or 0x03",
                        )),
                    }
                }
            }
    
            deserializer.deserialize_tuple(2, StreamOutputFormatVisitor)
        }
    }
    

    Getting past the boilerplate, we have a call to deserialize_tuple, which will call visit_seq on the visitor. Inside that method, the visitor consumes a u64 to use as the discriminant, and then consumes the inner data based on that discriminant.

    Methods that don't work

    You can't serialize it by adding dummy fields to the enum since this would still use u32 tags. Another thing you might try is deserializing through a (u64, UntaggedEnum) where UntaggedEnum is:

    #[derive(Deserialize)]
    #[serde(untagged)]
    UntaggedEnum {
        A(X),
        B(Y),
        C(Z),
    }
    

    This doesn't work because non-self-describing formats can't handle untagged enums. On top of that, even self-describing formats could fail if the data is valid for more than one variant, since there is no simple way to conditionally deserialize the second element of a tuple depending on the first element. It is also inefficient since it will try to deserialize the enum even if the u64 was invalid.

    Notes

    You may have added #[repr(u64)] without knowing that serde enums need to be u32, and you can actually work with #[repr(u32)] enums as well. If that's the case, it appears that you can use serde's enum deserialization and it will be slightly simpler (very similar to what the Deserialize macro generates). As far as I could tell, you would need to map the discriminants to their respective variants.

    Notably, the code I wrote never references the actual discriminants given in the enum definition. This is pretty standard for rust, since enum discriminants have very little functionality. You need to do a pointer cast to even read them, and they have absolutely no use in conditionally deserializing X, Y, and Z. If you remove them, it will have no effect on serialization.

    If you intend to modify this enum, or if it has a large quantity of variants, then it would be a good use of time to turn this into a macro. This would not be too hard as a declarative macro since all the values you need are literals like 0x01 and the repetition is obvious.

    I haven't checked out bincode 2.0, which includes its own Decode trait that doesn't use serde. It might be possible to decode u64 enum tags, but the structure is completely different from serde, so I didn't look too much into it.

    This may not match the format you want. Judging by the u64 repr, whatever you're trying to deserialize wasn't made by a serde enum, so it's impossible for me to know if your format will match the tuple format I've used.