Search code examples
rustdesign-patternspolymorphismtraits

How to "mark" a function-like trait to be optionally executed in parallel?


Context

I have a predicate trait. It takes some type T and returns a boolean value for it.

trait Predicate<T> {
  fn evaluate(&self, t: &T) -> bool;
}

I also have evaluator that evaluates same predicate for each given T.

trait Evaluator<T> {
  fn evaluate(&self, is: &[T]) -> Vec<bool>;
}

Evaluator is implemented for Predicate. It evaluates the Predicate for each T.

impl<T, P: Predicate<T>> Evaluator<T> for P {
  fn evaluate(&self, is: &[T]) -> Vec<bool> {
    is.iter().map(|i| self.evaluate(i)).collect()
  }
}

I also have my Predicate<T> implemented for closures of type Fn(&T) -> bool which allows me to provide such closure whenever I want an Evaluator instead of creating a struct and implementing Evaluator trait for it each time I want to apply some different logic.

Now I can create a closure and pass it to whatever that takes an Evaluator.

  let p = |i: &i32| *i > 0;
  wants_evaluator(&p);

Problem

This API has been pretty concise so far but recently I had a need to evaluate predicates in parallel for each solution. The logic itself isn't a problem with rayon - par_iter() does exactly what I need. The problem is that I don't want to lose my handy closure-to-evaluator automatic conversion since I depend on it a lot. What I ideally want to achieve is something like this:

  let p = |i: &i32| *i > 0;
  let p_par = p.into_par(); // will be evaluated in parallel
  wants_evaluator(&p_par);

Parallel evaluation is required not so often because for most simple predicates parallelization only drops performance, however when I need it, I need it.

Failed solutions

I tried to use a wrapper struct ParPredicate

struct ParPredicate<T, P: Predicate<T>>(P, PhantomData<T>);

impl<T, P> Evaluator<T> for ParPredicate<T, P>
where
  T: Sync,
  P: Predicate<T> + Sync,
{
  fn evaluate(&self, is: &[T]) -> Vec<bool> {
    is.par_iter().map(|i| self.0.evaluate(i)).collect()
  }
}

However, Rust complains that this implementation conflicts with impl<T, P: Predicate<T>> Evaluator<T> for P { ...

conflicting implementations of trait `Evaluator<_>` for type `ParPredicate<_, _>`
downstream crates may implement trait `Predicate<_>` for type `ParPredicate<_, _>`

If I use some concrete type instead of T in Predicate, then this error does not appear. It also ruins the whole point of Predicate since it must be generic.

I also tried to use some kind type-state pattern with Predicate:

trait ExecutionStrategy {}

enum Sequential {}
impl ExecutionStrategy for Sequential {}

enum Parallel {}
impl ExecutionStrategy for Parallel {}

trait Predicate<ExecutionStrategy, T> {
  fn evaluate(&self, i: &T) -> bool;
}

trait Evaluator<ExecutionStrategy, T> {
  fn evaluate(&self, is: &[T]) -> Vec<bool>;
}

impl<T, P> Evaluator<Sequential, T> for P
where
  P: Predicate<Sequential, T>,
{
  fn evaluate(&self, is: &[T]) -> Vec<bool> {
    is.iter().map(|i| self.evaluate(i)).collect()
  }
}

impl<T, P> Evaluator<Parallel, T> for P
where
  T: Sync,
  P: Predicate<Parallel, T> + Sync,
{
  fn evaluate(&self, is: &[T]) -> Vec<bool> {
    is.par_iter().map(|i| self.evaluate(i)).collect()
  }
}

The problem here is that I need to cast Predicate<Sequential, T> into Predicate<Parallel, T> somehow and I have no idea how to do it even though they basically have the same contact.

trait IntoPar<T> {
  fn into_par(self) -> impl Predicate<Parallel, T>;
}

impl<T, P: Predicate<Sequential, T>> IntoPar<T> for P {
  fn into_par(self) -> impl Predicate<Parallel, T> {
    // now what?
  }
}

All I want is to attach some kind of marker to my Predicate that would allow Evaluator to implement different logic for Predicate based on its marker. All this information is gathered during compilation and I don't see why couldn't I theoretically achieve this. But how do I do that while keeping this seamless conversion from a closure to Evaluator?

I only want to explore limits of Rust type system. This code will be used only in my personal project and not in production.

Solution

  • You are on the right track. To solve the problem of converting to Parallel, you need to get rid of the ExecutionStrategy in Predicate (but keep it in Evaluator):

    trait Predicate<T> {
        fn evaluate(&self, i: &T) -> bool;
    }
    
    trait Evaluator<ExecutionStrategy, T> {
        fn evaluate(&self, is: &[T]) -> Vec<bool>;
    }
    
    impl<T, P> Evaluator<Sequential, T> for P
    where
        P: Predicate<T>,
    {
        fn evaluate(&self, is: &[T]) -> Vec<bool> {
            is.iter().map(|i| self.evaluate(i)).collect()
        }
    }
    
    impl<T, P> Evaluator<Parallel, T> for P
    where
        T: Sync,
        P: Predicate<T> + Sync,
    {
        fn evaluate(&self, is: &[T]) -> Vec<bool> {
            is.par_iter().map(|i| self.evaluate(i)).collect()
        }
    }
    
    fn wants_evaluator<S, E: Evaluator<S, i32>>(evaluator: E) { ... }
    

    However, that will have an unwanted consequence: all existing invocations of wants_evaluator() will stop working with "type annotation needed".

    That can be solved with a marker type:

    enum Sequential {}
    enum Parallel {}
    
    trait Predicate<T> {
        fn evaluate(&self, i: &T) -> bool;
    }
    
    trait Evaluator<ExecutionStrategy, T> {
        fn evaluate(&self, is: &[T]) -> Vec<bool>;
    }
    
    pub struct ParallelPredicate<T, P>(P, PhantomData<T>);
    
    impl<T, P> Evaluator<Parallel, T> for ParallelPredicate<T, P>
    where
        T: Sync,
        P: Predicate<T> + Sync,
    {
        fn evaluate(&self, is: &[T]) -> Vec<bool> {
            is.par_iter().map(|i| self.0.evaluate(i)).collect()
        }
    }
    
    fn into_par<T, P: Predicate<T>>(predicate: P) -> ParallelPredicate<T, P> {
        ParallelPredicate(predicate, PhantomData)
    }
    
    impl<T, P> Evaluator<Sequential, T> for P
    where
        P: Predicate<T>,
    {
        fn evaluate(&self, is: &[T]) -> Vec<bool> {
            is.iter().map(|i| self.evaluate(i)).collect()
        }
    }
    
    fn wants_evaluator<S, E: Evaluator<S, i32>>(evaluator: E) {
        evaluator.evaluate(&[1, 2, 3]);
    }
    
    fn main() {
        wants_evaluator(|v: &i32| true);
        wants_evaluator(into_par(|v: &i32| true));
    }