Search code examples
pythondataframematplotlibjupyter-notebookscatter3d

How to distribute the same data by grouping in Python


I have a table like below.

original table

According to the above table, I want to draw distributions of the features in 3-dimensions. It includes three class such as normal, hyper and hypo. I created the following code for this.

%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

labels = df.class_type
for l in labels:
   attr1 = df.X1
   attr2 = df.X2
   attr3 = df.X3
   ax.scatter(xs = attr1, ys = attr2, zs = attr3, label = "normal")
   ax.scatter(xs = attr1, ys = attr2, zs = attr3, label = "hyper")
   ax.scatter(xs = attr1, ys = attr2, zs = attr3, label = "hypo")

ax.set_title("1.Grup")
ax.set_xlabel("atr1")
ax.set_ylabel("atr2")
ax.set_zlabel("atr3")

plt.show()

current output

But I want to draw a plot like below. How can i do it? Thanks in advance

desired output


Solution

  • I found the answer. I made an example instead of your data frame. First, create a data frame.

    %matplotlib inline
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    import pandas as pd
    import itertools
    
    df = pd.DataFrame({'class_att': [1, 1, 2, 2, 3, 3],
                       'X1': [100, 110, 120, 130, 140, 150],
                       'X2': [10, 20, 30, 40, 50, 60],
                       'X3': [50, 60, 70, 80, 90, 100],
                       'class_type': ['normal', 'normal', 'hyper', 'hyper', 'hypo', 'hypo']})
    

    You can create a group as a function of groupby().

    groups = df.groupby('class_type')
    

    Then draw the scatter plot and it's done.

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    
    colors = itertools.cycle(["r", "b", "g"])
    for name, group in groups:
        print(group)
        ax.scatter(xs=group.X1, ys=group.X2, zs=group.X3, label=name, color=next(colors), alpha=1)
    
    ax.legend()
    plt.show()
    

    enter image description here