Search code examples
pythonnumpymultidimensional-arraydimensions

What is a mory elegant and pythonic solution for creating a numpy array with one fixed dimension size and all others dynamic?


I have a function that will create a random numpy array with the first dimension fixed. However, all others can be changed dynmically. I wrote avery simple function that fullfills this, however I expected that this is not the correct way to do this and was hoping for a more pythonic solution.

def test(arg,first_dim=5):
    new_dims = []
    new_dims.append(first_dim)
    if type(arg) == int:
        new_dims.extend([arg])
    else:
        new_dims.extend(arg)
    return np.zeros(shape=new_dims) 

P.S.: One could argue to put the 5 directly into the args, however this is not possible in my specific case


Solution

  • You can do simply:

    def test(arg,first_dim=5):
        return np.zeros(shape=[first_dim] + ([arg] if type(arg) is int else arg))
    

    It's totally equivalent to your code.