Search code examples
multidimensional-arrayd

Retrieving the shape of a multidimensional array in D


auto getMultidimensionalArrayShape( T )( T array )
{
  static assert( isArray!( T ) );

  // Retrieve the shape of the input array and double check that
  // the arrays from a same dimension have the same length.
  size_t[] shape;
  // ...

  return shape;
}

getMultidimensionalArrayShape( [1, 2, 3] )  //< returns [3]
getMultidimensionalArrayShape( [[1, 2, 3], [4, 5, 6]] )  //< returns [3][2]
// and so on...

Naively, I would iterate in a depth-first manner to retrieve the size of each array at the indices 0, then I'll check if the other arrays from the same dimension matches in length with what I've found at index 0, just to make sure that everything is consistent, but I'm thinking that there must be a better way... does anyone have an idea?


Solution

  • So here's what I've ended up writing. Basically I'm collecting all the arrays of a same depth level and retrieve their dimension size, after making sure that they're all equal. Sounds like it's working nicely, the only thing is that I'm not sure what to do when the length of depthLevelArrays is equal to 0? Does it even make sense or should I just throw an exception?

    import std.stdio;
    import std.traits;
    
    
    void internal( T )( T[] depthLevelArrays, ref size_t[] shape )
    {
      if ( depthLevelArrays.length == 0 )
          writeln( "what to do here?" );
    
      static if ( isArray!( ForeachType!( T ) ) )
      {
        ForeachType!( T )[] children;
        size_t dimensionSize = depthLevelArrays[0].length;
        foreach ( element; depthLevelArrays )
        {
          if ( element.length != dimensionSize )
            throw new Exception( "Shape is not uniform" );
    
          foreach ( child; element )
            children ~= child;
        }
    
        internal( children, shape );
      }
      else
      {
        size_t dimensionSize = depthLevelArrays[0].length;
        foreach ( element; depthLevelArrays )
          if ( element.length != dimensionSize )
            throw new Exception( "Shape is not uniform" );
      }
    
      shape ~= dimensionSize;
    }
    
    
    auto getMultidimensionalArrayShape( T )( T array )
    {
      static assert( isArray!( T ) );
      size_t[] shape;
      internal( [array], shape );
      return shape;
    }
    
    
    void main()
    {
      immutable auto array1d = [1.0, 2.5, 3.4];
      auto shape1 = getMultidimensionalArrayShape( array1d );
      writeln( "shape1: ", shape1 ); //< [3]
    
      int[5][2] array2d = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]];
      auto shape2 = getMultidimensionalArrayShape( array2d );
      writeln( "shape2: ", shape2 );  //< [5][2]
    
      auto shape3 = getMultidimensionalArrayShape(
          [[[1, 2], [3, 4], [5, 6], [7, 8]],
          [[9, 10], [11, 12], [13, 14], [15, 16]],
          [[17, 18], [19, 20], [21, 22], [23, 24]]] );
      writeln( "shape3: ", shape3 ); //< [2][4][3]
    
      const int[5][4][3][2] array4d;
      auto shape4 = getMultidimensionalArrayShape( array4d );
      writeln( "shape4: ", shape4 ); //< [5][4][3][2]
    }