Search code examples
juliaarrayfire

Possible workarounds for vector cross product with Arrayfire.jl?


I'm trying to do vector math using ArrayFire.jl but the function for vector cross product is not implemented in Arrayfire. Is there a workaround for calculating it using Julia's Arrayfire.jl wrapper in a performant way? Defining the function in a naive way is really slow due to all the data transfer between the device and the host, and I don't understand the wrapper functions enough to figure out how to solve this.

cross(a::ArrayFire.AFArray, b::ArrayFire.AFArray) = ArrayFire.AFArray([a[2]*b[3]-a[3]*b[2]; a[3]*b[1]-a[1]*b[3]; a[1]*b[2]-a[2]*b[1]]);

Solution

  • To answer myself, the cross product can be done using circshift() function to create shifted vectors in GPU and one can then do element-wise multiplication and subtraction. It's not the most elegant way, but it works.

    function cross(a::ArrayFire.AFArray{Float32,1}, b::ArrayFire.AFArray{Float32,1})
        ashift = circshift(a, [-1]);
        ashift2 = circshift(a, [-2]);
        bshift = circshift(b, [-2]);
        bshift2 = circshift(b, [-1]);
        c::ArrayFire.AFArray{Float32,1} = ashift.*bshift - ashift2.*bshift2;
    end