Search code examples
parsinggenericsrusttypenum

Implementing static bounds checking in Rust using `typenum`


Problem statement

I have a Rust parsing library for wire-encoded buffers of bytes, and I'm trying to generate a "view" data structure, ByteSliceReader, for writing subparsers that enforces bounds checking at compile time. Since const generics don't appear to have sufficient functionality to allow this (correct me if I'm wrong), I tried using the typenum crate. However, I'm getting weird/difficult-to-understand compiler errors, so I'm looking for guidance on how to resolve them.

(I don't think I can use generic-array since subparsers need to use functions like f32::from_le_bytes which accepts a [u8; 4].)

Implementation

The data structure looks roughly like this (except that there's actually a ring buffer storing the data instead of an array):

use typenum::*;

struct ByteSliceReader<'a, const LEN: usize, Cursor> {
    buffer: &'a [u8; LEN],
    cursor: core::marker::PhantomData<Cursor>,
}

Here Cursor is intended to be a typenum-based unsigned integer that tracks how many bytes have been consumed by the subparser. It has one method to "pop" a number of bytes determined by a const generic:

impl<'a, const LEN: usize, Cursor: Unsigned> ByteSliceReader<'a, LEN, Cursor>
where
    Const<LEN>: ToUInt,
    U<LEN>: Unsigned,
{
    #[must_use]
    fn pop<const NUM: usize>(self) -> (ByteSliceReader<'a, LEN, Sum<Cursor, U<NUM>>>, [u8; NUM])
    where
        Const<NUM>: ToUInt,
        Cursor: core::ops::Add<U<NUM>>,
        Sum<Cursor, U<NUM>>: IsLessOrEqual<U<LEN>>,
    {
        let mut byte_window = [0; NUM];
        for (i, byte) in byte_window.iter_mut().enumerate() {
            *byte = self.buffer[Cursor::to_usize() + i];
        }

        (
            ByteSliceReader::<LEN, _> {
                buffer: self.buffer,
                cursor: core::marker::PhantomData,
            },
            byte_window,
        )
    }
}

This all compiles fine. But when I try to use it in this example helper function (intended to extract and parse some bytes), it doesn't compile:

fn get_serial_and_part_num<const LEN: usize, Cursor>(
    reader: ByteSliceReader<'_, LEN, Cursor>,
) -> (ByteSliceReader<'_, LEN, Sum<Sum<Cursor, U1>, U1>>, (u8, u8))
where
    Cursor: Unsigned,
    Cursor: core::ops::Add<U1>,
    Sum<Cursor, U1>: core::ops::Add<U1> + Unsigned,
    Const<LEN>: ToUInt,
    U<LEN>: Unsigned,
{
    let (reader, serial_num) = reader.pop::<1>();
    let (reader, part_num) = reader.pop::<1>();
    (reader, (serial_num[0], part_num[0]))
}

Error

error[E0599]: the method `pop` exists for struct `ByteSliceReader<'_, LEN, Cursor>`, but its trait bounds were not satisfied
   --> fsw/sensors/ring_parser.rs:528:43
    |
483 |     struct ByteSliceReader<'a, const LEN: usize, Cursor> {
    |     ---------------------------------------------------- method `pop` not found for this struct
...
528 |         let (reader, serial_num) = reader.pop::<1>();
    |                                           ^^^ method cannot be called on `ByteSliceReader<'_, LEN, Cursor>` due to unsatisfied trait bounds
    |
    = note: the following trait bounds were not satisfied:
            `Cursor: Add<<typenum::Const<_> as typenum::ToUInt>::Output>`
help: consider restricting the type parameter to satisfy the trait bound
    |
526 |         U<LEN>: Unsigned, Cursor: Add<<typenum::Const<_> as typenum::ToUInt>::Output>
    |                         ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The compiler's suggestion contains a _, which I interpret as 1, but applying it doesn't help. I've tried various permutations of bounds and gotten various similar errors. Is there some way to achieve what I want?


Solution

  • The immediate problem is that you need to tell the compiler that the Output of Const<1> as ToUInt will be U1:

        Const<1>: ToUInt<Output = U1>,
    

    However that then leads to further unsatisfied trait bounds (namely the IsLessOrEqual ones):

        Sum<Cursor, U1>: IsLessOrEqual<U<LEN>>,
        Sum<Sum<Cursor, U1>, U1>: IsLessOrEqual<U<LEN>>,
    

    See it on the playground.