Search code examples

Library housing CNN shape calculation in a function?

I find myself continually re-implementing the same free function for a convolutional neural network's output shape, given hyperparameters. I am growing tired of re-implementing this function and occasionally also unit tests.

pytorch nn.Conv3d shape formulae


Is there a library (preference to pytorch, tensorflow, or numpy) that houses a function that implements this formula?

Here is what I just implemented for a PyTorch-based project using Python 3.10+, but I would rather just import this.

def conv_conversion(
    in_shape: tuple[int, ...],
    kernel_size: int | tuple[int, ...],
    padding: int | tuple[int, ...] = 0,
    dilation: int | tuple[int, ...] = 1,
    stride: int | tuple[int, ...] = 1,
) -> tuple[int, ...]:
    """Perform a Conv layer calculation matching nn.Conv's defaults."""

    def to_tuple(value: int | tuple[int, ...]) -> tuple[int, ...]:
        return (value,) * len(in_shape) if isinstance(value, int) else value

    k, p = to_tuple(kernel_size), to_tuple(padding)
    dil, s = to_tuple(dilation), to_tuple(stride)
    return tuple(
        int((in_shape[i] + 2 * p[i] - dil[i] * (k[i] - 1) - 1) / s[i] + 1)
        for i in range(len(in_shape))


  • There is such a function in keras, namely conv_output_length that computes the size of the dimension given a specific kernel size, dilation rate and strides.

    The only downside is that it supports only 3 types of padding: None ("valid"), padding the input such as the output has the same spatial dimensions as the input ("same"), or to pad such as every pixel of the image is convolved the same amount of time ("full").

    You could implement your own function to get the shape for all dimension with

    from keras.utils.conv_utils import conv_output_length
    def conv_output_shape(input_shape, kernel_shape, strides, padding, dilate):  
        # assuming only spatial dimensions
        dims = range(len(kernel_shape))
        output_shape = [
            conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d], dilate[d])
            for d in dims
        output_shape = tuple(
            [0 if input_shape[d] == 0 else output_shape[d] for d in dims]
        return output_shape