Search code examples
c#matrixrotational-matrices

Rodrigues Rotation about an arbitrary axis


Suppose, I have a candidate vector v(vx, vy, vz). I want to rotate it theta degrees about an arbitrary axis that starts at vector s(sx,sy,sz) and ends at vector e(ex, ey, ez) when the origin of the axes is located at o(ox, oy, oz).

So, I am doing the Rodrigues rotation in the following source code, but it is not giving the correct results:

public class RotoTranslation
{
    private readonly Vec3 origin;
    private readonly double cosTheta;
    private readonly double sinTheta;
    private readonly Vec3 axis;
    private readonly Matrix3x3 rotationMatrix;
    private readonly Matrix3x3 translationMatrix;
    private readonly Matrix3x3 invTranslationMatrix;
    public RotoTranslation(
                Vec3 origin, 
                Vec3 start, 
                Vec3 end, 
                double angle_rad)
    {
        this.origin = origin;
        this.axis = Vec3.Normalize(end - start);
        this.cosTheta = Math.Cos(angle_rad);
        this.sinTheta = Math.Sin(angle_rad);

        Matrix3x3 uOuter = 
            Vec3.OuterProduct_mat(axis, axis);
        Matrix3x3 uCross = 
            new Matrix3x3(0.0f, -axis.z, axis.y,                   
                        axis.z, 0.0f, -axis.x,                
                        -axis.y, axis.x, 0.0f);

        rotationMatrix = 
             cosTheta * Matrix3x3.Identity()
             + (1.0f - cosTheta) * uOuter
             + sinTheta * uCross;

        translationMatrix = 
             Matrix3x3.CreateTranslation(-origin);
        rotationMatrix = 
             rotationMatrix * translationMatrix;
        invTranslationMatrix = 
            Matrix3x3.CreateTranslation(origin);
    }
    public Vec3 RotateVector(Vec3 vector)
    {
        Vec3 transformedVector = 
            Vec3.TransformNormal(
                vector - origin, 
                rotationMatrix);
        return Vec3.Transform(
                transformedVector, 
                invTranslationMatrix);
    }
}

Test

[TestMethod]
public void TestMethod1()
{
    Vec3 origin = new Vec3(1, 1, 1);
    Vec3 start = new Vec3(1, 1, 1);
    Vec3 end = new Vec3(4, 4, 4);
    Vec3 candidate = new Vec3(3,3,3);

    double degrees = 360;
    degrees = degrees * (Math.PI / 180.0);

    RotoTranslation rot = 
        new RotoTranslation(origin, 
                            start, 
                            end, 
                            degrees);

    Vec3 rotated = rot.RotateVector(candidate);

    Assert.AreEqual(candidate[0], rotated[0]);
    Assert.AreEqual(candidate[1], rotated[1]);
    Assert.AreEqual(candidate[2], rotated[2]);
}

If I rotate a vector around an axis 360 degrees, it should be at the same position as the initial vector. However, that is not the case here.

Can you tell me what I am doing wrong?

N.B. I must use Matrix3x3, rather than a Matrix4x4 (augmented matrix).


Full Source code

using System;
using System.Collections.Generic;
using Microsoft.VisualStudio.TestTools.UnitTesting;

public class Vec3
{
    public double x, y, z;

    public Vec3(double x, double y, double z)
    {
        this.x = x;
        this.y = y;
        this.z = z;
    }

    public static double DistanceSquared(Vec3 t1, Vec3 t2)
    {
        double x = t1.x - t2.x;
        double y = t1.y - t2.y;
        double z = t1.z - t2.z;
        return x * x + y * y + z * z;
    }

    public static double Distance(Vec3 t1, Vec3 t2)
    {
        if (t1 == null) throw new Exception("point1 is null");
        if (t2 == null) throw new Exception("point2 is null");

        return Math.Sqrt(DistanceSquared(t1, t2));
    }

