Search code examples
pythonplotlyplotly-python

How to assign colors for scatterplot by group?


I'm trying to assign color for each class in my dataframe in plotly, here is my code:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

knn = KNeighborsClassifier(n_neighbors=7)

# fitting the model
knn.fit(X_train, y_train)

# predict the response
pred = knn.predict(X_test)

dfp = pd.DataFrame(X_test)
dfp.columns = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
dfp["PClass"] = pred

pyo.init_notebook_mode()
data = [go.Scatter(x=dfp['SepalLengthCm'], y=dfp['SepalWidthCm'], 
                   text=dfp['PClass'],
                   mode='markers',
                   marker=dict(
                    color=dfp['PClass']))]

layout = go.Layout(title='Chart', hovermode='closest')
fig = go.Figure(data=data, layout=layout)

pyo.iplot(data)

And here how my df looks like:

SepalLengthCm   SepalWidthCm    PetalLengthCm   PetalWidthCm    PClass
       6.1           2.8             4.7         1.2    Iris-versicolor
      5.7            3.8             1.7         0.3        Iris-setosa
      7.7             2.6        6.9         2.3    Iris-virginica

So the problem is that it's not assigning color based on dfp['PClass'] column and every point on the plot is the same color: black. Even though when hovering every point is correctly labeled based on its class. Any ideas why it's not working correctly?


Solution

  • In your code sample, you are trying to assign colors to your categorical groups using color=dfp['PClass']). This is a logic applied by for example ggplot with ggplot(mtcars, aes(x=wt, y=mpg, shape=cyl, color=cyl, size=cyl)) where cyl is a categorical variable. You'll see an example a bit down the page here.

    But for plotly, this won't work. color in go.Scatter will only accept numerical values like in this example with color = np.random.randn(500):

    enter image description here

    In order to achieve your desired result, you'll have to build your plot using multiple traces like in this example:

    enter image description here