Search code examples
rustimmutabilitytraitslifetimeode

Assignment from function to variable defined as mutable reference?


I'm trying to implement the ode_solvers crate to integrate my system of equations. To do so I have a function that calculates what dydx will be and then I call that within the implementation of the system as the example in the documents describe

type State = DVector<f64>;

impl ode_solvers::System<State> for NBody {
    fn system<'a>(&self, _t: Time, y: &State, mut dy: &'a mut State) {
        
        dy = get_derivatives(y, &self.masses, self.n_objects)
    }
}

When I do this I get type mismatch errors as the function get derivatives outputs a State type but dy expects a &mut State. But if I change the output of the function to a mutable reference then I get lifetime issues I don't know how to resolve. And finally I keep getting an error that I can't assign to immutable dy.

The example in the documents shows this -

impl ode_solvers::System<State> for KeplerOrbit {
    // Equations of motion of the system
    fn system(&self, _t: Time, y: &State, dy: &mut State) {
        let r = (y[0] * y[0] + y[1] * y[1] + y[2] * y[2]).sqrt();

        dy[0] = y[3];
        dy[1] = y[4];
        dy[2] = y[5];
        dy[3] = - self.mu * y[0] / r.powi(3);
        dy[4] = - self.mu * y[1] / r.powi(3);
        dy[5] = - self.mu * y[2] / r.powi(3);
    }
}

I assume the example works fine because each value being assigned has copy implemented for it but my assignment is a vec itself so it's unclear how to resolve this.


Solution

  • When you want to assign to the thing that a mutable reference points to, you must dereference it; otherwise you're trying to replace the reference in the local variable, which does no good even if you succeeded. These two lines:

    fn system<'a>(&self, _t: Time, y: &State, mut dy: &'a mut State) {     
        ...
        dy = get_derivatives(y, &self.masses, self.n_objects)
    

    should instead be

    fn system<'a>(&self, _t: Time, y: &State, dy: &'a mut State) {
        ...
        *dy = get_derivatives(y, &self.masses, self.n_objects);
    

    The added dereference * operator in *dy = ... is what you need. Note also that the variable dy does not need to be declared mut (because you are not reassigning the reference itself) — that is, you can change mut dy: &'a mut State to dy: &'a mut State.

    The reason this doesn't happen when you are accessing a vector element dy[i] is because the indexing operator has a dereference built in; it expands to roughly *IndexMut::index_mut(&mut dy, i), and the extra &mut is handled by deref coercion. (In general, thanks to deref coercion, you can often ignore extra layers of references, like &&Foo and &mut &mut Foo, but you need to be careful with the difference between a specific type and any reference or references to it.)