    public Vec3 Subtract(Vec3 rhs)
    {
        return new Vec3(this.x - rhs.x, this.y - rhs.y, this.z - rhs.z);
    }

    public static Vec3 operator -(Vec3 a, Vec3 b)
    {
        return new Vec3(a.x - b.x, a.y - b.y, a.z - b.z);
    }


    public Vec3 Scale(double rhs)
    {
        return new Vec3(this.x * rhs, this.y * rhs, this.z * rhs);
    }

    public double MagnitudeSquared()
    {
        return this.x * this.x + this.y * this.y + this.z * this.z;
    }

    public static double Dot(Vec3 a, Vec3 b)
    {
        return a.x * b.x + a.y * b.y + a.z * b.z;
    }


    public Vec3 Cross(Vec3 other)
    {
        double a = this.y * other.z - this.z * other.y;
        double b = this.z * other.x - this.x * other.z;
        double c = this.x * other.y - this.y * other.x;
        return new Vec3(a, b, c);
    }

    public override string ToString()
    {
        return $"{x,8:0.000}{y,8:0.000}{z,8:0.000}";
    }

    public Vec3(string x, string y, string z)
    {
        this.x = Convert.ToDouble(x);
        this.y = Convert.ToDouble(y);
        this.z = Convert.ToDouble(z);
    }

    public Vec3(string xyz)
    {
        string[] vals = xyz.Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);

        this.x = Convert.ToDouble(vals[0].Trim());
        this.y = Convert.ToDouble(vals[1].Trim());
        this.z = Convert.ToDouble(vals[2].Trim());
    }

    public Vec3()
    {
    }

    public static Vec3 Normalize(Vec3 vec)
    {
        double mag = Math.Sqrt(vec.x * vec.x + vec.y * vec.y + vec.z * vec.z);
        return new Vec3(vec.x / mag, vec.y / mag, vec.z / mag);
    }

    public static Vec3 OuterProduct(Vec3 a, Vec3 b)
    {
        double x = a.y * b.z - a.z * b.y;
        double y = a.z * b.x - a.x * b.z;
        double z = a.x * b.y - a.y * b.x;
        return new Vec3(x, y, z);
    }

    public static Matrix3x3 OuterProduct_mat(Vec3 lhs, Vec3 rhs)
    {
        double[,] data = new double[3, 3];

        data[0, 0] = lhs.x * rhs.x;
        data[0, 1] = lhs.x * rhs.y;
        data[0, 2] = lhs.x * rhs.z;

        data[1, 0] = lhs.y * rhs.x;
        data[1, 1] = lhs.y * rhs.y;
        data[1, 2] = lhs.y * rhs.z;

        data[2, 0] = lhs.z * rhs.x;
        data[2, 1] = lhs.z * rhs.y;
        data[2, 2] = lhs.z * rhs.z;

        return new Matrix3x3(data[0, 0], data[0, 1], data[0, 2],
                             data[1, 0], data[1, 1], data[1, 2],
                             data[2, 0], data[2, 1], data[2, 2]);
    }

    public static Vec3 operator -(Vec3 v)
    {
        return new Vec3(-v.x, -v.y, -v.z);
    }

    public static Vec3 Transform(Vec3 v, Matrix3x3 m)
    {
        double x = m[0] * v.x + m[1] * v.y + m[2] * v.z;
        double y = m[3] * v.x + m[4] * v.y + m[5] * v.z;
        double z = m[6] * v.x + m[7] * v.y + m[8] * v.z;
        return new Vec3(x, y, z);
    }

    public static Vec3 TransformNormal(Vec3 normal, Matrix3x3 matrix)
    {
        return new Vec3(
            matrix[0] * normal.x + matrix[3] * normal.y + matrix[6] * normal.z,
            matrix[1] * normal.x + matrix[4] * normal.y + matrix[7] * normal.z,
            matrix[2] * normal.x + matrix[5] * normal.y + matrix[8] * normal.z
        );
    }

    public static Vec3 operator +(Vec3 v1, Vec3 v2)
    {
        return new Vec3(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z);
    }

    public double this[int i]
    {
        get
        {
            switch (i)
            {
                case 0:
                    return x;
                case 1:
                    return y;
                case 2:
                    return z;
                default:
                    throw new IndexOutOfRangeException();
            }
        }
        set
        {
            switch (i)
            {
                case 0:
                    x = value;
                    break;
                case 1:
                    y = value;
                    break;
                case 2:
                    z = value;
                    break;
                default:
                    throw new IndexOutOfRangeException();
            }
        }
    }
}

