Search code examples
pythonmatplotlibplotprobability

Probability Mass Function (PMF): plot probabilities as columns with matplotlib.pyplot.plot


Trying to use 'matplotlib.pyplot' to plot discreet event probabilities as columns in PMF graph. Instead of complicated logic of 'hist' function I hope to achieve my goal with 'plot' using drawstyle="steps-pre" :

def plot_pmf(self):
    """" Plot PMF """
    x,y = list(self.pmf_v.keys()), list(self.pmf_v.values())
    #_=plt.plot(x,y, marker='.', linestyle='none') # plot with dots
    _=plt.plot(x,y, drawstyle="steps-pre") # plot with columns?
    _=plt.margins(0.02)
    _=plt.title(self.title)
    _=plt.xlabel(self.x_label)
    _=plt.ylabel(self.y_label)
    plt.show()

Which does not work, as Iris data shows:

x = setosa["sepal_length"]
sf = StatsFun(x,"Setosa Sepal Length","length", "probability")
pmf = sf.pmf()
print(pmf)
sf.plot_pmf()

{5.1: 0.16, 4.9: 0.08, 4.7: 0.04, 4.6: 0.08, 5.0: 0.16, 5.4: 0.1, 4.4: 0.06, 4.8: 0.1, 4.3: 0.02, 5.8: 0.02, 5.7: 0.04, 5.2: 0.06, 5.5: 0.04, 4.5: 0.02, 5.3: 0.02}

enter image description here

Please advise how to get with 'plot' function the result similar to the image below created by 'plt.hist(data, weights=weights, bins = 100)' and which works only because of 'bins = 100':

data = setosa["sepal_length"]
weights = np.ones_like(np.array(data))/float(len(np.array(data)))
print(weights, sum(weights))
#plt.hist(data, bins = 100) # does the same as next line
plt.hist(data, weights=weights, bins = 100)
plt.show()

[0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02
 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02
 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02
 0.02 0.02 0.02 0.02 0.02 0.02 0.02 0.02] 1.0000000000000004

enter image description here


Solution

  • Your example data (x,y) pairs:

    data = {5.1: 0.16, 4.9: 0.08, 4.7: 0.04, 4.6: 0.08,
            5.0: 0.16, 5.4: 0.1, 4.4: 0.06, 4.8: 0.1,
            4.3: 0.02, 5.8: 0.02, 5.7: 0.04, 5.2: 0.06,
            5.5: 0.04, 4.5: 0.02, 5.3: 0.02}
    

    Draw a line for each (x,y) pair from (x,0) to (x,y):

    for x,y in data.items():
        plt.plot((x,x),(0,y))
    plt.show()
    plt.close()
    

    plotimage


    If all the lines should be the same, specify a color.

    plt.plot((x,x),(0,y), 'black')