Search code examples
pythonmatplotlibkeyerror

Key Error:3 while using For to plot using matplotlib.pyplot.scatter


I have a dataset with 10 columns, in which column 4 and 8 are the data I want to plot, and column 9 is determining the marker type.

mkr_dict = {'Fluoride': 'D', 'Chloride': 'P', 'Bromide': 'o', 'Iodide' : '^'}
for kind in mkr_dict:
    d = df[df.Halogen==kind]
    plt.scatter(d[3], d[7], 
                s = 20,
                c = 'red',
                marker=mkr_dict[kind])
plt.show()

I keep getting Key Error: 3 while executing. Can someone give me an advice?


Solution

  • As you can see in the following minimal example, the code from the question may run fine.

    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    
    a = np.random.rand(50,8)
    df = pd.DataFrame(a)
    df["Halogen"] = np.random.choice(['Fluoride','Chloride','Bromide','Iodide'], size=50)
    
    
    mkr_dict = {'Fluoride': 'D', 'Chloride': 'P', 'Bromide': 'o', 'Iodide' : '^'}
    for kind in mkr_dict:
        d = df[df.Halogen==kind]
        plt.scatter(d[3], d[7], 
                    s = 20,
                    c = 'red',
                    marker=mkr_dict[kind])
    plt.show()
    

    However, if the columns are not numbered, but have a name instead, something like d[3] will fail. You would then need to select the column by its index,

    d.iloc[:,3]
    

    or by its name (in the case below, e.g. d.D or d["D"]). A version using iloc:

    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    
    a = np.random.rand(50,8)
    df = pd.DataFrame(a, columns=list("ABCDEFGH"))
    df["Halogen"] = np.random.choice(['Fluoride','Chloride','Bromide','Iodide'], size=50)
    
    
    mkr_dict = {'Fluoride': 'D', 'Chloride': 'P', 'Bromide': 'o', 'Iodide' : '^'}
    for kind in mkr_dict:
        d = df[df.Halogen==kind]
        plt.scatter(d.iloc[:,3], d.iloc[:,7], 
                    s = 20,
                    c = 'red',
                    marker=mkr_dict[kind])
    plt.show()