Search code examples
rustmultidimensional-arraypyo3rust-ndarray

Perform broadcast additions between ndarray slices created from a numpy array


I'm trying to write a Rust code that could be called from Python. For simplicity, this code should just take a two-dimensional boolean array and XOR the second row to the first one. I've tried to write this code:

use numpy::PyArray2;
use pyo3::prelude::{pymodule, PyModule, PyResult, Python};

#[pymodule]
fn state_generator(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
    #[pyfn(m)]
    fn random_cnots_transformation(_py: Python<'_>, x: &PyArray2<bool>) {
        let mut array = unsafe { x.as_array_mut() };
        let source = array.row(1);
        let mut target = array.row_mut(0);

        target ^= source;
    }

    Ok(())
}

but when compiling with maturin this fails with this error:

error[E0271]: type mismatch resolving `<ViewRepr<&mut bool> as RawData>::Elem == ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
  --> src/lib.rs:35:16
   |
35 |         target ^= source;
   |                ^^ expected `bool`, found `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
   |
   = note: expected type `bool`
            found struct `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
   = note: required for `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>` to implement `BitXorAssign<ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>>`

error[E0277]: the trait bound `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>: ScalarOperand` is not satisfied
  --> src/lib.rs:35:16
   |
35 |         target ^= source;
   |                ^^ the trait `ScalarOperand` is not implemented for `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
   |
   = help: the following other types implement trait `ScalarOperand`:
             bool
             isize
             i8
             i16
             i32
             i64
             i128
             usize
           and 9 others
   = note: required for `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>` to implement `BitXorAssign<ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>>`

error[E0271]: type mismatch resolving `<ViewRepr<&bool> as RawData>::Elem == ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
  --> src/lib.rs:35:16
   |
35 |         target ^= source;
   |                ^^ expected `bool`, found `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
   |
   = note: expected type `bool`
            found struct `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>`
   = note: required for `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>` to implement `BitXorAssign`
   = note: 1 redundant requirement hidden
   = note: required for `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>` to implement `BitXorAssign<ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>>`

error[E0277]: the trait bound `ViewRepr<&bool>: DataMut` is not satisfied
  --> src/lib.rs:35:16
   |
35 |         target ^= source;
   |                ^^ the trait `DataMut` is not implemented for `ViewRepr<&bool>`
   |
   = help: the trait `DataMut` is implemented for `ViewRepr<&'a mut A>`
   = note: required for `ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>` to implement `BitXorAssign`
   = note: 1 redundant requirement hidden
   = note: required for `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>` to implement `BitXorAssign<ArrayBase<ViewRepr<&bool>, Dim<[usize; 1]>>>`

I thought that maybe this came from the fact that I used ^=, so I tried tu use += instead but this failed with this error:

error[E0368]: binary assignment operation `+=` cannot be applied to type `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>`
  --> src/lib.rs:35:9
   |
35 |         target += source;
   |         ------^^^^^^^^^^
   |         |
   |         cannot use `+=` on type `ArrayBase<ViewRepr<&mut bool>, Dim<[usize; 1]>>`

I've read this answer but I'm not sure to see what's the difference here. If I'm not mistaken, the difference is that in my code, array is an ArrayViewMut<bool, lx2>, while in the answer it's an Array2.

What should I change in my code to perform such an operation?

If that matters, I'm using cargo 1.72.0 and here's my Cargo.toml:

[package]
name = "state_generator"
version = "0.1.0"
authors = ["Tristan NEMOZ"]
edition = "2021"

[lib]
crate-type = ["cdylib"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ndarray = "0.15.6"
numpy = "0.19.0"
rand = "0.8.5"

[dependencies.pyo3]
version = "0.19.2"
features = ["extension-module"]

Solution

  • The error is confusing, but the crux of it is that bitwise xor is implemented only with &ArrayView (more precisely, &ArrayBase), and not an owned ArrayView, so it's simple:

    target ^= &source;
    

    But then you'll face another error:

    error: cannot borrow `array` as mutable because it is also borrowed as immutable
     --> src/lib.rs:5:22
      |
    4 |     let source = array.row(1);
      |                  ------------ immutable borrow occurs here
    5 |     let mut target = array.row_mut(0);
      |                      ^^^^^^^^^^^^^^^^ mutable borrow occurs here
    6 |
    7 |     target ^= &source;
      |               ------- immutable borrow later used here
    

    Which can be solved by using an iterator over the rows:

    let mut rows = array.rows_mut().into_iter();
    let mut target = rows.next().unwrap();
    let source = rows.next().unwrap();
    
    target ^= &source;