Search code examples
rustiteratorclosures

Pass function as an argument where inner function input is an Iterator


This is a follow up question to Use map object as function input .

I would like to pass the function compute_var_iter as an argument to another function but I'm struggling to make sense of the error messages:

This works:

use num::Float;

fn compute_var_iter<I, T>(vals: I) -> T
where
    I: Iterator<Item = T>,
    T: Float + std::ops::AddAssign,
{
    // online variance function
    let mut x = T::zero();
    let mut xsquare = T::zero();
    let mut len = T::zero();

    for val in vals {
        x += val;
        xsquare += val * val;
        len += T::one();
    }

    ((xsquare / len) - (x / len) * (x / len)) / (len - T::one()) * len
}

fn main() {
    let a = (1..10000001).map(|i| i as f64);
    let b = (1..10000001).map(|i| i as f64);

    dbg!(compute_var_iter(a.zip(b).map(|(a, b)| a * b)));
}

but when I try this:

use num::Float;

fn compute_var_iter<I, T>(vals: I) -> T
where
    I: Iterator<Item = T>,
    T: Float + std::ops::AddAssign,
{
    // online variance function
    let mut x = T::zero();
    let mut xsquare = T::zero();
    let mut len = T::zero();

    for val in vals {
        x += val;
        xsquare += val * val;
        len += T::one();
    }

    ((xsquare / len) - (x / len) * (x / len)) / (len - T::one()) * len
}

fn use_fun<I, Fa, T>(aggregator: Fa, values: &[T], weights: &[T]) -> T
where
    I: Iterator<Item = T>,
    T: Float + std::ops::AddAssign,
    Fa: Fn(I) -> T
{
   aggregator(values[0..10].iter().zip(weights).map(|(x, y)| *x * *y))
}

fn main() {
    let a: Vec<f64> = (1..10000001).map(|i| i as f64).collect();
    let b: Vec<f64> = (1..10000001).map(|i| i as f64).collect();
    
    dbg!(use_fun(compute_var_iter, &a, &b));
}


I get the errors:

error[E0308]: mismatched types
  --> src/main.rs:43:16
   |
37 | fn use_fun<I, Fa, T>(aggregator: Fa, values: &[T], weights: &[T]) -> T
   |            - this type parameter
...
43 |     aggregator(values[0..10].iter().zip(weights).map(|(x, y)| *x * *y))
   |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `I`, found struct `Map`
   |
   = note: expected type parameter `I`
                      found struct `Map<Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, T>>, [closure@src/main.rs:43:54: 43:70]>`

error[E0282]: type annotations needed
  --> src/main.rs:54:10
   |
54 |     dbg!(use_fun(compute_var_iter, &a, &b));
   |          ^^^^^^^ cannot infer type for type parameter `I` declared on the function `use_fun`

How should the type annotation work to get this running?


Solution

  • fn use_fun<I, Fa, T>(aggregator: Fa, values: &[T], weights: &[T]) -> T
    

    This signature says that I is a type parameter, meaning that the caller of use_fun is allowed to choose any iterator type.

       aggregator(values[0..10].iter().zip(weights).map(|(x, y)| *x * *y))
    

    But in the body of the function, you don't pass an iterator of type I; you pass an iterator of type Map<Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, T>>, [closure@src/main.rs:43:54: 43:70]>, which is the type of the chain of iterator adapters and a slice iterator that you made and passed.

    What you need to do, in principle, is change your function bounds to require the provided function to be a function which accepts that specific type (because it's not possible to require the function to be generic over any iterator). Unfortunately, it is currently not possible to write down that type in Rust source code (because it contains a closure type), using stable Rust. Here are some solutions:

    1. Use an unstable feature (available only in nightly builds of the compiler, not the stable release), type_alias_impl_trait, to give a name to the closure:

      #![feature(type_alias_impl_trait)]
      
      type MyIter<'a, T: 'a> = impl Iterator<Item = T> + 'a;
      fn make_iter<'a, T>(values: &'a [T], weights: &'a [T]) -> MyIter<'a, T>
      where
          T: Float + std::ops::AddAssign + 'a,
      {
          values[0..10].iter().zip(weights).map(|(x, y)| *x * *y)
      }
      
      fn use_fun<'a, Fa, T>(aggregator: Fa, values: &'a [T], weights: &'a [T]) -> T
      where
          T: Float + std::ops::AddAssign,
          Fa: Fn(MyIter<'a, T>) -> T,
      {
          aggregator(make_iter(values, weights))
      }
      

      (The make_iter helper function is necessary to let the type ... = impl ... know what concrete-but-unnameable type it is supposed to have, by putting the impl MyIterator in a return-type position.)

    2. Instead of using .map(closure), write your own iterator adapter (a struct and an impl Iterator for it), which will thus have a nameable type.

    3. Declare the type of the iterator as mapping using a function pointer (which is a nameable type, but might be slightly less efficient if the optimizer doesn't manage to eliminate it):

      type MyIter<'a, T> = std::iter::Map<
          std::iter::Zip<std::slice::Iter<'a, T>, std::slice::Iter<'a, T>>,
          fn((&T, &T)) -> T,
      >;
      
      fn use_fun<'a, Fa, T>(aggregator: Fa, values: &'a [T], weights: &'a [T]) -> T
      where
          T: Float + std::ops::AddAssign,
          Fa: Fn(MyIter<'a, T>) -> T,
      {
          aggregator(values[0..10].iter().zip(weights).map(|(x, y)| *x * *y))
      }
      
    4. Use a boxed type-erased iterator, Box<dyn Iterator>. (This will likely be somewhat slower than the other options.)

      fn use_fun<'a, Fa, T>(aggregator: Fa, values: &'a [T], weights: &'a [T]) -> T
      where
          T: Float + std::ops::AddAssign,
          Fa: Fn(Box<dyn Iterator<Item = T> + 'a>) -> T,
      {
          aggregator(Box::new(values[0..10].iter().zip(weights).map(|(x, y)| *x * *y)))
      }
      
    5. Pass a collection instead of an iterator to the aggregator function.

    6. Instead of passing an iterator to aggregator, make the aggregator accept the data in "push" fashion: perhaps it could implement the Extend trait, or just be a FnMut that is called with each value. This avoids the need for the aggregator to accept an type chosen by use_fn — it only needs to know about the number type.