Search code examples
pythonnumpyarray-broadcasting

Is there a function that can apply NumPy's broadcasting rules to a list of shapes and return the final shape?


This is not a question about how broadcasting works (i.e., it's not a duplicate of these questions).

I would just like find a function that can apply NumPy's broadcasting rules to a list of shapes and return the final shape, for example:

>>> broadcast_shapes([6], [4, 2, 3, 1], [2, 1, 1])
[4, 2, 3, 6]

Thanks!


Solution

  • Here is another direct implementation which happens to beat the others on the example. Honorable mention goes to @hpaulj's with @Warren Weckesser's hack which is almost as fast and much more concise:

    def bs_pp(*shapes):
        ml = max(shapes, key=len)
        out = list(ml)
        for l in shapes:
            if l is ml:
                continue
            for i, x in enumerate(l, -len(l)):
                if x != 1 and x != out[i]:
                    if out[i] != 1:
                        raise ValueError
                    out[i] = x
        return (*out,)
    
    def bs_mq1(*shapes):
        max_rank = max([len(shape) for shape in shapes])
        shapes = [[1] * (max_rank - len(shape)) + shape for shape in shapes]
        final_shape = [1] * max_rank
        for shape in shapes:
            for dim, size in enumerate(shape):
                if size != 1:
                    final_size = final_shape[dim]
                    if final_size == 1:
                        final_shape[dim] = size
                    elif final_size != size:
                        raise ValueError("Cannot broadcast these shapes")
        return (*final_shape,)
    
    import numpy as np
    
    def bs_mq2(*shapes):
        max_rank = max([len(shape) for shape in shapes])
        shapes = np.array([[1] * (max_rank - len(shape)) + shape
                          for shape in shapes])
        shapes[shapes==1] = -1
        final_shape = shapes.max(axis=0)
        final_shape[final_shape==-1] = 1
        return (*final_shape,)
    
    def bs_hp_ww(*shapes):
        return np.broadcast(*[np.empty(shape + [0,], int) for shape in shapes]).shape[:-1]
    
    L = [6], [4, 2, 3, 1], [2, 1, 1]
    
    from timeit import timeit
    
    print('pp:       ', timeit(lambda: bs_pp(*L), number=10_000)/10)
    print('mq 1:     ', timeit(lambda: bs_mq1(*L), number=10_000)/10)
    print('mq 2:     ', timeit(lambda: bs_mq2(*L), number=10_000)/10)
    print('hpaulj/ww:', timeit(lambda: bs_hp_ww(*L), number=10_000)/10)
    
    assert bs_pp(*L) == bs_mq1(*L) and bs_pp(*L) == bs_mq2(*L) and bs_pp(*L) == bs_hp_ww(*L)
    

    Sample run:

    pp:        0.0021552839782088993
    mq 1:      0.00398325570859015
    mq 2:      0.01497043427079916
    hpaulj/ww: 0.003267909213900566