Search code examples
rustmacrosmonomorphism

Capture monomorphized generics in a macro


I have written a trait to serialize objects into iterators of little endian bytes:

pub trait ToLeBytes: Sized
where
    Self::Iter: Iterator<Item = u8>,
{
    type Iter;

    fn to_le_bytes(&self) -> Self::Iter;
}

I have implemented it for the primitive data types I need and also for heapless::Vec:

#[allow(clippy::cast_possible_truncation)]
#[cfg(feature = "heapless")]
impl<I, const SIZE: usize> ToLeBytes for heapless::Vec<I, SIZE>
where
    I: ToLeBytes,
    for<'a> <I as ToLeBytes>::Iter: Iterator<Item = u8> + 'a,
{
    type Iter = Box<dyn Iterator<Item = u8>>;

    fn to_le_bytes(&self) -> Self::Iter {
        let mut iterator: Box<dyn Iterator<Item = u8>> = Box::new(empty());

        if u8::try_from(SIZE).is_ok() {
            iterator = Box::new(<u8 as ToLeBytes>::to_le_bytes(&(self.len() as u8)));
        } else if u16::try_from(SIZE).is_ok() {
            iterator = Box::new(<u16 as ToLeBytes>::to_le_bytes(&(self.len() as u16)));
        } else if u32::try_from(SIZE).is_ok() {
            iterator = Box::new(<u32 as ToLeBytes>::to_le_bytes(&(self.len() as u32)));
        } else if u64::try_from(SIZE).is_ok() {
            iterator = Box::new(<u64 as ToLeBytes>::to_le_bytes(&(self.len() as u64)));
        }

        for item in self {
            iterator = Box::new(iterator.chain(<I as ToLeBytes>::to_le_bytes(item)));
        }

        iterator
    }
}

However, since this code is intended to run on embedded systems with low performing hardware, I want to avoid heap allocations and thus want to get rid of the Boxes.

Of course, chaining the iterators in basic rust is not possible, because each call to .chain() will return a new type of iterator. Hence I thought that maybe a macro can do the trick, since I already did something similar for the derive macro of that trait.

However, I of course do not want to implement the body for any possible I and SIZE, but only for those that are being used in the respective program. Therefore, I'd need to run the macro on the monomorphized code.

I tried to google how to do this, but did not find any results. How do I write a macro that is being passed in the monomorphized code of an impl block? I don't want an entire solution, but a nudge into the right direction.

Update

I think that I'm almost there, thanks to Chayim's comments:

use crate::ToLeBytes;
use std::array::IntoIter;
use std::iter::FlatMap;
use std::slice::Iter;

pub struct ContainerIterator<'a, T, const HEADER_SIZE: usize>
where
    T: ToLeBytes,
{
    size_iterator: IntoIter<u8, HEADER_SIZE>,
    items_iterator: FlatMap<Iter<'a, T>, <T as ToLeBytes>::Iter, fn(&T) -> <T as ToLeBytes>::Iter>,
}

impl<'a, T, const HEADER_SIZE: usize> ContainerIterator<'a, T, HEADER_SIZE>
where
    T: ToLeBytes,
{
    fn from_size_iterator_and_slice(size_iterator: IntoIter<u8, HEADER_SIZE>, items: &[T]) -> Self
    where
        T: ToLeBytes,
    {
        Self {
            size_iterator,
            items_iterator: items
                .iter()
                .flat_map(|item| <T as ToLeBytes>::to_le_bytes(item)),
        }
    }
}

impl<'a, T, const HEADER_SIZE: usize> Iterator for ContainerIterator<'a, T, HEADER_SIZE>
where
    T: ToLeBytes,
{
    type Item = u8;

    fn next(&mut self) -> Option<Self::Item> {
        if let Some(next_header) = self.size_iterator.next() {
            Some(next_header)
        } else {
            self.items_iterator.next()
        }
    }
}

