Search code examples
pythonmachine-learningkerasneural-networkkeras-layer

Keras: input with size x*x generates unwanted output y*x


I have the following neural network in Keras:

inp = layers.Input((3,))
#Middle layers omitted
out_prop = layers.Dense(units=3, activation='softmax')(inp)
out_value = layers.Dense(units=1, activation = 'linear')(inp)

Then I prepared a pseudo-input to test my network:

inpu = np.array([[1,2,3],[4,5,6],[7,8,9]])

When I try to predict, this happens:

In [45]:nn.network.predict(inpu)
Out[45]: 
[array([[0.257513  , 0.41672954, 0.32575747],
    [0.20175152, 0.4763418 , 0.32190666],
    [0.15986516, 0.53449154, 0.30564335]], dtype=float32),
array([[-0.24281949],
    [-0.10461146],
    [ 0.11201331]], dtype=float32)]

So, as you can see above, I wanted two output: one should have been an array with size 3, the other should have been a normal value. Instead, I get a 3x3 matrix, and an array with 3 elements. What am I doing wrong?


Solution

  • You are passing three input samples to the network:

    >>> inpu.shape
    (3,3)  # three samples of size 3
    

    And you have two output layers: one of them outputs a vector of size 3 for each sample and the other outputs a vector of size one (i.e. scalar), again for each sample. As a result the output shapes would be (3, 3) and (3, 1).

    Update: If you want your network to accept an input sample of shape (3,3) and outputs vectors of size 3 and 1, and you want to only use Dense layers in your network, then you must use a Flatten layer somewhere in the model. One possible option is to use it right after the input layer:

    inp = layers.Input((3,3))  # don't forget to set the correct input shape
    x = Flatten()(inp)
    # pass x to other Dense layers
    

    Alternatively, you could flatten your data to have a shape of (num_samples, 9) and then pass it to your network without using a Flatten layer.

    Update 2: As @Mete correctly pointed out in the comments, make sure the input array have a shape of (num_samples, 3, 3) if each input sample has a shape of (3,3).