public class Matrix3x3
{
    const int rows = 3;
    const int cols = 3;
    private List<double> data2d;

    public Matrix3x3()
    {
        data2d = new List<double>(new double[rows * cols]);
    }

    public Matrix3x3(double x1, double y1, double z1,
                     double x2, double y2, double z2,
                     double x3, double y3, double z3)
    {
        data2d = new List<double>(new double[rows * cols]);
        data2d[0] = x1; data2d[1] = y1; data2d[2] = z1;
        data2d[3] = x2; data2d[4] = y2; data2d[5] = z2;
        data2d[6] = x3; data2d[7] = y3; data2d[8] = z3;
    }

    public double this[int i]
    {
        get { return data2d[i]; }
        set { data2d[i] = value; }
    }

    public static Matrix3x3 Add(Matrix3x3 lhs, Matrix3x3 rhs)
    {
        Matrix3x3 result = new Matrix3x3();
        for (int i = 0; i < rows * cols; i++)
        {
            result[i] = lhs[i] + rhs[i];
        }
        return result;
    }

    public static Matrix3x3 Sub(Matrix3x3 lhs, Matrix3x3 rhs)
    {
        Matrix3x3 result = new Matrix3x3();
        for (int i = 0; i < rows * cols; i++)
        {
            result[i] = lhs[i] - rhs[i];
        }
        return result;
    }

    public static Matrix3x3 mul_scalar(Matrix3x3 lhs, double rhs)
    {
        Matrix3x3 result = new Matrix3x3();
        for (int i = 0; i < rows * cols; i++)
        {
            result[i] = lhs[i] * rhs;
        }
        return result;
    }

    public static Matrix3x3 div_scalar(Matrix3x3 lhs, double rhs)
    {
        Matrix3x3 result = new Matrix3x3();
        for (int i = 0; i < rows * cols; i++)
        {
            result[i] = lhs[i] / rhs;
        }
        return result;
    }

    public static Vec3 mul_vec_mut(Matrix3x3 lhs, Vec3 rhs)
    {
        double x = lhs[0] * rhs.x + lhs[1] * rhs.y + lhs[2] * rhs.z;
        double y = lhs[3] * rhs.x + lhs[4] * rhs.y + lhs[5] * rhs.z;
        double z = lhs[6] * rhs.x + lhs[7] * rhs.y + lhs[8] * rhs.z;
        return new Vec3(x, y, z);
    }

    public double det()
    {
        Matrix3x3 lhs = this;
        double a = lhs[0] * (lhs[4] * lhs[8] - lhs[5] * lhs[7]);
        double b = lhs[1] * (lhs[3] * lhs[8] - lhs[5] * lhs[6]);
        double c = lhs[2] * (lhs[3] * lhs[7] - lhs[4] * lhs[6]);
        double returns = a - b + c;
        return returns;
    }