pub enum SizedContainerIterator<'a, T>
where
    T: ToLeBytes,
{
    U8(ContainerIterator<'a, T, 1>),
    U16(ContainerIterator<'a, T, 2>),
    U32(ContainerIterator<'a, T, 4>),
    U64(ContainerIterator<'a, T, 8>),
}

impl<'a, T> SizedContainerIterator<'a, T>
where
    T: ToLeBytes,
{
    pub fn new(items: &[T], capacity: usize) -> SizedContainerIterator<'a, T>
    where
        T: ToLeBytes,
    {
        if u8::try_from(capacity).is_ok() {
            SizedContainerIterator::U8(ContainerIterator::from_size_iterator_and_slice(
                <u8 as ToLeBytes>::to_le_bytes(&(items.len() as u8)),
                items,
            ))
        } else if u16::try_from(capacity).is_ok() {
            SizedContainerIterator::U16(ContainerIterator::from_size_iterator_and_slice(
                <u16 as ToLeBytes>::to_le_bytes(&(items.len() as u16)),
                items,
            ))
        } else if u32::try_from(capacity).is_ok() {
            SizedContainerIterator::U32(ContainerIterator::from_size_iterator_and_slice(
                <u32 as ToLeBytes>::to_le_bytes(&(items.len() as u32)),
                items,
            ))
        } else if u64::try_from(capacity).is_ok() {
            SizedContainerIterator::U64(ContainerIterator::from_size_iterator_and_slice(
                <u64 as ToLeBytes>::to_le_bytes(&(items.len() as u64)),
                items,
            ))
        } else {
            unreachable!("vec size exceeds u64");
        }
    }
}

impl<'a, T> Iterator for SizedContainerIterator<'a, T>
where
    T: ToLeBytes,
{
    type Item = u8;

    fn next(&mut self) -> Option<Self::Item> {
        match self {
            Self::U8(iterator) => iterator.next(),
            Self::U16(iterator) => iterator.next(),
            Self::U32(iterator) => iterator.next(),
            Self::U64(iterator) => iterator.next(),
        }
    }
}

However, now I have the issue of specifying the lifetime for the Iter type:

#[allow(clippy::cast_possible_truncation)]
#[cfg(feature = "heapless")]
impl<I, const SIZE: usize> ToLeBytes for heapless::Vec<I, SIZE>
where
    I: ToLeBytes,
    for<'a> <I as ToLeBytes>::Iter: Iterator<Item = u8> + 'a,
{
    type Iter = SizedContainerIterator<'_, I>;

    fn to_le_bytes(&self) -> Self::Iter {
        SizedContainerIterator::new(self, SIZE)
    }
}

Solution

  • Thanks to Chayim's hints, I got it working:

    #![cfg(feature = "heapless")]
    use crate::ToLeBytes;
    use std::array::IntoIter;
    
    #[derive(Debug)]
    pub enum SizePrefixIterator {
        U8(IntoIter<u8, 1>),
        U16(IntoIter<u8, 2>),
        U32(IntoIter<u8, 4>),
        U64(IntoIter<u8, 8>),
    }
    
    impl SizePrefixIterator {
        #[allow(clippy::cast_possible_truncation)]
        pub fn new(len: usize, capacity: usize) -> Self {
            if u8::try_from(capacity).is_ok() {
                Self::U8(<u8 as ToLeBytes>::to_le_bytes(len as u8))
            } else if u16::try_from(capacity).is_ok() {
                Self::U16(<u16 as ToLeBytes>::to_le_bytes(len as u16))
            } else if u32::try_from(capacity).is_ok() {
                Self::U32(<u32 as ToLeBytes>::to_le_bytes(len as u32))
            } else if u64::try_from(capacity).is_ok() {
                Self::U64(<u64 as ToLeBytes>::to_le_bytes(len as u64))
            } else {
                unreachable!("container size exceeds u64");
            }
        }
    }
    
    impl Iterator for SizePrefixIterator {
        type Item = u8;
    
        fn next(&mut self) -> Option<Self::Item> {
            match self {
                Self::U8(header) => header.next(),
                Self::U16(header) => header.next(),
                Self::U32(header) => header.next(),
                Self::U64(header) => header.next(),
            }
        }
    }
    
    #[cfg(feature = "heapless")]
    impl<T, const SIZE: usize> ToLeBytes for heapless::Vec<T, SIZE>
    where
        T: Sized + ToLeBytes,
    {
        type Iter = std::iter::Chain<
            size_prefix_iterator::SizePrefixIterator,
            FlatMap<
                <Self as IntoIterator>::IntoIter,
                <T as ToLeBytes>::Iter,
                fn(T) -> <T as ToLeBytes>::Iter,
            >,
        >;
    
        fn to_le_bytes(self) -> Self::Iter {
            size_prefix_iterator::SizePrefixIterator::new(self.len(), SIZE).chain(
                self.into_iter()
                    .flat_map(<T as ToLeBytes>::to_le_bytes as fn(T) -> <T as ToLeBytes>::Iter),
            )
        }
    }