I am very new to Rust. Currently, I am looking for a way to generate a matrix with dimension based on a tuple.
use itertools::zip;
use ndarray::Array;
fn main() {
let mut layer_width: [u64; 4] = [784, 512, 256, 10]; //in- & output layers of the nn
init_nn(&mut layer_width);
}
fn init_nn(layer_width: &mut [u64; 4]) {
for (layer_in, layer_out) in zip(&layer_width[.. 4], &layer_width[1 ..]) {
let mut params = Array::zeros((layer_in, layer_out)); //error
}
}
The iteration through the zip works fine and i get output for either layer_in and _out, but when creating the the matrix I get the following error:
the trait bound `(&i64, &i64): ndarray::Dimension` is not satisfied
the trait `ndarray::Dimension` is not implemented for `(&i64, &i64)`
note: required because of the requirements on the impl of `ndarray::IntoDimension` for `(&i64, &i64)`rustc(E0277)
main.rs(13, 39): the trait `ndarray::Dimension` is not implemented for `(&i64, &i64)`
I very much need help from the community on this issue here. Many thanks.
The issue is you're passing in (&i64, &i64)
to Array::zeros()
, which is not valid. Instead, you can pass in (usize, usize)
. After fixing that, the code will still not compile, as we haven't given the compiler any way of knowing the element type, but that error will resolve itself once you do something like assign to the array.
Here's working code:
use itertools::zip;
use ndarray::Array;
fn main() {
let mut layer_width: [usize; 4] = [784, 512, 256, 10]; // in- & output layers of the nn
init_nn(&mut layer_width);
}
fn init_nn(layer_width: &mut [usize; 4]) {
for (&layer_in, &layer_out) in zip(&layer_width[..4], &layer_width[1..]) {
let mut params = Array::zeros((layer_in, layer_out));
// Dummy assignment so the compiler can infer the element type
params[(0, 0)] = 1;
}
}
Notice the added &
in for (&layer_in, &layer_out)
. The output of the zip()
function is (&usize, &usize)
, so we are using destructuring to dereference the references into plain usize
s. Equivalently, you could have done Array::zeros((*layer_in, *layer_out))
.