    public static Matrix3x3 mul_mat_mut(Matrix3x3 lhs, Matrix3x3 rhs)
    {
        Matrix3x3 result = new Matrix3x3(
            lhs[0] * rhs[0] + lhs[1] * rhs[3] + lhs[2] * rhs[6], // row 1, column 1
            lhs[0] * rhs[1] + lhs[1] * rhs[4] + lhs[2] * rhs[7], // row 1, column 2
            lhs[0] * rhs[2] + lhs[1] * rhs[5] + lhs[2] * rhs[8], // row 1, column 3
            lhs[3] * rhs[0] + lhs[4] * rhs[3] + lhs[5] * rhs[6], // row 2, column 1
            lhs[3] * rhs[1] + lhs[4] * rhs[4] + lhs[5] * rhs[7], // row 2, column 2
            lhs[3] * rhs[2] + lhs[4] * rhs[5] + lhs[5] * rhs[8], // row 2, column 3
            lhs[6] * rhs[0] + lhs[7] * rhs[3] + lhs[8] * rhs[6], // row 3, column 1
            lhs[6] * rhs[1] + lhs[7] * rhs[4] + lhs[8] * rhs[7], // row 3, column 2
            lhs[6] * rhs[2] + lhs[7] * rhs[5] + lhs[8] * rhs[8]); // row 3, column 3

        return result;
    }

    public static Matrix3x3 inverse(Matrix3x3 lhs)
    {
        Matrix3x3 temp = null;
        double _det = lhs.det();
        if (_det != 0.0)
        {
            double inv_det = 1.0 / _det;
            temp = new Matrix3x3(
                lhs[4] * lhs[8] - lhs[5] * lhs[7],
                lhs[2] * lhs[7] - lhs[1] * lhs[8],
                lhs[1] * lhs[5] - lhs[2] * lhs[4],
                lhs[5] * lhs[6] - lhs[3] * lhs[8],
                lhs[0] * lhs[8] - lhs[2] * lhs[6],
                lhs[2] * lhs[3] - lhs[0] * lhs[5],
                lhs[3] * lhs[7] - lhs[4] * lhs[6],
                lhs[1] * lhs[6] - lhs[0] * lhs[7],
                lhs[0] * lhs[4] - lhs[1] * lhs[3]);
        }
        return temp;
    }

    public static Matrix3x3 Identity()
    {
        return new Matrix3x3(1.0, 0.0, 0.0,
                                        0.0, 1.0, 0.0,
                                        0.0, 0.0, 1.0);
    }

    public static Matrix3x3 OuterProduct(Matrix3x3 a, Matrix3x3 b)
    {
        Matrix3x3 result = new Matrix3x3();
        for (int i = 0; i < rows; i++)
        {
            for (int j = 0; j < cols; j++)
            {
                result[i * cols + j] = a[i] * b[j];
            }
        }
        return result;
    }

    public static Matrix3x3 operator *(Matrix3x3 lhs, double rhs)
    {
        Matrix3x3 result = new Matrix3x3();
        for (int i = 0; i < rows * cols; i++)
        {
            result[i] = lhs[i] * rhs;
        }
        return result;
    }

    public static Matrix3x3 operator *(double lhs, Matrix3x3 rhs)
    {
        return rhs * lhs;
    }

    public static Matrix3x3 operator +(Matrix3x3 lhs, Matrix3x3 rhs)
    {
        Matrix3x3 result = new Matrix3x3();
        for (int i = 0; i < rows * cols; i++)
        {
            result[i] = lhs[i] + rhs[i];
        }
        return result;
    }

    public static Matrix3x3 CreateTranslation(Vec3 translation)
    {
        return new Matrix3x3(
            1.0f, 0.0f, 0.0f,
            0.0f, 1.0f, 0.0f,
            translation.x, translation.y, translation.z
        );
    }

    public static Matrix3x3 operator *(Matrix3x3 lhs, Matrix3x3 rhs)
    {
        Matrix3x3 result = new Matrix3x3();
        for (int i = 0; i < rows; i++)
        {
            for (int j = 0; j < cols; j++)
            {
                double sum = 0;
                for (int k = 0; k < cols; k++)
                {
                    sum += lhs[i * cols + k] * rhs[k * cols + j];
                }
                result[i * cols + j] = sum;
            }
        }
        return result;
    }
}

public class RotoTranslation
{
    private readonly Vec3 origin;
    private readonly double cosTheta;
    private readonly double sinTheta;
    private readonly Vec3 axis;
    private readonly Matrix3x3 rotationMatrix;
    private readonly Matrix3x3 translationMatrix;
    private readonly Matrix3x3 invTranslationMatrix;

