Search code examples
arraysrustthread-safetyjson-deserializationserde

How can I deserialize an array in a background thread using Rust's serde?


I've got a json file with an internal array (e.g. {"a":"b", "c": [...]}) that can be incredibly long; too long to fit into memory, and as such cannot be handled by the default Vec deserializer.

I saw an approach written in C# that did the following:

  • when deserializing, for each item, put the item on a ConcurrentQueue
  • consumers run in a background thread, which pull deserialized items off the queue for processing.

Now, I could simply implement my deserializer to do the processing, as is described here. However, I'd like to try and implement an approach similar to that described above. I've gotten far enough to have discovered channels, and that the Deserializer and SeqAccess traits don't implement Send and therefore cannot safely be moved across threads.

I've tried a few configurations, similar to what's described here, but isn't quite the approach I'd like to take.

Here's some sample code:

// generator.rs
// This lets me replace `Vec<T>` with `ChannelGenerator<T>` 
// and treat it the same, except `next()` is pulling 
// a T obj off the channel, rather than from the whole Vec in memory

pub struct ChannelGenerator<T> {
    pub(crate) receiver: Receiver<T>,
}

impl<'f, T> Iterator for ChannelGenerator<T> {
    type Item = T;

    fn next(&mut self) -> Option<Self::Item> {
        let res = self.receiver.recv();
        res.map_or(None, |i| Some(i))
    }
}
// struct.rs

#[derive(Deserialize)]
pub struct File {
    pub a: String,
    pub b: String,
    #[serde(deserialize_with = "deserialize_to_channel")]
    pub items: ChannelGenerator<Item>,
}
use std::{
    fmt,
    marker::PhantomData,
    sync::mpsc::{sync_channel, Receiver, SyncSender},
    thread,
};

use serde::{
    de::{SeqAccess, Visitor},
    Deserialize, Deserializer,
};

use crate::sync_array_serde::generator::ChannelGenerator;

pub struct ChannelVisitor<T> {
    pub sender: SyncSender<T>,
    pub f: PhantomData<fn() -> T>,
}

impl<'de, T> Visitor<'de> for ChannelVisitor<T>
where
    T: Deserialize<'de> + Send,
{
    type Value = ();

    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("an array of objects of type T")
    }

    fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
    where
        S: SeqAccess<'de>,
    {
        while let Some(n) = seq.next_element::<T>()? {
            if self.sender.send(n).is_err() {
                break;
            }
        }
        Ok(())
    }
}

// An implementation of a serde deserializer that returns a channel receiver immediately,
// and in a background thread starts deserializing objects and sending them to a channel,
// up to a max amount of memory.
// The deserialized struct then has a ChannelGenerator, which iterates by receiving items from the channel.
pub fn deserialize_to_channel<'de, T, D>(deserializer: D) -> Result<ChannelGenerator<T>, D::Error>
where
    T: Deserialize<'de> + Send,
    D: Deserializer<'de>,
{
    let (sender, receiver) = sync_channel::<T>(0);

    // compiler error here about not being able to move D into a thread safely
    thread::spawn(move || {
        let visitor = ChannelVisitor {
            sender,
            f: PhantomData,
        };
        deserializer.deserialize_seq(visitor).unwrap();
    })
    .join()
    .unwrap();
    Ok(ChannelGenerator { receiver })
}

The problem I'm at now, is

D: Deserializer cannot be sent between threads safely

I've tried spawning the thread inside of visit_seq, but I get a similar issue with S: SeqAccess, since neither implement the Send trait, and I can't introduce it since that would make my implementation more restrictive than the Visitor trait.

Any insight into how I could implement an array deserializer that starts the deserialization by tossing messages onto a queue, and handle the deserialized objects by reading from the queue in two separate threads?

Writing that last piece out, I'm realizing maybe I should spawn the receiving thread first in the background, and then start deserializing... I'll try that out and get back to this. Any hints in the meantime would be appreciated!

edit: okay that didn't work. I just added two more channels to send the input/output back and forth between two external threads, but the main problem is still that returning the deserialized outer struct is blocked on putting all of the internal items onto the channel, before we can return the receiver. I need some way to either pass in the receiver to the deserializing function, or to run the deserialization at that level in a background thread.

Latest code can be found here, entry point is src/main.rs.


