Search code examples
genericsrusttraitsconst-generics

Satisfying a trait bound with a const generic expression, is it possible?


I am trying to make use of the currently unstable feature generic_const_exprs to allow the users of my library to know the resulting dimensions of the types they generate.

My use case is much more complex, but I've created a minimal example with a reproducible error. The main idea is, that given a Tensor<N> as input, I want to output a Tensor<M>, where M is {N + 1}. A Tensor<N> is a trait, and it is implemented both for Constant<N> and for Variable<M>. This is the code:

#![allow(incomplete_features)]
#![feature(generic_const_exprs)]

struct Variable<const N: usize>;
struct Constant<const N: usize>;

trait Tensor<const N: usize> {
    fn get_dim(&self) -> usize {
        N
    }
}
trait ConvertTo<Y> {
    fn convert(&self) -> Y;
}

impl<const N: usize> Tensor<N> for Variable<N> {}
impl<const N: usize> Tensor<N> for Constant<N> {}

impl<const N: usize, const M: usize> ConvertTo<Constant<M>> for Variable<N> {
    fn convert(&self) -> Constant<M> {
        Constant::<M>
    }
}

impl<const N: usize, const M: usize> ConvertTo<Variable<M>> for Constant<N> {
    fn convert(&self) -> Variable<M> {
        Variable::<M>
    }
}

fn convert_plus_one<const N: usize, X, Y>(x: X) -> Y
where
    X: Tensor<N> + ConvertTo<Y>,
    Y: Tensor<{ N + 1 }>,
{
    x.convert()
}

fn main() {
    let x = Constant::<3>;
    let y = convert_plus_one(x);
    // At this point the compiler should know that y is a Variable<N> with N = 4
    // and it implements Tensor<4>, because Tensor<N> is implemented for Variable<N>
    assert_eq!(y.get_dim(), 4);
}

And this is the compiler error:

   Compiling playground v0.0.1 (/playground)
error[E0277]: the trait bound `Variable<{_: usize}>: Tensor<{ N + 1 }>` is not satisfied
  --> src/main.rs:41:13
   |
41 |     let y = convert_plus_one(x);
   |             ^^^^^^^^^^^^^^^^ the trait `Tensor<{ N + 1 }>` is not implemented for `Variable<{_: usize}>`
   |
   = help: the trait `Tensor<N>` is implemented for `Variable<N>`
note: required by a bound in `convert_plus_one`
  --> src/main.rs:34:8
   |
31 | fn convert_plus_one<const N: usize, X, Y>(x: X) -> Y
   |    ---------------- required by a bound in this
...
34 |     Y: Tensor<{ N + 1 }>,
   |        ^^^^^^^^^^^^^^^^^ required by this bound in `convert_plus_one`

For more information about this error, try `rustc --explain E0277`.
error: could not compile `playground` due to previous error

I am running out of ideas on how to fix this. Am I missing something, or is this just impossible to do in the current state of generic_const_exprs?

Link to the Rust playground


Solution

  • Thanks to a suggestion from @lcnr in the rust-lang Zulip chat, I managed to make it work by using trait associated types. The trick here was to be able to express my bounds in a single expression. ❤️

    From this:

    where
        X: Tensor<N> + ConvertTo<Y>,
        Y: Tensor<{ N + 1 }>,
    

    To this:

    where
        X: Tensor + ConvertTo<{<X as Tensor>::N + 1}>,
    

    The original example didn't work because Rust evaluates each trait bound independently. So in one hand it tries to assert that Constant<3>: ConvertTo<?>, and on the other that ?: Tensor<4>. Which only makes sense if they are both considered at the same time.

    Associated types on traits, allow for the necessary syntax to indeed have all bounds in a single expression, here is the final result, which compiles perfectly:

    #![allow(incomplete_features)]
    #![feature(generic_const_exprs)]
    #![feature(associated_type_bounds)]
    
    struct Variable<const N: usize>;
    struct Constant<const N: usize>;
    
    trait Tensor {
        const N: usize;
    
        fn get_dim(&self) -> usize {
            Self::N
        }
    }
    trait ConvertTo<const N: usize> {
        type To;
        
        fn convert(&self) -> Self::To;
    }
    
    impl<const N: usize> Tensor for Variable<N> {
        const N: usize = N;
    }
    impl<const N: usize> Tensor for Constant<N> {
        const N: usize = N;
    }
    
    impl<const N: usize, const M: usize> ConvertTo<M> for Variable<N> {
        type To = Constant<M>;
    
        fn convert(&self) -> Self::To {
            Constant::<M>
        }
    }
    
    impl<const N: usize, const M: usize> ConvertTo<M> for Constant<N> {
        type To = Variable<M>;
        
        fn convert(&self) -> Self::To {
            Variable::<M>
        }
    }
    
    fn convert_plus_one<X>(x: X) -> <X as ConvertTo<{<X as Tensor>::N + 1}>>::To
    where
        X: Tensor + ConvertTo<{<X as Tensor>::N + 1}>,
    {
        x.convert()
    }
    
    fn main() {
        let x = Constant::<3>;
        let y = convert_plus_one(x);
        assert_eq!(y.get_dim(), 4);
    }
    

    And now I can rest.