    public RotoTranslation(Vec3 origin, Vec3 start, Vec3 end, double angle_rad)
    {
        this.origin = origin;
        this.axis = Vec3.Normalize(end - start);
        this.cosTheta = Math.Cos(angle_rad);
        this.sinTheta = Math.Sin(angle_rad);

        Matrix3x3 uOuter = Vec3.OuterProduct_mat(axis, axis);
        Matrix3x3 uCross = new Matrix3x3(
                                           0.0f, -axis.z, axis.y,
                                           axis.z, 0.0f, -axis.x,
                                          -axis.y, axis.x, 0.0f
                                        );

        rotationMatrix = cosTheta * Matrix3x3.Identity()
                                 + (1.0f - cosTheta) * uOuter
                                 + sinTheta * uCross;

        translationMatrix = Matrix3x3.CreateTranslation(-origin);
        rotationMatrix = rotationMatrix * translationMatrix;
        invTranslationMatrix = Matrix3x3.CreateTranslation(origin);
    }

    public Vec3 RotateVector(Vec3 vector)
    {
        Vec3 transformedVector = Vec3.TransformNormal(vector - origin, rotationMatrix);
        return Vec3.Transform(transformedVector, invTranslationMatrix);
    }
}

[TestClass]
public class RotoTranslationUnitTest
{
    [TestMethod]
    public void TestMethod1()
    {
        Vec3 origin = new Vec3(1, 1, 1);
        Vec3 start = new Vec3(1, 1, 1);
        Vec3 end = new Vec3(4, 4, 4);
        Vec3 candidate = new Vec3(3,3,3);

        double degrees = 360;
        degrees = degrees * (Math.PI / 180.0);

        RotoTranslation rot = new RotoTranslation(origin, start, end, degrees);

        Vec3 rotated = rot.RotateVector(candidate);

        Assert.AreEqual(candidate[0], rotated[0]);
        Assert.AreEqual(candidate[1], rotated[1]);
        Assert.AreEqual(candidate[2], rotated[2]);
    }
}

Solution

  • Your problem is in this method

    public static Matrix3x3 CreateTranslation(Vec3 translation)
    {
        return new Matrix3x3(
            1.0f, 0.0f, 0.0f,
            0.0f, 1.0f, 0.0f,
            translation.x, translation.y, translation.z
        );
    }
    

    To represent an offset in 3D with a matrix you need a 4×4 matrix arranged as follows

                        | 1    0    0    x |
    translate3(x,y,z) = | 0    1    0    y |
                        | 0    0    1    z |
                        | 0    0    0    1 |
    

    Instead of handling the translation with a matrix operation, just change the code to use linear algebra

    public RotoTranslation(Vec3 origin, Vec3 start, Vec3 end, double angle_rad)
    {
        this.origin = origin;
        this.axis = Vec3.Normalize(end - start);
        this.cosTheta = Math.Cos(angle_rad);
        this.sinTheta = Math.Sin(angle_rad);
    
        Matrix3x3 uOuter = Vec3.OuterProduct_mat(axis, axis);
        Matrix3x3 uCross = new Matrix3x3(
                                           0.0f, -axis.z, axis.y,
                                           axis.z, 0.0f, -axis.x,
                                          -axis.y, axis.x, 0.0f
                                        );
    
        rotationMatrix = cosTheta * Matrix3x3.Identity()
                                 + (1.0f - cosTheta) * uOuter
                                 + sinTheta * uCross;
    
    }
    
    public Vec3 RotateVector(Vec3 vector)
    {
        Vec3 transformedVector = vector - origin;
        transformedVector = Vec3.Transform(transformedVector, rotationMatrix );
        return transformedVector + origin;
    }
    

    translationMatrix and InvTranslationMatrix are not needed. You have the origin vector stored and that all you need.