I've got a struct:
struct Embedding {
values: Vec<f64>,
original: String,
}
which I want to contain some original text original
, as well as the embedding of that text according to OpenAI's CLIP model.
I've got a list of strings in a YAML file (without their embeddings). I want to deserialise that list of strings into a Vec<Embedding>
, where the values
member of an Embedding
should just be the embedded string. I can get the embedding via a function that looks like:
/// Using a pre-loaded CLIP model, embed some given text into a vector of f64's
fn embed_text(clip_model: &ClipModel, text: &str) -> Result<Vec<f64>> {...}
So I need to read the YAML file, deserialise each string, embed each string, and return the Embedding
struct that has the embedding as well as the original string.
I can't quite figure out how to make this work using the serde_yaml
or serde
docs. All the talk of visitors and deserialise_*
is kinda going over my head. I initially thought #[serde(deserialize_with = "callback")]
was going to be my savior, but I couldn't get that to work because it only works for a field in a struct, not the struct itself.
I've solved this problem before, in a hacky way, and I'd rather do something cleaner if possible. Before, would have made a RawEmbedding
and an Embedding
struct like so:
#[derive(Serialize, Deserialize)]
struct RawEmbedding {
original: String,
}
struct Embedding {
values: Vec<f64>,
original: String,
}
And then I'd manually write a impl From<RawEmbedding> for Embedding {...}
so that I could use serde
to deserialise into a RawEmbedding
, and then .into()
to convert the RawEmbedding
into an actual Embedding
. I'm kinda tired of duplicating all my structs thought, I'm sure there must be a better way that does the conversion logic inside of serde
's deserialisation.
Note that I need access to the ClipModel
(or at least a borrow of it) in order to do the embeddings and loading this model from disk is quite expensive so it's not an option to do again and again for every string that I want to embed.
Since obtaining the values using the embedding string requires access to the ClipModel
, we need a way to deserialize while having access to additional context. serde
has a way to support exactly this, using its DeserializeSeed
trait. The documentation actually calls this out as a specific use case: "If you ever find yourself looking for a way to pass data into a Deserialize impl, this trait is the way to do it."
While serde
provides ways to derive its Deserialize
trait, it does not provide a way to derive its DeserializeSeed
trait. There has been discussion about supporting it in the past, but as things currently are we need to implement the trait manually. This unfortunately involves a bit of boilerplate.
I will assume, for the sake of the example, that the definition of ClipModel
is as follows:
struct ClipModel;
DeserializeSeed
is different than Deserialize
in that it is not often implemented on the type you are wanting to deserialize into. Instead, it is implemented on a type containing the context you want to deserialize with. In our case, that context is a reference to ClipModel
, so we can use it to call embed_text
during deserialization. The simplest way to provide that context is to just implement DeserializeSeed
directly on &ClipModel
. Naturally, if this doesn't work (perhaps because you don't own the definition of ClipModel
), you can also wrap &ClipModel
in a newtype.
use serde::de::{DeserializeSeed, Deserializer, SeqAccess, Visitor};
use std::{fmt, fmt::Formatter};
impl<'a, 'de> DeserializeSeed<'de> for &'a ClipModel {
type Value = Vec<Embedding>;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct ClipModelVisitor<'a>(&'a ClipModel);
impl<'a, 'de> Visitor<'de> for ClipModelVisitor<'a> {
type Value = Vec<Embedding>;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
formatter.write_str("a sequence of embedded strings")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut result = Vec::new();
while let Some(text) = seq.next_element::<String>()? {
result.push(Embedding {
// Unwrap is just for demonstration; handle the error in real code!
values: embed_text(self.0, &text).unwrap(),
original: text,
})
}
Ok(result)
}
}
deserializer.deserialize_seq(ClipModelVisitor(self))
}
}
If this looks unfamiliar or confusing, I recommend familiarizing yourself with serde
deserialization in general by reading "Implementing Deserialize". The main differences here are:
DeserializeSeed
requires you to define a Value
associated type to indicate the return type of the deserialization. In our case, we declare that we will return a Vec<Embedding>
.&ClipModel
) to our Visitor
helper struct, so it can be used when accessing each deserialized string.We can verify that this implementation works by providing a simple implementation of embed_text
and a simple yaml
input:
// Not sure what the error type is here, since it's not provided in the code included with your question.
fn embed_text(clip_model: &ClipModel, text: &str) -> Result<Vec<f64>, ()> {
Ok(match text {
"foo" => vec![0.1],
"bar" => vec![1.2],
_ => Vec::new(),
})
}
fn main() {
let clip_model = ClipModel;
let result = clip_model.deserialize(serde_yaml::Deserializer::from_str("- foo\n- bar\n"));
dbg!(result);
}
This will output something similar to the following:
[src\main.rs:63] result = Ok(
[
Embedding {
values: [
0.1,
],
original: "foo",
},
Embedding {
values: [
1.2,
],
original: "bar",
},
],
)
You can run the above code at this playground.
This is the intended solution provided by the serde
framework for what you are attempting to do. Hope that helps!