Search code examples
pythonmatplotlibpytorchtensor

Matplot histogram of Pytorch Tensor


I have a tensor of size 10, with only 2 values: 0 and 1.

I want to plot an histogram of the tensor above, simply using matplotlib.pyplot.hist. This is my code:

import torch
import matplotlib.pyplot as plt
t = torch.tensor([0., 1., 1., 1., 0., 1., 1., 0., 0., 0.])
print(t)
plt.hist(t, bins=2)
plt.show()

And the output:

enter image description here

Why are there so many values in the histogram? Where did the rest of the values come from? How can I plot a correct histogram for my tensor?


Solution

  • The plt.hist(t, bins=2) function is not meant to work with tensors. For this to work properly, you can try using t.numpy() or t.tolist() instead. As far as I could educate myself, the way to compute a histogramwith pytorch is through the torch.histc() function and to plot the histogram you use plt.bar() function as follows:

    import torch
    import matplotlib.pyplot as plt
    
    t = torch.tensor([0., 0., 1., 1., 0., 1., 1., 0., 0., 0.])
    hist = torch.histc(t, bins = 2, min = 0, max = 1)
    
    bins = 2
    x = range(bins)
    plt.bar(x, hist, align='center')
    plt.xlabel('Bins')
    

    Some sources for plotting a histogram can be seen here and here . I could not find the root cause for this and if some could educate me it will be great, but as far as I am aware, this is the way to plot a tensor

    I changed the tensor to have 4 '1.0' and 6 '0.0' to be able to see the difference