Search code examples
pytorchloss-function

Loss for binary sparsity


I have binary images (as the one below) at the output of my net. I need the '1's to be further from each other (not connected), so that they would form a sparse binary image (without white blobs). Something like salt-and-pepper noise. I am looking for a way to define a loss (in pytorch) that would punish based on the density of the '1's.

Thanks.

IBinary Image


Solution

  • It depends on how you're generating said image. Since neural networks have to be trained by backpropagation, I'm rather sure your binary image is not the direct output of your neural network (ie not the thing you're applying loss to), because gradient can't blow through binary (discrete) variables. I suspect you do something like pixel-wise binary cross entropy or similar and then threshold.

    I assume your code works like that: you densely regress real-valued numbers and then apply thresholding, likely using sigmoid to map from [-inf, inf] to [0, 1]. If it is so, you can do the following. Build a convolution kernel which is 0 in the center and 1 elsewhere, of size related to how big you want your "sparsity gaps" to be.

    kernel = [
        [1, 1, 1, 1, 1]
        [1, 1, 1, 1, 1]
        [1, 1, 0, 1, 1]
        [1, 1, 1, 1, 1]
        [1, 1, 1, 1, 1]
    ]
    

    Then you apply sigmoid to your real-valued output to squash it to [0, 1]:

    squashed = torch.sigmoid(nn_output)
    

    then you convolve squashed with kernel, which gives you the relaxed number of non-zero neighbors.

    neighborhood = nn.functional.conv2d(squashed, kernel, padding=2)
    

    and your loss will be the product of each pixel's value in squashed with the corresponding value in neighborhood:

    sparsity_loss = (squashed * neighborhood).mean()
    

    If you think of this loss applied to your binary image, for a given pixel p it will be 1 if and only if both p and at least one of its neighbors have values 1 and 0 otherwise. Since we apply it to non-binary numbers in [0, 1] range, it will be the differentiable approximation of that.

    Please note that I left out some of the details from the code above (like correctly reshaping kernel to work with nn.functional.conv2d).