Search code examples
pythontensorflowlambdakeraskeras-layer

Custom Lambda layer for Kronecker product in Keras - troubles with the dimension reserved for batch_size


I am using Keras 2.1.5 with Tensorflow backend to create a model for image classification. In my model, I would like to combine the input and the output of a convolution layer by counting the Kronecker product. I've written the functions which counts the Kronecker product of two 3D tensors using the Keras backend functions.

def kronecker_product(mat1, mat2):
    #Computes the Kronecker product of two matrices.
    m1, n1 = mat1.get_shape().as_list()
    mat1_rsh = K.reshape(mat1, [m1, 1, n1, 1])
    m2, n2 = mat2.get_shape().as_list()
    mat2_rsh = K.reshape(mat2, [1, m2, 1, n2])
    return K.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])

def kronecker_product3D(tensors):
    tensor1 = tensors[0]
    tensor2 = tensors[1]
    #Separete slices of tensor and computes appropriate matrice kronecker product
    m1, n1, o1 = tensor1.get_shape().as_list()
    m2, n2, o2 = tensor2.get_shape().as_list()
    x_list = []
    for ind1 in range(o1):
        for ind2 in range(o2):
            x_list.append(DenseNetKTC.kronecker_product(tensor1[:,:,ind1], tensor2[:,:,ind2]))
    return K.reshape(Concatenate()(x_list), [m1 * m2, n1 * n2, o1 * o2])

Then I've tried to use the Lambda layer to wrap the operation into Keras layer:

cb = Convolution2D(12, (3,3), padding='same')(x)
x = Lambda(kronecker_product3D)([x, cb])

but received the error "ValueError: too many values to unpack (expected 3)". I expect the input to be the tensor of 3 dimensions, but in fact, it has 4 dimensions - the first dimension reserved for the batch_size in Keras. I do not know how to deal with this fourth dimension with dynamic size.

I've searched a lot, but cannot find any example function which deals manualy with the dimension for batches.

I would be glad for any tips or help. Thank you very much!


Solution

  • Easy solution:

    Simply add the batch dimension to your calcs and reshapes

    def kronecker_product(mat1, mat2):
        #Computes the Kronecker product of two matrices.
        batch, m1, n1 = mat1.get_shape().as_list()
        mat1_rsh = K.reshape(mat1, [-1, m1, 1, n1, 1])
        batch, m2, n2 = mat2.get_shape().as_list()
        mat2_rsh = K.reshape(mat2, [-1, 1, m2, 1, n2])
        return K.reshape(mat1_rsh * mat2_rsh, [-1, m1 * m2, n1 * n2])
    
    def kronecker_product3D(tensors):
        tensor1 = tensors[0]
        tensor2 = tensors[1]
        #Separete slices of tensor and computes appropriate matrice kronecker product
        batch, m1, n1, o1 = tensor1.get_shape().as_list()
        batch, m2, n2, o2 = tensor2.get_shape().as_list()
        x_list = []
        for ind1 in range(o1):
            for ind2 in range(o2):
                x_list.append(kronecker_product(tensor1[:,:,:,ind1], tensor2[:,:,:,ind2]))
        return K.reshape(Concatenate()(x_list), [-1, m1 * m2, n1 * n2, o1 * o2])
    

    For the hard solution, I would try to figure out a way to avoid iterating, but that may be way more complex than I thought....