Search code examples
genericsrusttypesdependent-typetype-level-computation

Generic parameter of array type


I'd like to create a struct that is type-indexed (or, I guess as the Rust programmers would say, has a generic parameter of some array type). Something like this:

struct Frame<const N: u8, const NS: [u8; N]> {
    frame_num: u8,
    subframe_num: u8,
}

impl<const N: u8, const NS: [u8; N]> Frame<N, NS> {
    pub fn start() -> Self {
        Frame{ frame_num: 0, subframe_num: 0, }
    }

    pub fn next(&mut self) {
        self.subframe_num += 1;
        if self.subframe_num == NS[self.frame_num as usize] {
            self.subframe_num = 0;
            self.frame_num += 1;
            if self.frame_num == N {
                self.frame_num = 0;
            }
        }
    }
}

...

let foo: Frame<2, [5, 8]> = Frame::start();

This doesn't typecheck because, quoting rustc:

error[E0770]: the type of const parameters must not depend on other generic parameters
  --> src/main.rs:32:42
   |
32 | struct Frame<const N: u8, const NS: [u8; N]> {
   |                                          ^ the type must not depend on the parameter `N`

I tried using const_generic_wrap which seems to be aimed at exactly this use case. The extra noise it introduces, while not ideal, is not too bad:

#![feature(generic_const_exprs)]

struct Frame<const N: u8, NS> where NS: ConstWrap<BaseType = [u8; N as usize]> {
    frame_num: u8,
    subframe_num: u8,
    phantom: PhantomData<NS>,
}

impl<const N: u8, NS> Frame<N, NS> where NS: ConstWrap<BaseType = [u8; N as usize]> {
    pub fn start() -> Self {
        Frame{ frame_num: 0, subframe_num: 0, phantom: PhantomData, }
    }

    pub fn next(&mut self) {
        self.subframe_num += 1;
        if self.subframe_num == NS::VALUE[self.frame_num as usize] {
            self.subframe_num = 0;
            self.frame_num += 1;
            if self.frame_num == N {
                self.frame_num = 0;
            }
        }
    }
}

But then I get into trouble when trying to instantiate this type. I can't write Frame<2, [5, 8]> as the type, because my type parameter is of type ConstWrap.... It seems const_generic_wrap exposes no way of producing a ConstWrap value that is not a simple type like u32 etc.

Is there a way to achieve what I want? I'm willing to use any nightly/unstable features.


Solution

  • I managed to cobble together something via a type-level length-indexed list:

    #![feature(generic_const_exprs)]
    use const_generic_wrap::*;
    use std::marker::PhantomData;
    
    trait TList<T, const N: usize> {
        const VALUE: [T; N];
    }
    
    #[derive(Debug)]
    struct Nil<T> {
        phantom: PhantomData<T>
    }
    
    impl<T> TList<T, 0> for Nil<T> {
        const VALUE: [T; 0] = [];
    }
    
    #[derive(Debug)]
    struct Cons<T: Sized, const N: usize, X, Tail> where X: ConstWrap<BaseType = T>{
        head: PhantomData<X>,
        tail: PhantomData<Tail>,
    }
    
    const fn cons_array<T: Copy, const N: usize>(x: T, tail: [T; N]) -> [T; N+1] {
        let mut res = [x; N+1];
    
        let mut i = 0;
        while i < N {
            res[i+1] = tail[i];
            i += 1;
        }
    
        res
    }
    
    impl<T: Copy, const N: usize, X: ConstWrap<BaseType = T>, Tail: TList<T, N>> TList<T, {N+1}> for Cons<T, N, X, Tail> {
        const VALUE: [T; N + 1] = cons_array(X::VALUE, Tail::VALUE);
    }
    

    With this, we can define Frame in a straightforward way:

    #[derive(Debug)]
    struct Frame<const N: u8, NS> where NS: TList<u8, {N as usize}> {
        frame_num: u8,
        subframe_num: u8,
        phantom: PhantomData<NS>,
    }
    
    impl<const N: u8, NS> Frame<N, NS> where NS: TList<u8, {N as usize}> {
        pub fn start() -> Self {
            Frame{ frame_num: 0, subframe_num: 0, phantom: PhantomData, }
        }
    
        pub fn next(&mut self) {
            self.subframe_num += 1;
            if self.subframe_num == NS::VALUE[self.frame_num as usize] {
                self.subframe_num = 0;
                self.frame_num += 1;
                if self.frame_num == N {
                    self.frame_num = 0;
                }
            }
        }
    }
    

    Here's an example of using it; it's screaming for a macro to make the syntax nicer, but at least it works as a proof of concept:

    pub fn main() {
        let mut foo: Frame<2, Cons<u8, 1, WrapU8<8>, Cons<u8, 0, WrapU8<3>, Nil<u8>>>> = Frame::start();
        for _ in 0..20 {
            println!("{foo:?}");
            foo.next();
        }
    }