Search code examples
pythonmatplotliberror-handlingconv-neural-networkindex-error

Visualizing Feature maps: IndexError: too many indices for array


Following this tutorial I am trying to visualize feature maps.

My model looks as follows:

model.summary()
Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_5 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
efficientnet-b0 (Functional)    (None, 7, 7, 1280)   4049564     input_5[0][0]                    
__________________________________________________________________________________________________
flatten_4 (Flatten)             (None, 62720)        0           efficientnet-b0[0][0]            
__________________________________________________________________________________________________
branch_0_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_1_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_2_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_3_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_4_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_5_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_6_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_0_output (Dense)         (None, 35)           8995        branch_0_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_1_output (Dense)         (None, 35)           8995        branch_1_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_2_output (Dense)         (None, 35)           8995        branch_2_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_3_output (Dense)         (None, 35)           8995        branch_3_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_4_output (Dense)         (None, 35)           8995        branch_4_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_5_output (Dense)         (None, 35)           8995        branch_5_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_6_output (Dense)         (None, 35)           8995        branch_6_Dense_16000[0][0]       
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 245)          0           branch_0_output[0][0]            
                                                                 branch_1_output[0][0]            
                                                                 branch_2_output[0][0]            
                                                                 branch_3_output[0][0]            
                                                                 branch_4_output[0][0]            
                                                                 branch_5_output[0][0]            
                                                                 branch_6_output[0][0]            
__________________________________________________________________________________________________
reshape_4 (Reshape)             (None, 7, 35)        0           concatenate_4[0][0]              
==================================================================================================
Total params: 116,508,561
Trainable params: 116,466,545
Non-trainable params: 42,016

I now would like to visualize the layer with index 10: 10 branch_0_output (None, 35)

3 branch_0_Dense_16000 (None, 256)
4 branch_1_Dense_16000 (None, 256)
5 branch_2_Dense_16000 (None, 256)
6 branch_3_Dense_16000 (None, 256)
7 branch_4_Dense_16000 (None, 256)
8 branch_5_Dense_16000 (None, 256)
9 branch_6_Dense_16000 (None, 256)
10 branch_0_output (None, 35)
11 branch_1_output (None, 35)
12 branch_2_output (None, 35)
13 branch_3_output (None, 35)
14 branch_4_output (None, 35)
15 branch_5_output (None, 35)
16 branch_6_output (None, 35)

I followed the code stated in the tutorial, preprocessed the image and I now would like to plot the 35 (?) feature maps of this layer: I used the code in the tutorial and amended the square number, here it is 1 but I tried several:

# plot all 35 maps
square = 1
ix = 1
for _ in range(square):
    for _ in range(square):
        # specify subplot and turn of axis
        ax = pyplot.subplot(square, square, ix)
        ax.set_xticks([])
        ax.set_yticks([])
        # plot filter channel in grayscale
        pyplot.imshow(feature_maps[0, :, :, ix-1], cmap='gray')
        ix += 1
# show the figure
pyplot.show()

Independent on what number I tried I got this error message:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-28-4c1f464f6978> in <module>()
      9                 ax.set_yticks([])
     10                 # plot filter channel in grayscale
---> 11                 pyplot.imshow(feature_maps[0, :, ix-1], cmap='gray')
     12                 ix += 1
     13 # show the figure

IndexError: too many indices for array

Can someone help what I have to amend?

Thanks a lot!


Solution

  • The error shows too many indices for array in line 11. This is happening because you are passing indices incorrectly in the feature maps. Here you are trying to plot 35 maps in a 1*1 grid as you have given square = 1.

    Suppose you need to plot 64 maps, then we will take square = 8, then the output will be a 8*8 grid.