I'm trying to do vector math using ArrayFire.jl but the function for vector cross product is not implemented in Arrayfire. Is there a workaround for calculating it using Julia's Arrayfire.jl wrapper in a performant way? Defining the function in a naive way is really slow due to all the data transfer between the device and the host, and I don't understand the wrapper functions enough to figure out how to solve this.
cross(a::ArrayFire.AFArray, b::ArrayFire.AFArray) = ArrayFire.AFArray([a[2]*b[3]-a[3]*b[2]; a[3]*b[1]-a[1]*b[3]; a[1]*b[2]-a[2]*b[1]]);
To answer myself, the cross product can be done using circshift() function to create shifted vectors in GPU and one can then do element-wise multiplication and subtraction. It's not the most elegant way, but it works.
function cross(a::ArrayFire.AFArray{Float32,1}, b::ArrayFire.AFArray{Float32,1})
ashift = circshift(a, [-1]);
ashift2 = circshift(a, [-2]);
bshift = circshift(b, [-2]);
bshift2 = circshift(b, [-1]);
c::ArrayFire.AFArray{Float32,1} = ashift.*bshift - ashift2.*bshift2;
end