Search code examples
rustdeserializationserde

Rust deserialise and transform some content


I've got a struct:

struct Embedding {
    values: Vec<f64>,
    original: String,
}

which I want to contain some original text original, as well as the embedding of that text according to OpenAI's CLIP model.

I've got a list of strings in a YAML file (without their embeddings). I want to deserialise that list of strings into a Vec<Embedding>, where the values member of an Embedding should just be the embedded string. I can get the embedding via a function that looks like:

/// Using a pre-loaded CLIP model, embed some given text into a vector of f64's
fn embed_text(clip_model: &ClipModel, text: &str) -> Result<Vec<f64>> {...}

So I need to read the YAML file, deserialise each string, embed each string, and return the Embedding struct that has the embedding as well as the original string.

I can't quite figure out how to make this work using the serde_yaml or serde docs. All the talk of visitors and deserialise_* is kinda going over my head. I initially thought #[serde(deserialize_with = "callback")] was going to be my savior, but I couldn't get that to work because it only works for a field in a struct, not the struct itself.

I've solved this problem before, in a hacky way, and I'd rather do something cleaner if possible. Before, would have made a RawEmbedding and an Embedding struct like so:

#[derive(Serialize, Deserialize)]
struct RawEmbedding {
    original: String,
}

struct Embedding {
    values: Vec<f64>,
    original: String,
}

And then I'd manually write a impl From<RawEmbedding> for Embedding {...} so that I could use serde to deserialise into a RawEmbedding, and then .into() to convert the RawEmbedding into an actual Embedding. I'm kinda tired of duplicating all my structs thought, I'm sure there must be a better way that does the conversion logic inside of serde's deserialisation.

Note that I need access to the ClipModel (or at least a borrow of it) in order to do the embeddings and loading this model from disk is quite expensive so it's not an option to do again and again for every string that I want to embed.


Solution

  • Since obtaining the values using the embedding string requires access to the ClipModel, we need a way to deserialize while having access to additional context. serde has a way to support exactly this, using its DeserializeSeed trait. The documentation actually calls this out as a specific use case: "If you ever find yourself looking for a way to pass data into a Deserialize impl, this trait is the way to do it."

    While serde provides ways to derive its Deserialize trait, it does not provide a way to derive its DeserializeSeed trait. There has been discussion about supporting it in the past, but as things currently are we need to implement the trait manually. This unfortunately involves a bit of boilerplate.

    I will assume, for the sake of the example, that the definition of ClipModel is as follows:

    struct ClipModel;
    

    DeserializeSeed is different than Deserialize in that it is not often implemented on the type you are wanting to deserialize into. Instead, it is implemented on a type containing the context you want to deserialize with. In our case, that context is a reference to ClipModel, so we can use it to call embed_text during deserialization. The simplest way to provide that context is to just implement DeserializeSeed directly on &ClipModel. Naturally, if this doesn't work (perhaps because you don't own the definition of ClipModel), you can also wrap &ClipModel in a newtype.

    use serde::de::{DeserializeSeed, Deserializer, SeqAccess, Visitor};
    use std::{fmt, fmt::Formatter};
    
    impl<'a, 'de> DeserializeSeed<'de> for &'a ClipModel {
        type Value = Vec<Embedding>;
    
        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
        where
            D: Deserializer<'de>,
        {
            struct ClipModelVisitor<'a>(&'a ClipModel);
    
            impl<'a, 'de> Visitor<'de> for ClipModelVisitor<'a> {
                type Value = Vec<Embedding>;
    
                fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
                    formatter.write_str("a sequence of embedded strings")
                }
    
                fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
                where
                    A: SeqAccess<'de>,
                {
                    let mut result = Vec::new();
                    while let Some(text) = seq.next_element::<String>()? {
                        result.push(Embedding {
                            // Unwrap is just for demonstration; handle the error in real code!
                            values: embed_text(self.0, &text).unwrap(),
                            original: text,
                        })
                    }
                    Ok(result)
                }
            }
    
            deserializer.deserialize_seq(ClipModelVisitor(self))
        }
    }
    

    If this looks unfamiliar or confusing, I recommend familiarizing yourself with serde deserialization in general by reading "Implementing Deserialize". The main differences here are:

    • DeserializeSeed requires you to define a Value associated type to indicate the return type of the deserialization. In our case, we declare that we will return a Vec<Embedding>.
    • We must provide our context type (&ClipModel) to our Visitor helper struct, so it can be used when accessing each deserialized string.

    We can verify that this implementation works by providing a simple implementation of embed_text and a simple yaml input:

    // Not sure what the error type is here, since it's not provided in the code included with your question.
    fn embed_text(clip_model: &ClipModel, text: &str) -> Result<Vec<f64>, ()> {
        Ok(match text {
            "foo" => vec![0.1],
            "bar" => vec![1.2],
            _ => Vec::new(),
        })
    }
    
    fn main() {
        let clip_model = ClipModel;
    
        let result = clip_model.deserialize(serde_yaml::Deserializer::from_str("- foo\n- bar\n"));
        dbg!(result);
    }
    

    This will output something similar to the following:

    [src\main.rs:63] result = Ok(
        [
            Embedding {
                values: [
                    0.1,
                ],
                original: "foo",
            },
            Embedding {
                values: [
                    1.2,
                ],
                original: "bar",
            },
        ],
    )
    

    You can run the above code at this playground.

    This is the intended solution provided by the serde framework for what you are attempting to do. Hope that helps!