Search code examples
rustautomatic-differentiation

What is wrong with this implementation of matmul for automatic differentiation?


I have an autodiff tensor library in Rust where I implement Tensor as:

use std::cell::{RefCell, RefMut};
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
use std::ops::{Deref, DerefMut, Index, IndexMut, AddAssign, Add};
use std::rc::Rc;
use uuid::Uuid;
use num_traits::Zero;

pub struct TensorData {
    pub data: NdArray<f64, 2>,
    pub grad: NdArray<f64, 2>,
    pub uuid: Uuid,
    backward: Option<fn(&TensorData)>,
    prev: Vec<Tensor>,
    op: Option<String>
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)]
pub struct NdArray<T: Clone, const N: usize> {
    pub shape: [usize; N],
    pub data: Vec<T>,
}

impl<T: Clone, const N: usize> NdArray<T, N> {
    pub fn new(array: &[T], shape: [usize; N]) -> Self {
        NdArray {
            shape,
            data: array.to_vec(),
        }
    }

    pub fn zeros(shape: [usize; N]) -> Self
    where
        T: Clone + Zero,
    {
        NdArray {
            shape,
            data: vec![T::zero(); shape.iter().product()],
        }
    }

    pub fn fill(&mut self, val: T) {
        self.data = vec![val; self.shape.iter().product()]
    }

    fn get_index(&self, idx: &[usize; N]) -> usize {
        let mut i = 0;
        for j in 0..self.shape.len() {
            if idx[j] >= self.shape[j] {
                println!("Index {} is out of bounds for dimension {} with size {}", idx[j], j, self.shape[j])
            }
            i = i * self.shape[j] + idx[j];
        }
        i
    }

    pub fn transpose(mut self) -> NdArray<T, N> {
        self.shape.reverse();
        self
    }
}

impl<T: Clone, const N: usize> Index<&[usize; N]> for NdArray<T, N> {
    type Output = T;

    fn index(&self, idx: &[usize; N]) -> &T {
        let i = self.get_index(idx);
        &self.data[i]
    }
}

impl<T: Clone, const N: usize> IndexMut<&[usize; N]> for NdArray<T, N> {
    fn index_mut(&mut self, idx: &[usize; N]) -> &mut T {
        let i = self.get_index(idx);
        &mut self.data[i]
    }
}

impl NdArray<f64, 2> {
    /// Finds the matrix product of 2 matrices
    pub fn matmul(&self, b: &NdArray<f64, 2>) -> NdArray<f64, 2> {
        assert_eq!(self.shape[1], b.shape[0]);
        let mut res: NdArray<f64, 2> = NdArray::zeros([self.shape[0], b.shape[1]]);
        for row in 0..self.shape[0] {
            for col in 0..b.shape[1] {
                for el in 0..b.shape[0] {
                    res[&[row, col]] += self[&[row, el]] * b[&[el, col]]
                }
            }
        }
        res
    }
}

impl<T: Clone + Add<Output = T>, const N: usize> AddAssign<NdArray<T, N>> for NdArray<T, N> {
    fn add_assign(&mut self, rhs: NdArray<T, N>) {
        let sum_vec: Vec<T> = self
            .data
            .iter()
            .zip(&rhs.data)
            .map(|(a, b)| a.clone() + b.clone())
            .collect();
        self.data = sum_vec;
    }
}

impl Deref for Tensor {
    type Target = Rc<RefCell<TensorData>>;
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl DerefMut for Tensor {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

impl Hash for Tensor {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.borrow().uuid.hash(state);
    }
}

impl PartialEq for Tensor {
    fn eq(&self, other: &Self) -> bool {
        self.borrow().uuid == other.borrow().uuid
    }
}

impl Eq for Tensor {}

#[derive(Clone)]
pub struct Tensor(Rc<RefCell<TensorData>>);

impl TensorData {
    fn new(data: NdArray<f64, 2>) -> TensorData {
        let shape = data.shape;
        TensorData {
            data, 
            grad: NdArray::zeros(shape),
            uuid: Uuid::new_v4(),
            backward: None,
            prev: Vec::new(),
            op: None
        }
    }
}

impl Tensor {
    pub fn new(array: NdArray<f64, 2>) -> Tensor {
        Tensor(Rc::new(RefCell::new(TensorData::new(array))))
    }

    pub fn shape(&self) -> [usize; 2] {
        self.borrow().data.shape
    }

    pub fn grad(&self) -> NdArray<f64, 2> {
            self.borrow().grad.clone()
    }

    pub fn grad_mut(&self) -> impl DerefMut<Target = NdArray<f64, 2>> + '_ {
        RefMut::map((*self.0).borrow_mut(), |mi| &mut mi.grad)
    }

    pub fn backward(&self) {
        let mut topo: Vec<Tensor> = vec![];
        let mut visited: HashSet<Tensor> = HashSet::new();
        self._build_topo(&mut topo, &mut visited);
        topo.reverse();

        self.grad_mut().fill(1.0);
        for v in topo {
            if let Some(backprop) = v.borrow().backward {
                backprop(&v.borrow());
            }
        }
    }

    fn _build_topo(&self, topo: &mut Vec<Tensor>, visited: &mut HashSet<Tensor>) {
        if visited.insert(self.clone()) {
            self.borrow().prev.iter().for_each(|child| {
                child._build_topo(topo, visited);
            });
            topo.push(self.clone());
        }
    }

    pub fn matmul(&self, rhs: &Tensor) -> Tensor {
        let a_shape = self.shape();
        let b_shape = rhs.shape();
        assert_eq!(a_shape[1], b_shape[0]);
        let res: NdArray<f64, 2> = self.borrow().data.matmul(&rhs.borrow().data);
        let out = Tensor::new(res);
        out.borrow_mut().prev = vec![self.clone(), rhs.clone()];
        out.borrow_mut().op = Some(String::from("matmul"));
        out.borrow_mut().backward = Some(|value: &TensorData| {
            let lhs = value.prev[0].borrow().data.clone();
            let rhs = value.prev[1].borrow().data.clone();
            let da = value.grad.clone().matmul(&rhs.transpose());
            let db = lhs.transpose().matmul(&value.grad.clone());
            value.prev[0].borrow_mut().grad += da;
            value.prev[1].borrow_mut().grad += db;
        });
        out
    }
}

fn main() {
    let a_array = NdArray::new(&[1.0, 2.0, 3.0, 4.0], [2, 2]);
    let a = Tensor::new(a_array);
    let b = a.matmul(&a);
    b.backward();
    println!("Grad: {:?}", a.grad())
}

That is, each Tensor<N> wraps an NdArray containing its data and an NdArray containing its gradient. The core NdArray type contains basic array operations, including matmul(), dot(), etc.

I am trying to implement the matmul() method for autodiff, and this is what I have so far. I tested this matmul implementation, but this matmul implementation gives the incorrect gradient of:

[[7.0  9.0 ]
 [11.0 13.0]]

Rather than the correct gradient of:

[[7.0 11.0]
 [9.0 13.0]]

What part of the matmul implementation is incorrect?


Solution

  • I found the issue - it was in my implementation of transpose(). I implemented a new transpose() for NdArray<f64, 2>. The working, correct implementation of transpose() is:

    impl NdArray<f64, 2> {
        pub fn transpose(&self) -> NdArray<f64, 2> {
            let mut shape = self.shape.clone();
            shape.reverse();
            let mut result = NdArray::zeros(shape);
            for i in 0..shape[0] {
                for j in 0..shape[1] {
                    result[&[i, j]] = self[&[j, i]];
                }
            }
            result
        }
    }
    

    With this change to the code the gradients are accurate.