I want to create a 3d plot, which visualizes the correlation between hidden_layer_sizes
, max_iter
and Score
. I had to google arround a bit to get a plot like expected, but now I'm facing some problems with the legend:
My goal is to move the right legend to the bottom. But it doesn't work, I'm not even able to remove the right legend. If I set showlegend=False
only the highlighted small legend disappears, the right legend remains.
I'm sure it's just about my lack of plotly-experience. I'd appreciate any possible help.
Data
import pandas as pd
df = pd.DataFrame({'hidden_layer_sizes': {0: 25,
1: 25, 2: 25, 3: 25, 4: 25, 5: 50, 6: 50, 7: 50, 8: 50, 9: 50, 10: 75,
11: 75, 12: 75, 13: 75, 14: 75, 15: 100, 16: 100, 17: 100, 18: 100, 19: 100, 20: 125,
21: 125, 22: 125, 23: 125, 24: 125, 25: 150, 26: 150, 27: 150, 28: 150, 29: 150},
'max_iter': {0: 100, 1: 200, 2: 300, 3: 400, 4: 500, 5: 100, 6: 200, 7: 300, 8: 400, 9: 500,
10: 100, 11: 200, 12: 300, 13: 400, 14: 500, 15: 100, 16: 200, 17: 300, 18: 400, 19: 500,
20: 100, 21: 200, 22: 300, 23: 400, 24: 500, 25: 100, 26: 200, 27: 300, 28: 400, 29: 500},
'Score': {0: 0.9270832984321359, 1: 0.9172223807360554, 2: 0.9202868292420568, 3: 0.9187318693456508,
4: 0.9263589700182026, 5: 0.9325454241272417, 6: 0.9351742112383672, 7: 0.934706441722599,
8: 0.9350294733755595, 9: 0.9334167352798914, 10: 0.9355533396303661, 11: 0.9327821227628682,
12: 0.9333376163633981, 13: 0.9322875868305249, 14: 0.9345524934883098, 15: 0.9341786678949748,
16: 0.9306931295155753, 17: 0.9332227354795629, 18: 0.9312008571438402, 19: 0.9335295484755572,
20: 0.9333167395841182, 21: 0.9315595511169302, 22: 0.9301811416101524, 23: 0.9314818362895073,
24: 0.9308551601915486, 25: 0.9296559215457606, 26: 0.9284091216867709, 27: 0.9318823563281231,
28: 0.9295666150206443, 29: 0.9291284919738931},
'Time': {0: 119.91294360160828, 1: 256.4710912704468, 2: 266.6792154312134, 3: 326.7445312023163,
4: 256.8881601810455, 5: 183.77022705078124, 6: 359.7090343952179, 7: 383.6012378692627,
8: 416.3133870601654, 9: 425.7837643623352, 10: 225.39801173210145, 11: 516.9914848804474,
12: 562.7134436607361, 13: 585.6752841472626, 14: 560.5802517414093, 15: 267.22873797416685,
16: 646.1253435134888, 17: 811.1979314804078, 18: 780.6058969974517, 19: 789.9369702339172,
20: 394.0711458206177, 21: 890.7988158226013, 22: 1065.5482338428496, 23: 996.5119229316712,
24: 1096.0208141803741, 25: 524.0947244644165, 26: 1182.684538602829, 27: 1348.3343998908997,
28: 1356.0255290508271, 29: 1053.8607951164245}})
Code for creating the plot
import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import griddata
import plotly.io as pio
xi = np.linspace(min(df["hidden_layer_sizes"]), max(df["hidden_layer_sizes"]), num=100)
yi = np.linspace(min(df["max_iter"]), max(df["max_iter"]), num=100)
x_grid, y_grid = np.meshgrid(xi,yi)
z_grid = griddata((df["hidden_layer_sizes"],df["max_iter"]),df["Score"],(x_grid,y_grid),method="cubic")
fig = go.Figure(go.Surface(x=x_grid, y=y_grid, z=z_grid, showlegend=True))
fig.update_layout(title="Test",
width=600, height=600, template="none",
legend=dict(orientation="h"))
fig.show()
You're talking about two different things here: legend
and colorbar
, where the former is an attribute of the figure layout, and the latter is an attribute of the figure data or traces. To obtain what you're aiming for here, just include this:
fig.update_layout(legend = dict(orientation="h", x = -0.25, y = -0.10))
fig.update_traces(colorbar = dict(orientation='h', y = -0.25, x = 0.5))
That is, if you'd like to keep the "small" legend at all. If not, just use:
fig.update_layout(showlegend = False)
import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import griddata
import plotly.io as pio
import pandas as pd
df = pd.DataFrame({'hidden_layer_sizes': {0: 25,
1: 25, 2: 25, 3: 25, 4: 25, 5: 50, 6: 50, 7: 50, 8: 50, 9: 50, 10: 75,
11: 75, 12: 75, 13: 75, 14: 75, 15: 100, 16: 100, 17: 100, 18: 100, 19: 100, 20: 125,
21: 125, 22: 125, 23: 125, 24: 125, 25: 150, 26: 150, 27: 150, 28: 150, 29: 150},
'max_iter': {0: 100, 1: 200, 2: 300, 3: 400, 4: 500, 5: 100, 6: 200, 7: 300, 8: 400, 9: 500,
10: 100, 11: 200, 12: 300, 13: 400, 14: 500, 15: 100, 16: 200, 17: 300, 18: 400, 19: 500,
20: 100, 21: 200, 22: 300, 23: 400, 24: 500, 25: 100, 26: 200, 27: 300, 28: 400, 29: 500},
'Score': {0: 0.9270832984321359, 1: 0.9172223807360554, 2: 0.9202868292420568, 3: 0.9187318693456508,
4: 0.9263589700182026, 5: 0.9325454241272417, 6: 0.9351742112383672, 7: 0.934706441722599,
8: 0.9350294733755595, 9: 0.9334167352798914, 10: 0.9355533396303661, 11: 0.9327821227628682,
12: 0.9333376163633981, 13: 0.9322875868305249, 14: 0.9345524934883098, 15: 0.9341786678949748,
16: 0.9306931295155753, 17: 0.9332227354795629, 18: 0.9312008571438402, 19: 0.9335295484755572,
20: 0.9333167395841182, 21: 0.9315595511169302, 22: 0.9301811416101524, 23: 0.9314818362895073,
24: 0.9308551601915486, 25: 0.9296559215457606, 26: 0.9284091216867709, 27: 0.9318823563281231,
28: 0.9295666150206443, 29: 0.9291284919738931},
'Time': {0: 119.91294360160828, 1: 256.4710912704468, 2: 266.6792154312134, 3: 326.7445312023163,
4: 256.8881601810455, 5: 183.77022705078124, 6: 359.7090343952179, 7: 383.6012378692627,
8: 416.3133870601654, 9: 425.7837643623352, 10: 225.39801173210145, 11: 516.9914848804474,
12: 562.7134436607361, 13: 585.6752841472626, 14: 560.5802517414093, 15: 267.22873797416685,
16: 646.1253435134888, 17: 811.1979314804078, 18: 780.6058969974517, 19: 789.9369702339172,
20: 394.0711458206177, 21: 890.7988158226013, 22: 1065.5482338428496, 23: 996.5119229316712,
24: 1096.0208141803741, 25: 524.0947244644165, 26: 1182.684538602829, 27: 1348.3343998908997,
28: 1356.0255290508271, 29: 1053.8607951164245}})
xi = np.linspace(min(df["hidden_layer_sizes"]), max(df["hidden_layer_sizes"]), num=100)
yi = np.linspace(min(df["max_iter"]), max(df["max_iter"]), num=100)
x_grid, y_grid = np.meshgrid(xi,yi)
z_grid = griddata((df["hidden_layer_sizes"],df["max_iter"]),df["Score"],(x_grid,y_grid),method="cubic")
fig = go.Figure(go.Surface(x=x_grid, y=y_grid, z=z_grid, showlegend=True))
fig.update_layout(title="Test",
width=600, height=600, template="none",
# legend=dict(orientation="h")
)
fig.update_layout(legend = dict(orientation="h", x = -0.25, y = -0.10))
fig.update_traces(colorbar = dict(orientation='h', y = -0.25, x = 0.5))
fig.update_layout(showlegend = False)
fig.show()