Search code examples
genericsrustlifetimehigher-rank-types

Why does using sum() here require a higher-ranked trait bound?


Im writing a generic matrix library in Rust, and wanted to make a dot product function

this:

pub fn dot<'a, R: Scalar>(&'a self, rhs: &'a Matrix<R, M, 1>) -> T
where
    T: Sum<&'a T>,
    &'a Self: Mul<&'a Vector<R, M>, Output = Self>,
{
    (self * rhs).elements().sum::<T>()
}

caused a confusing error I didn't understand:

304 |     pub fn dot<'a, R: Scalar>(&'a self, rhs: &'a Matrix<R, M, 1>) -> T
    |                -- lifetime `'a` defined here
...
309 |         (self * rhs).elements().sum::<T>()
    |         ^^^^^^^^^^^^-----------
    |         |
    |         creates a temporary which is freed while still in use
    |         argument requires that borrow lasts for `'a`
310 |     }
    |     - temporary value is freed at the end of this statement

After searching for similar issues blindly trying different things, I came to the solution of

pub fn dot<'a, R: Scalar>(&'a self, rhs: &'a Matrix<R, M, 1>) -> T
where
    T: for<'b> Sum<&'b T>,
    &'a Self: Mul<&'a Vector<R, M>, Output = Self>,
{
    (self * rhs).elements().sum::<T>()
}

this works, but I don't understand why its necessary. Why is the temporary not lasting the length of the entire function block? and what is the lifetime that 'b is referring to?

Context:

Matrix is

#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Matrix<T, const M: usize, const N: usize>
where
    T: Scalar,
{
    data: [[T; N]; M], // Column-Major order
}

and Vector is

pub type Vector<T, const N: usize> = Matrix<T, N, 1>;

elements() returns a flattened iterator over data. mul returns the same type as Self. I get the same error if I change the dot() body to:

self.clone().data[0].iter().sum::<T>()

so it isn't an issue with other methods on Matrix

the same thing also happens if I assign self * rhs to a variable


Solution

  • The problem is that you're using too many lifetime annotations as is. If you actually think about the lifetimes you've given the compiler, it's right. You are dropping the temporary while it's still in use. With the bound T: Sum<&'a T> you've told it that a T can be made from &'a Ts. And 'a is the same lifetime for which self is borrowed. But: &Self * &Vector<R, M> returns a new, owned Self, call it P, that lives only for the remainder of the function, call that lifetime 'p.

    Now, elements probably has a signature along the lines of fn elements<'a>(&'a self) -> impl Iterator<Item = &'a T> (maybe some of those lifetimes are actually elided, but they're still there). So when you call P.elements, you're calling, essentially, elements(&'p P) -> impl Iterator<Item = &'p T>.

    Then, when you call sum::<T> on that, the compiler only knows how to sum a new T from &'a Ts, so it infers that 'p must be equal to 'a. But they're clearly not. 'p only lasts to the end of the function while 'a keeps on living. That's why you're getting the error: You've tied the T you're returning to the lifetime of self for no reason.

    A rule of thumb I've not thought too hard about but might still be useful is: if you borrow self immutably and the return type doesn't capture a lifetime (i.e. isn't a reference, a trait object, an impl Trait, or a caller supplied generic), you probably don't need to specify the arguments' lifetimes.

    That is to say: you don't care about the lifetimes of self or rhs because you're returning an owned type that doesn't reference either of them.

    Of course, you'll still need lifetimes in your where clause and that's where for<'x> _: ... comes in. That's called a HRTB and basically says, "for any lifetime I plug into 'x what comes after the : will be valid." That's what you care about. That it works as long as you only have references, not as long as you have references that live for any particular 'x.

    So you can simply re-write dot's signature like this

    pub fn dot<R: Scalar>(&self, rhs: &Matrix<R, M, 1>) -> T
    where
        for<'x> T: Sum<&'x T>,
        for<'y> &'y Self: Mul<&'y Vector<R, M>, Output = Self>,
    { // ...