Search code examples
pythontensorflowkerasconv-neural-networkmask

Define a binary mask in Keras


I have an input image of shape [X,Y,3] and I have 2 coordinates (x,y). Now I want to create a mask with these coordinates and then multiply it with the input image. The mask should be a binary matrix with the same size as the image, with ones at coordinates [x:x+p_size,y:y+p_size] and zeros elsewhere.

My question is how to define the mask in Keras (tensorflow backend)?

note that this operation happens within the model (so simply using numpy won't help).

img = Input(shape=(32,32,3))
xy = Input(shape=(2)) # x and y coordinates for the mask
mask = ?
output = keras.layers.Multiply()([img, mask])

Solution

  • You can do the whole thing with a Lambda layer implementing a custom function:

    from keras.models import Model
    from keras.layers import Input, Lambda
    from keras import backend as K
    import numpy as np
    
    # Masking function factory
    def mask_img(x_size, y_size=None):
        if y_size is None:
            y_size = x_size
        # Masking function
        def mask_func(tensors):
            img, xy = tensors
            img_shape = K.shape(img)
            # Make indexing arrays
            xx = K.arange(img_shape[1])
            yy = K.arange(img_shape[2])
            # Get coordinates
            xy = K.cast(xy, img_shape.dtype)
            x = xy[:, 0:1]
            y = xy[:, 1:2]
            # Make X and Y masks
            mask_x = (xx >= x) & (xx < x + x_size)
            mask_y = (yy >= y) & (yy < y + y_size)
            # Make full mask
            mask = K.expand_dims(mask_x, 2) & K.expand_dims(mask_y, 1)
            # Add channels dimension
            mask = K.expand_dims(mask, -1)
            # Multiply image and mask
            mask = K.cast(mask, img.dtype)
            return img * mask
        return mask_func
    
    # Model
    img = Input(shape=(10, 10, 3))  # Small size for test
    xy = Input(shape=(2,))
    output = Lambda(mask_img(3))([img, xy])
    model = Model(inputs=[img, xy], outputs=output)
    
    # Test
    img_test = np.arange(100).reshape((1, 10, 10, 1)).repeat(3, axis=-1)
    xy_test = np.array([[2, 4]])
    output_test = model.predict(x=[img_test, xy_test])
    print(output_test[0, :, :, 0])
    

    Output:

    [[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
     [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
     [ 0.  0.  0.  0. 24. 25. 26.  0.  0.  0.]
     [ 0.  0.  0.  0. 34. 35. 36.  0.  0.  0.]
     [ 0.  0.  0.  0. 44. 45. 46.  0.  0.  0.]
     [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
     [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
     [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
     [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
     [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]]