Search code examples
genericsrusttraits

Expose function and trait that hides internal trait bounds


Goal

I am attempting to provide a generic function which will generate a list of random floats from the standard normal distribution. In particular, I am using the rand and rand_distr crates to do the actual sampling. My main goal is to provide a single trait Float along with the generic function randn where no other imports are required (particularly from the rand* libraries used internally.

pub trait Float {}
impl Float for f32 {}
impl Float for f64 {}

pub fn randn<N: Float>() -> Vec<N> {
    ...
}

MRE and Error

Below is a minimal reproducible example of the issue I am facing. It contains 3 modules:

  • rand - Function signatures and dummy implementations of library code from rand and rand_distr
  • library - Code from library
  • consumer - Example use case of library code
mod rand {
    pub struct StandardNormal;

    pub trait Distribution<N> {
        fn sample(&self) -> N;
    }

    impl Distribution<f32> for StandardNormal {
        fn sample(&self) -> f32 {
            0.0
        }
    }

    impl Distribution<f64> for StandardNormal {
        fn sample(&self) -> f64 {
            0.0
        }
    }
}

mod library {
    use super::{Distribution, StandardNormal};

    pub trait Numeric {}
    impl Numeric for i32 {}
    impl Numeric for i64 {}
    impl Numeric for f32 {}
    impl Numeric for f64 {}

    pub trait Float: Numeric {}
    impl Float for f32 {}
    impl Float for f64 {}

    pub fn randn<F: Float>() -> Vec<F> {
        vec![StandardNormal.sample()]
    }
}

mod consumer {
    use super::{randn, Float};

    pub struct Consumer<F: Float> {
        f: Vec<F>,
    }

    impl<F: Float> Consumer<F> {
        pub fn new() -> Self {
            Self { f: randn() }
        }
    }
}

pub use library::{Float, randn};
pub use rand::{Distribution, StandardNormal};

This code generates the following error message:

error[E0277]: the trait bound `StandardNormal: Distribution<F>` is not satisfied
  --> src/lib.rs:35:29
   |
35 |         vec![StandardNormal.sample()]
   |                             ^^^^^^ the trait `Distribution<F>` is not implemented for `StandardNormal`
   |
help: consider introducing a `where` clause, but there might be an alternative better way to express this requirement
   |
34 |     pub fn randn<F: Float>() -> Vec<F> where StandardNormal: Distribution<F> {
   |                                        +++++++++++++++++++++++++++++++++++++

For more information about this error, try `rustc --explain E0277`.
error: could not compile `tester` (lib) due to 1 previous error

What I've Tried

I attempted to alter the trait definition of Float to include a where clause:

pub trait Float: Numeric + Sized where StandardNormal: Distribution<Self> {
    ...
}

This resulted in multiple errors, the first one being:

error[E0277]: the trait bound `StandardNormal: Distribution<F>` is not satisfied
  --> src/lib.rs:34:21
   |
34 |     pub fn randn<F: Float>() -> Vec<F> {
   |                     ^^^^^ the trait `Distribution<F>` is not implemented for `StandardNormal`
   |
note: required by a bound in `Float`
  --> src/lib.rs:30:60
   |
30 |     pub trait Float: Numeric + Sized where StandardNormal: Distribution<Self> {}
   |                                                            ^^^^^^^^^^^^^^^^^^ required by this bound in `Float`
help: consider introducing a `where` clause, but there might be an alternative better way to express this requirement
   |
34 |     pub fn randn<F: Float>() -> Vec<F> where StandardNormal: Distribution<F> {

Current Best Solution

The only workable solution I have found is to include a where clause on both the randn function and the impl block of the Consumer struct:

mod library {
    ...
    pub fn randn<F: Float>() -> Vec<F> where StandardNormal: Distribution<F> {
        ...
    }
    ...
}
...
mod consumer {
    ...
    impl<F: Float> Consumer<F> where StandardNormal: Distribution<F> {
        ...
    }
    ...
}

Is it possible to provide an API so that StandardNormal and Distribution are hidden from the consumer?


Solution

  • The first attempt to fix the problem is the right idea, but it doesn't work because of this "bug" (it is technically expected behavior, but it is not intuitive and will probably be fixed). But it is possible to work around the problem with a couple of helper traits.

    The main idea of one possible workaround is that associated type bounds are special. This means that if StandardNormal would be an associated type, the code would work. We can make it into an associated type by using two helper traits: (we could use one, but by using two it should be easier to make StandardNormal replaceable in the future)

    pub trait MakeAssociated<A: ?Sized>: MakeAssociatedInner<A, Same=A> {}
    impl<A: ?Sized, B: ?Sized> MakeAssociated<A> for B {}
    
    pub trait MakeAssociatedInner<A: ?Sized> {
        type Same: ?Sized;
    }
    impl<A: ?Sized, B: ?Sized> MakeAssociatedInner<A> for B {
        type Same = A;
    }
    

    Then we define the Float trait as

    pub trait Float: Numeric + Sized + MakeAssociated<StandardNormal, Same: Distribution<Self>> {}
    

    It can be useful to hide these traits from the user, as they only are useful to help the compiler. This can be done by sealing them inside a private module:

    pub trait Float: Numeric + Sized + sealed::MakeAssociated<StandardNormal, Same: Distribution<Self>> {}
    
    mod sealed {
        pub trait MakeAssociated<A: ?Sized>: MakeAssociatedInner<A, Same=A> {}
        impl<A: ?Sized, B: ?Sized> MakeAssociated<A> for B {}
        
        pub trait MakeAssociatedInner<A: ?Sized> {
            type Same: ?Sized;
        }
        impl<A: ?Sized, B: ?Sized> MakeAssociatedInner<A> for B {
            type Same = A;
        }
    }
    

    And this works! (playground)

    Note that when consumers try to implement the Float trait, they also have to implement Distribution<Self> for StandardNormal, so they are not entirely hidden, but the user doesn't have to name them otherwise.