Search code examples
for-looprustmultidimensional-arraytypestraits

Rust mismatch between input and output types of trait object


So my issue is that I have a layer trait with input and output types as follows:

pub trait Layer {
    type Input: Dimension;
    type Output: Dimension;
    
    fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output>;
}

With this forward function:

impl<A: Activation> Layer for DenseLayer<A> {
    type Input = Ix2;
    type Output = Ix2;

    fn forward(&mut self, input: &Array2<f32>) -> Array2<f32> {
        assert_eq!(input.shape()[1], self.weights.shape()[0], "Input width must match weight height.");
        let z = input.dot(&self.weights) + &self.biases;

        self.activation.activate(&z)
    }
}

I have these so that my forward or backwards functions can take in for example an array of 2 dimensions but still output one with only 1 dimension. Then I have an implementation for a sort of wrapper of this layer trait where I want to forward through all the layers:

pub struct NeuralNetwork<'a, L>
where
    L: Layer + 'a,
{
    layers: Vec<L>,
    loss_function: &'a dyn Cost,
}

impl<'a, L> NeuralNetwork<'a, L>
where
    L: Layer + 'a,
{
    pub fn new(layers: Vec<L>, loss_function: &'a dyn Cost) -> Self {
        NeuralNetwork { layers, loss_function }
    }

    pub fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, L::Input>) -> ArrayBase<OwnedRepr<f32>, L::Output> {
        let mut output = input.clone();

        // todo fix the layer forward changing input to output
        // causing mismatch in the input and output dimensions of forward
        for layer in &mut self.layers {
            output = layer.forward(&output);
        }

        output
    }
}

Now because in the for loop I first input of type input, then receive output from layer.forward. In the next iteration it takes the type output, but the layer.forward only accepts type input. Atleast that is what I think is happening. This might seem like a really simple issue but I am genuinly unsure on how to fix this.

Edit 1:

Reproduceable Example:

use ndarray::{Array, Array2, ArrayBase, Dimension, OwnedRepr};

pub trait Layer {
    type Input: Dimension;
    type Output: Dimension;

    fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output>;
}

// A Dense Layer struct
pub struct DenseLayer {
    weights: Array2<f32>,
    biases: Array2<f32>,
}

impl DenseLayer {
    pub fn new(input_size: usize, output_size: usize) -> Self {
        let weights = Array::random((input_size, output_size), rand::distributions::Uniform::new(-0.5, 0.5));
        let biases = Array::zeros((1, output_size));
        DenseLayer { weights, biases }
    }
}

impl Layer for DenseLayer {
    type Input = ndarray::Ix2;  // Two-dimensional input
    type Output = ndarray::Ix2; // Two-dimensional output

    fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output> {
        assert_eq!(input.shape()[1], self.weights.shape()[0], "Input width must match weight height.");
        let z = input.dot(&self.weights) + &self.biases;
        z // Return the output directly without activation
    }
}

// Neural Network struct
pub struct NeuralNetwork<'a, L>
where
    L: Layer + 'a,
{
    layers: Vec<L>,
}

impl<'a, L> NeuralNetwork<'a, L>
where
    L: Layer + 'a,
{
    pub fn new(layers: Vec<L>) -> Self {
        NeuralNetwork { layers }
    }

    pub fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, L::Input>) -> ArrayBase<OwnedRepr<f32>, L::Output> {
        let mut output = input.clone();

        for layer in &mut self.layers {
            output = layer.forward(&output);
        }

        output
    }
}

fn main() {
    // Create a neural network with one Dense Layer
    let mut dense_layer = DenseLayer::new(3, 2);
    let mut nn = NeuralNetwork::new(vec![dense_layer]);

    // Create an example input (1 batch, 3 features)
    let input = Array::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
    
    // Forward pass
    let output = nn.forward(&input);
    println!("Output: {:?}", output);
}


Solution

  • There's two things you need to get NeuralNetwork::forward to compile.

    • You need to restrict the Layer bound so that the Input and Output associated types are the same type.
    • You need to ensure that this type implements Clone so that input.clone() can clone the underlying array instead of cloning the reference.

    These bounds will communicate these restrictions to the compiler (note the introduction of a new generic parameter T on the impl block):

    impl<'a, L, T> NeuralNetwork<'a, L>
    where
        L: Layer<Input = T, Output = T> + 'a,
        T: Clone,
    

    Note that you should consider moving NeuralNetwork::new to an impl block with minimal restrictions, as there's no reason it needs these restrictions applied to it.


    There are a few other compile-time errors but I assume these are unrelated to the problem you're trying to solve. In particular, it's not clear to me why you have a 'a lifetime on NeuralNetwork; you can remove it completely and the code still compiles.