Search code examples
genericsrustiterator

Correctly make blanket implementation of argmax trait for iterators


I decided on trying to make a trait in Rust using blanket implementation, and the test method to implement was a trait that returns the argmax over an iterator together with the element. Right now the implementation is

use num::Bounded;

trait Argmax<T> {
    fn argmax(self) -> (usize, T);
}

impl<I, T> Argmax<T> for I
where
    I: Iterator<Item = T>,
    T: std::cmp::PartialOrd + Bounded,
{
    fn argmax(self) -> (usize, T) {
        self.enumerate()
            .fold((0, T::min_value()), |(i_max, val_max), (i, val)| {
                if val >= val_max {
                    (i, val)
                } else {
                    (i_max, val_max)
                }
            })
    }
}

Testing it with this code

fn main() {
    let v = vec![1., 2., 3., 4., 2., 3.];
    println!("v: {:?}", v);
    let (i_max, v_max) = v.iter().copied().argmax();
    println!("i_max: {}\nv_max: {}", i_max, v_max);
}

works, while

fn main() {
    let v = vec![1., 2., 3., 4., 2., 3.];
    println!("v: {:?}", v);
    let (i_max, v_max) = v.iter().argmax();
    println!("i_max: {}\nv_max: {}", i_max, v_max);
}

doesn't compile, and gives these errors:

  --> src/main.rs:27:35
   |
27 |     let (i_max, v_max) = v.iter().argmax();
   |                                   ^^^^^^ method cannot be called on `std::slice::Iter<'_, {float}>` due to unsatisfied trait bounds
   |
   = note: the following trait bounds were not satisfied:
           `<&std::slice::Iter<'_, {float}> as Iterator>::Item = _`
           which is required by `&std::slice::Iter<'_, {float}>: Argmax<_>`
           `&std::slice::Iter<'_, {float}>: Iterator`
           which is required by `&std::slice::Iter<'_, {float}>: Argmax<_>`

error: aborting due to previous error

For more information about this error, try `rustc --explain E0599`.

I figure that the problem originates from the fact that .iter() loops over references, while .iter().copied() loops over actual values, but I still can't wrap my head around the error message and how to make it generic and working with looping over references.

EDIT: After being recommended trying to implement the above using associated types instead of generic types, and ended up with this working implementation for later reference:

trait Argmax {
    type Maximum;

    fn argmax(self) -> Option<(usize, Self::Maximum)>;
}

impl<I> Argmax for I
where
    I: Iterator,
    I::Item: std::cmp::PartialOrd,
{
    type Maximum = I::Item;

    fn argmax(mut self) -> Option<(usize, Self::Maximum)> {
        let v0 = match self.next() {
            Some(v) => v,
            None => return None,
        };

        Some(
            self.enumerate()
                .fold((0, v0), |(i_max, val_max), (i, val)| {
                    if val > val_max {
                        (i + 1, val) // Add 1 as index is one off due to next() above
                    } else {
                        (i_max, val_max)
                    }
                }),
        )
    }
}

This implementation also has Bounded removed as dependency, and instead checks if the iterator is empty, and if not, initializes the current maximum with the first element returned by the iterator. This implementation returns the index of the first maximum it finds.


Solution

  • I still can't wrap my head around the error message

    Unfortunately, the error message is cryptic, because it doesn't tell you what <&std::slice::Iter<'_, {float}> as Iterator>::Item is, which is the key fact — just what it isn't. (Possibly it doesn't help that {float}, a not-yet-chosen numeric type, is involved. I'm also not sure what the & is doing there, since there's no reference to an iterator involved.)

    However, if you look up the documentation for std::slice::Iter<'a, T> you will find that its item type is &'a T, so in this case, &'a {float}.

    This tells you what you already know: the iterator is over references. Unfortunately the error message doesn't tell you much about the remainder of the problem. But if I check out the docs for num::Bounded I find, unsurprisingly, that Bounded is not implemented for references to numbers. This is unsurprising because references must be to values which exist in memory, and so it can be tricky or impossible to construct references which aren't borrowing some existing data structure. (I think it might be possible in this case, but num hasn't implemented that.)

    and how to make it generic and working with looping over references.

    It's not possible as long as you choose to use the Bounded trait, because Bounded is not implemented for references to primitive numbers, and it's not possible to provide two different blanket implementations for &T and T.

    (You could implement Bounded for a type of your own, MyWrapper<f32>, and references to it, but then users have to deal with that wrapper.)

    Here are some options:

    1. Keep the code you currently have, and live with the need to write .copied(). It is not at all uncommon to have this situation in other iterators — don't make code more hairy just for the sake of avoiding one extra function call.

    2. Write a version of argmax() with return type Option<(usize, T)>, producing None when the iterator is empty. Then, there is no need to use Bounded and the code will work with only the PartialEq constraint. Also, it will not return a meaningless index and value when the iterator is empty — this is generally considered a virtue in Rust code. The caller can always use .unwrap_or_else() if (0, T::min_value()) is an appropriate answer for their application.

    3. Write a version of argmax() which takes a separate initial value, rather than using T::min_value().