Search code examples
pythonmatplotlibtorchtensor

How can I plot a tensor?


I have to plot a torch tensor of dimension [9,224,224]. Is there a method with matplotlib.pyplot?


Solution

  • You can plot images

    Matrices are easily plotted as images. Then if you have higher dimensions, each of them can be broken down into some number of matrices.

    a = np.random.random((9, 224, 224)).reshape((3, 3, 224, 224))
    fig, ax = plt.subplots(3, 3)
    
    for i in range(3):
        for j in range(3):
            ax[i][j].imshow(a[i][j])
    

    It's worth considering if this is appropriate for the data in question, but given a suitable configuration, it can work very well. For example viewing weights of a neural network.