Search code examples
c++multidimensional-arrayvariadic-templates

Pass N-D array by reference to variadic function


I'd like to make the function multi_dimensional accept a multidimensional array by reference.

Can this be done with a variation of the syntax below which works for three_dimensional?

#include <utility>

// this works, but number of dimensions must be known (not variadic)
template <size_t x, size_t y, size_t z>
void three_dimensional(int (&nd_array)[x][y][z]) {}

// error: parameter packs not expanded with ‘...’
template <size_t... dims>
void multi_dimensional(int (&nd_array)[dims]...) {}

int main() {
    int array[2][3][2] = {
        { {0,1}, {2,3}, {4,5} },
        { {6,7}, {8,9}, {10,11} }
    };
    three_dimensional(array); // OK
    // multi_dimensional(array); // error: no matching function
    return 0;
}

Solution

  • The main problem is that you cannot make the number of array dimensions itself variadic. So whichever way you go, you will almost certainly need a recursive approach of some sort to deal with the individual array layers. What exactly such approach should look like will mainly depend on what exactly you're planning to do with the array once it's been given to you.

    If really all you want is a function that can be given any multi-dimensional array, then just write a function that can be given anything but only exists as long as that anything is an array:

    template <typename T>
    std::enable_if_t<std::is_array_v<T>> multi_dimensional(T& a)
    {
        constexpr int dimensions = std::rank_v<T>;
    
        // ...
    }
    

    However, this by itself will most likely not get you very far. To actually do anything meaningful with the array you've been given, you will most likely need some recursive walking through subarrays. Unless you really just want to look at the topmost layer of the structure.

    Another approach is to use a recursive template to peel back the individual array levels, for example:

    // we've reached the bottom
    template <typename T, int N>
    void multi_dimensional(T (&a)[N])
    {
        // ...
    }
    
    // this matches any array with more than one dimension
    template <typename T, int N, int M>
    void multi_dimensional(T (&a)[N][M])
    {
        // peel off one dimension, invoke function for each element on next layer
        for (int i = 0; i < N; ++i)
            multi_dimensional(a[i]);
    }
    

    I would, however, suggest to at least consider using std::array<> instead of raw arrays as the syntax and special behavior of raw arrays tends to turn everything into a confusing mess in no time. In general, it might be worth to implement your own multi-dimensional array type, like an NDArray<int, 2, 3, 2> which internally works with a flattened representation and just maps multi-dimensional indices to a linear index. One advantage of this approach (besides the cleaner syntax) would be that you can easily change the mapping, e.g., to switch from row-major to column-major layout, e.g., for performance optimization…

    To implement a general nD array with static dimensions, I would introduce a helper class to encapsulate the recursive computation of a linear index from an nD index:

    template <std::size_t... D>
    struct row_major;
    
    template <std::size_t D_n>
    struct row_major<D_n>
    {
        static constexpr std::size_t SIZE = D_n;
    
        std::size_t operator ()(std::size_t i_n) const
        {
            return i_n;
        }
    };
    
    template <std::size_t D_1, std::size_t... D_n>
    struct row_major<D_1, D_n...> : private row_major<D_n...>
    {
        static constexpr std::size_t SIZE = D_1 * row_major<D_n...>::SIZE;
    
        template <typename... Tail>
        std::size_t operator ()(std::size_t i_1, Tail&&... tail) const
        {
            return i_1 + D_1 * row_major<D_n...>::operator ()(std::forward<Tail>(tail)...);
        }
    };
    

    And then:

    template <typename T, std::size_t... D>
    class NDArray
    {
        using memory_layout_t = row_major<D...>;
    
        T data[memory_layout_t::SIZE];
    
    public:
        template <typename... Args>
        T& operator ()(Args&&... args)
        {
            memory_layout_t memory_layout;
            return data[memory_layout(std::forward<Args>(args)...)];
        }
    };
    
    
    NDArray<int, 2, 3, 5> arr;
    
    int main()
    {
        int x = arr(1, 2, 3);
    }