I'm trying to assign color for each class in my dataframe in plotly, here is my code:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
knn = KNeighborsClassifier(n_neighbors=7)
# fitting the model
knn.fit(X_train, y_train)
# predict the response
pred = knn.predict(X_test)
dfp = pd.DataFrame(X_test)
dfp.columns = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
dfp["PClass"] = pred
pyo.init_notebook_mode()
data = [go.Scatter(x=dfp['SepalLengthCm'], y=dfp['SepalWidthCm'],
text=dfp['PClass'],
mode='markers',
marker=dict(
color=dfp['PClass']))]
layout = go.Layout(title='Chart', hovermode='closest')
fig = go.Figure(data=data, layout=layout)
pyo.iplot(data)
And here how my df looks like:
SepalLengthCm SepalWidthCm PetalLengthCm PetalWidthCm PClass
6.1 2.8 4.7 1.2 Iris-versicolor
5.7 3.8 1.7 0.3 Iris-setosa
7.7 2.6 6.9 2.3 Iris-virginica
So the problem is that it's not assigning color based on dfp['PClass']
column and every point on the plot is the same color: black. Even though when hovering every point is correctly labeled based on its class.
Any ideas why it's not working correctly?
In your code sample, you are trying to assign colors to your categorical groups using color=dfp['PClass'])
. This is a logic applied by for example ggplot
with ggplot(mtcars, aes(x=wt, y=mpg, shape=cyl, color=cyl, size=cyl))
where cyl
is a categorical variable. You'll see an example a bit down the page here.
But for plotly, this won't work. color
in go.Scatter
will only accept numerical values like in this example with color = np.random.randn(500)
:
In order to achieve your desired result, you'll have to build your plot using multiple traces like in this example: