Search code examples
python-3.xmatplotlibdeep-learningdata-sciencedata-analysis

Unexpected plots on matplotlib histograms


I am quite a beginner with matplotlib so apologies if this seems like a dumb question.

I have a csv file with weight values for individual neurons in the different layers of my deep learning model. As I have four layers in my model, the file structure looks like this:

weight_1,weight_2......weight_n

weight_1,weight_2......weight_n

weight_1,weight_2......weight_n

weight_1,weight_2......weight_n

I want to extract the weights from each layer and generate the distributions out of it. I already have a code for it and it's working but for some epochs, the histograms have some weird colors which look like more histograms. I am attaching a sample image with the question. sample histogram

As you can see, there is some pinkish part which is masked by the blue bulk of the histogram. Can someone please help me to understand what is that?

My code currently looks like this (assume that my file is loaded in the reader):

        for row in csv_reader:
            a = np.array(row)
            a_float = a.astype(np.float)
            plt.hist(a_float,bins=20)
            plt.xlabel("weight_range")
            plt.ylabel("frequency")

Please note that FOUR different plots (images) are generated after finishing the loop as the csv file has four rows. I have only posted the sample image for one of them. I didn't try to plot all the rows in one graph.

EDIT

I reduced the number of bins and now it's more prominent. I am attaching another sample image.another_sample

SOLVED

Adding plt.figure() inside the loop solved it. Please check the comments and answer below for the details. The updated loop should be as follows:

        for row in csv_reader:
            a = np.array(row)
            a_float = a.astype(np.float)
            plt.figure()
            plt.hist(a_float,bins=20)
            plt.xlabel("weight_range")
            plt.ylabel("frequency")
            plt.close()

Solution

  • I was trying to reproduce your error, and most likely you are plotting several histograms in one plot:

    %matplotlib inline
    import matplotlib.pyplot as plt
    import numpy as np
    
    arrays = np.array([np.random.random() for i in range(200)]).reshape(2, 100)
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    for array in arrays:
        ax.hist(array, bins = 20)
    

    enter image description here