Search code examples
rusttraits

How to implement ops::Mul on a struct so it works with numerical types as well as another struct?


I have implemented a Point3D struct:

use std::ops;
#[derive(Debug, PartialEq)]
pub struct Point3D {
    pub x: f32,
    pub y: f32,
    pub z: f32,
}

impl ops::Add<&Point3D> for &Point3D {
    type Output = Point3D;
    fn add(self, rhs: &Point3D) -> Point3D {
        Point3D {
            x: self.x + rhs.x,
            y: self.y + rhs.y,
            z: self.z + rhs.z,
        }
    }
}

impl ops::Sub<&Point3D> for &Point3D {
    type Output = Point3D;
    fn sub(self, rhs: &Point3D) -> Point3D {
        Point3D {
            x: self.x - rhs.x,
            y: self.y - rhs.y,
            z: self.z - rhs.z,
        }
    }
}

impl ops::Mul<&Point3D> for &Point3D {
    type Output = f32;
    fn mul(self, rhs: &Point3D) -> f32 {
        self.x * rhs.x + self.y * rhs.y + self.z * rhs.z
    }
}

//Scalar impl of ops::Mul here

#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn addition_point_3D() {
        let point1 = Point3D {
            x: 1.0,
            y: 2.0,
            z: 3.0,
        };
        let point2 = Point3D {
            x: 4.0,
            y: 5.0,
            z: 6.0,
        };
        let result = &point1 + &point2;
        assert_eq!(
            result,
            Point3D {
                x: 5.0,
                y: 7.0,
                z: 9.0
            },
            "Testing Addition with {:?} and {:?}",
            point1,
            point2
        );
    }

    #[test]
    fn subtraction_point_3D() {
        let point1 = Point3D {
            x: 1.0,
            y: 2.0,
            z: 3.0,
        };
        let point2 = Point3D {
            x: 4.0,
            y: 5.0,
            z: 6.0,
        };
        let result = &point1 - &point2;
        assert_eq!(
            result,
            Point3D {
                x: -3.0,
                y: -3.0,
                z: -3.0
            },
            "Testing Subtraction with {:?} and {:?}",
            point1,
            point2
        );
    }

    #[test]
    fn point3D_point3D_multiplication() {
        let point1 = Point3D {
            x: 1.0,
            y: 2.0,
            z: 3.0,
        };
        let point2 = Point3D {
            x: 4.0,
            y: 5.0,
            z: 6.0,
        };
        let result = &point1 * &point2;
        assert_eq!(
            result, 32.0,
            "Testing Multiplication with {:?} and {:?}",
            point1, point2
        );
    }

    /*
    #[test]
    fn point3D_scalar_multiplication() {
        let point1 = Point3D { x: 1.0, y: 2.0, z: 3.0};
        let scalar = 3.5;
        let result = &point1 * &scalar;
        assert_eq!(result, Point3D { x: 3.5, y: 7.0, z: 10.5 }, "Testing Multiplication with {:?} and {:?}", point1, scalar);
    }
    */
}

I would like to use generics in my multiplication trait so that if I pass it another Point3D class it will implement the dot product, but if I pass it a basic numeric type (integer, f32, unsigned integer, f64) it will multiply x, y, and z by the scalar value. How would would I do this?


Solution

  • To do this with generics you first need to make your Point3D struct accept generics, like

    use std::ops::{Mul, Add};
    
    #[derive(Debug, PartialEq)]
    pub struct Point3D<T> {
        pub x: T,
        pub y: T,
        pub z: T,
    }
    

    And your implementation of multiplication of Point3D with a numeric type would be

    impl<T> Mul<T> for &Point3D<T>
        where T: Mul<Output=T> + Copy
    {
        type Output = Point3D<T>;
        fn mul(self, rhs: T) -> Self::Output {
            Point3D {
                x: self.x * rhs,
                y: self.y * rhs,
                z: self.z * rhs,
            }
        }
    }
    

    We have the where clause because our generic T would need to implement the traits Mul and Copy as well. Copy because we need to copy rhs to use in all the three multiplications.

    Your dot product implementation would also need to change according to

    impl<T> Mul<&Point3D<T>> for &Point3D<T> 
        where T: Mul<Output=T> + Add<Output=T> + Copy
    {
        type Output = T;
        fn mul(self, rhs: &Point3D<T>) -> Self::Output {
            self.x * rhs.x + self.y * rhs.y + self.z * rhs.z
        }
    }
    

    with the Add because we of course need to be able to add the generics T here.