Solution

  • You are right that you need some way to pass a SyncSender to the deserialize function. The best way to pass state into deserialization is to implement DeserializeSeed instead of Deserialize. The documentation for DeserializeSeed states:

    If you ever find yourself looking for a way to pass data into a Deserialize impl, this trait is the way to do it.

    In our case, you're needing to pass in a sender so that the data can be sent to another thread as it is deserialized.

    Unfortunately, DeserializeSeed implementations are not derivable, so we will need to write our own trait implementation. It contains a bit of boilerplate, but I'll break it down as best I can.

    Deserializing the Items

    In serde, we can deserialize from a "seed" which contains state needed for the deserialization. In our case, the state contains the SyncSender. So, we can define a seed for deserializing the items as follows:

    struct ItemsSeed<T> {
        sender: SyncSender<T>,
    }
    

    Now we implement DeserializeSeed for ItemsSeed. The implementation looks very similar to a normal Deserialize implementation, and the Visitor we define is very similar to the one in your example code:

    impl<'de, T> DeserializeSeed<'de> for ItemsSeed<T>
    where
        T: Deserialize<'de>,
    {
        type Value = ();
    
        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
        where
            D: Deserializer<'de>,
        {
            struct ItemsVisitor<T> {
                sender: SyncSender<T>,
            }
    
            impl<'de, T> Visitor<'de> for ItemsVisitor<T>
            where
                T: Deserialize<'de>,
            {
                type Value = ();
    
                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                    formatter.write_str("an array of objects of type T")
                }
    
                fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
                where
                    A: SeqAccess<'de>,
                {
                    while let Some(n) = seq.next_element()? {
                        // Might want some better error handling here.
                        if self.sender.send(n).is_err() {
                            break;
                        }
                    }
                    Ok(())
                }
            }
    
            deserializer.deserialize_seq(ItemsVisitor {
                sender: self.sender,
            })
        }
    }
    

    The main difference between this and a normal Deserialize implementation is that we have access to the sender during deserialization. We also define a return value, which in our case is just type Value = ();, as the only deserialized data is already sent to another thread to be processed there instead.

    Deserializing the File struct

    Since your example includes data besides the incredibly long array, we want to deserialize that data into a File struct at the same time. As I mentioned before, we can't actually derive a DeserializeSeed implementation like we can a Deserialize implementation, so unfortunately the #[derive(Deserialize)] turns into a lot of boilerplate. This is the DeserializeSeed implementation for File:

    #[derive(Debug)]
    pub struct File {
        pub a: String,
        pub b: String,
    }
    
    pub struct FileSeed<T> {
        sender: SyncSender<T>,
    }
    
    impl<'de, T> DeserializeSeed<'de> for FileSeed<T>
    where
        T: Deserialize<'de>,
    {
        type Value = File;
    
        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
        where
            D: Deserializer<'de>,
        {
            enum Field {
                A,
                B,
                Items,
            }
    
            impl<'de> Deserialize<'de> for Field {
                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
                where
                    D: Deserializer<'de>,
                {
                    struct FieldVisitor;
    
                    impl<'de> Visitor<'de> for FieldVisitor {
                        type Value = Field;
    
                        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                            formatter.write_str("`a`, `b`, or `items`")
                        }
    
                        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
                        where
                            E: de::Error,
                        {
                            match value {
                                "a" => Ok(Field::A),
                                "b" => Ok(Field::B),
                                "items" => Ok(Field::Items),
                                _ => Err(de::Error::unknown_field(value, FIELDS)),
                            }
                        }
                    }
    
                    deserializer.deserialize_identifier(FieldVisitor)
                }
            }
    
            struct FileVisitor<T> {
                sender: SyncSender<T>,
            }
    
            impl<'de, T> Visitor<'de> for FileVisitor<T>
            where
                T: Deserialize<'de>,
            {
                type Value = File;
    
                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                    formatter.write_str("struct File")
                }
    
                fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
                where
                    A: MapAccess<'de>,
                {
                    let mut a = None;
                    let mut b = None;
                    let mut items = None;
    
                    while let Some(key) = map.next_key()? {
                        match key {
                            Field::A => {
                                if a.is_some() {
                                    return Err(de::Error::duplicate_field("a"));
                                }
                                a = Some(map.next_value()?);
                            }
                            Field::B => {
                                if b.is_some() {
                                    return Err(de::Error::duplicate_field("b"));
                                }
                                b = Some(map.next_value()?);
                            }
                            Field::Items => {
                                if items.is_some() {
                                    return Err(de::Error::duplicate_field("items"));
                                }
                                items = Some(map.next_value_seed(ItemsSeed {
                                    sender: self.sender.clone(),
                                })?);
                            }
                        }
                    }
    
                    if items.is_none() {
                        return Err(de::Error::missing_field("items"));
                    }
    
                    Ok(File {
                        a: a.ok_or_else(|| de::Error::missing_field("a"))?,
                        b: b.ok_or_else(|| de::Error::missing_field("b"))?,
                    })
                }
            }
    
            const FIELDS: &[&str] = &["a", "b", "items"];
            deserializer.deserialize_struct(
                "File",
                FIELDS,
                FileVisitor {
                    sender: self.sender,
                },
            )
        }
    }
    

    This is very similar to the Manually implementing Deserialize for a struct example given on the official serde website. The main difference is that when we deserialize the items, we pass the SyncSender state using next_value_seed() method. Additionally, we don't actually include items in the File struct, because there is no data returned by the deserialization, as it's all processed in a separate thread.

    Running deserialization on a separate thread

    Now we put everything together. We can create our synchronous channel, spawn the deserialization in a separate thread, and process the received values as deserialization is occurring in the background.

    fn main() {
        let (sender, receiver) = sync_channel::<u32>(0);
    
        // Deserialize in a separate thread.
        let deserialize_thread = thread::spawn(|| {
            let mut deserializer = serde_json::de::Deserializer::from_str(
                "{\"a\": \"foo\", \"b\": \"bar\", \"items\": [0, 1, 2, 3, 4, 5]}",
            );
            FileSeed { sender }.deserialize(&mut deserializer)
        });
    
        while let Ok(value) = receiver.recv() {
            // Process the deserialized values here.
            dbg!(value);
        }
    
        // You can also access the `File` after deserializing is complete.
        dbg!(deserialize_thread.join());
    }
    

    Note that I chose to make T be u32, but you can use any generic T that implements Deserialize.

    Running this program outputs the following:

    [src/main.rs:197] value = 0
    [src/main.rs:197] value = 1
    [src/main.rs:197] value = 2
    [src/main.rs:197] value = 3
    [src/main.rs:197] value = 4
    [src/main.rs:197] value = 5
    [src/main.rs:200] deserialize_thread.join() = Ok(
        Ok(
            File {
                a: "foo",
                b: "bar",
            },
        ),
    )
    

    You can see and run my whole solution on this playground.