Search code examples
pythonplotlyplotly-python

3D Plot:: How to set legend and colorbar orientation and position?


I want to create a 3d plot, which visualizes the correlation between hidden_layer_sizes, max_iter and Score. I had to google arround a bit to get a plot like expected, but now I'm facing some problems with the legend:

  1. I'm having two legends
  2. The second legend is very tiny

My goal is to move the right legend to the bottom. But it doesn't work, I'm not even able to remove the right legend. If I set showlegend=False only the highlighted small legend disappears, the right legend remains.

I'm sure it's just about my lack of plotly-experience. I'd appreciate any possible help.


MWE

Data

import pandas as pd

df = pd.DataFrame({'hidden_layer_sizes': {0: 25,
  1: 25,  2: 25,  3: 25,  4: 25,  5: 50,  6: 50,  7: 50,  8: 50,  9: 50,  10: 75,
  11: 75,  12: 75,  13: 75,  14: 75,  15: 100,  16: 100,  17: 100,  18: 100,  19: 100,  20: 125,
  21: 125,  22: 125,  23: 125,  24: 125,  25: 150,  26: 150,  27: 150,  28: 150,  29: 150}, 
'max_iter': {0: 100,  1: 200,  2: 300,  3: 400,  4: 500,  5: 100,  6: 200,  7: 300,  8: 400,  9: 500,  
10: 100,  11: 200,  12: 300,  13: 400,  14: 500,  15: 100, 16: 200,  17: 300,  18: 400,  19: 500,  
20: 100,  21: 200,  22: 300,  23: 400,  24: 500,  25: 100,  26: 200,  27: 300,  28: 400,  29: 500}, 
'Score': {0: 0.9270832984321359,  1: 0.9172223807360554,  2: 0.9202868292420568,  3: 0.9187318693456508,
  4: 0.9263589700182026,  5: 0.9325454241272417,  6: 0.9351742112383672,  7: 0.934706441722599,
  8: 0.9350294733755595,  9: 0.9334167352798914,  10: 0.9355533396303661,  11: 0.9327821227628682,
  12: 0.9333376163633981,  13: 0.9322875868305249,  14: 0.9345524934883098,  15: 0.9341786678949748,
  16: 0.9306931295155753,  17: 0.9332227354795629,  18: 0.9312008571438402,  19: 0.9335295484755572,
  20: 0.9333167395841182,  21: 0.9315595511169302,  22: 0.9301811416101524,  23: 0.9314818362895073,
  24: 0.9308551601915486,  25: 0.9296559215457606,  26: 0.9284091216867709,  27: 0.9318823563281231,
  28: 0.9295666150206443,  29: 0.9291284919738931},
 'Time': {0: 119.91294360160828,  1: 256.4710912704468,  2: 266.6792154312134,  3: 326.7445312023163,
  4: 256.8881601810455,  5: 183.77022705078124,  6: 359.7090343952179,  7: 383.6012378692627,
  8: 416.3133870601654,  9: 425.7837643623352,  10: 225.39801173210145,  11: 516.9914848804474,
  12: 562.7134436607361,  13: 585.6752841472626,  14: 560.5802517414093,  15: 267.22873797416685,
  16: 646.1253435134888,  17: 811.1979314804078,  18: 780.6058969974517,  19: 789.9369702339172,
  20: 394.0711458206177,  21: 890.7988158226013,  22: 1065.5482338428496,  23: 996.5119229316712,
  24: 1096.0208141803741,  25: 524.0947244644165,  26: 1182.684538602829,  27: 1348.3343998908997,
  28: 1356.0255290508271,  29: 1053.8607951164245}})

Code for creating the plot

import numpy as np
import plotly.graph_objects as go 
from scipy.interpolate import griddata
import plotly.io as pio

xi = np.linspace(min(df["hidden_layer_sizes"]), max(df["hidden_layer_sizes"]), num=100)
yi = np.linspace(min(df["max_iter"]), max(df["max_iter"]), num=100)

x_grid, y_grid = np.meshgrid(xi,yi)
z_grid = griddata((df["hidden_layer_sizes"],df["max_iter"]),df["Score"],(x_grid,y_grid),method="cubic")

fig = go.Figure(go.Surface(x=x_grid, y=y_grid, z=z_grid, showlegend=True))
fig.update_layout(title="Test",
                  width=600, height=600, template="none",
                  legend=dict(orientation="h"))

fig.show()

Solution

  • You're talking about two different things here: legend and colorbar, where the former is an attribute of the figure layout, and the latter is an attribute of the figure data or traces. To obtain what you're aiming for here, just include this:

    fig.update_layout(legend = dict(orientation="h", x = -0.25, y = -0.10))
    fig.update_traces(colorbar = dict(orientation='h', y = -0.25, x = 0.5))
    

    Plot 1

    enter image description here

    That is, if you'd like to keep the "small" legend at all. If not, just use:

    fig.update_layout(showlegend = False)
    

    Plot 2

    enter image description here

    Complete code:

    import numpy as np
    import plotly.graph_objects as go 
    from scipy.interpolate import griddata
    import plotly.io as pio
    
    import pandas as pd
    
    df = pd.DataFrame({'hidden_layer_sizes': {0: 25,
      1: 25,  2: 25,  3: 25,  4: 25,  5: 50,  6: 50,  7: 50,  8: 50,  9: 50,  10: 75,
      11: 75,  12: 75,  13: 75,  14: 75,  15: 100,  16: 100,  17: 100,  18: 100,  19: 100,  20: 125,
      21: 125,  22: 125,  23: 125,  24: 125,  25: 150,  26: 150,  27: 150,  28: 150,  29: 150}, 
    'max_iter': {0: 100,  1: 200,  2: 300,  3: 400,  4: 500,  5: 100,  6: 200,  7: 300,  8: 400,  9: 500,  
    10: 100,  11: 200,  12: 300,  13: 400,  14: 500,  15: 100, 16: 200,  17: 300,  18: 400,  19: 500,  
    20: 100,  21: 200,  22: 300,  23: 400,  24: 500,  25: 100,  26: 200,  27: 300,  28: 400,  29: 500}, 
    'Score': {0: 0.9270832984321359,  1: 0.9172223807360554,  2: 0.9202868292420568,  3: 0.9187318693456508,
      4: 0.9263589700182026,  5: 0.9325454241272417,  6: 0.9351742112383672,  7: 0.934706441722599,
      8: 0.9350294733755595,  9: 0.9334167352798914,  10: 0.9355533396303661,  11: 0.9327821227628682,
      12: 0.9333376163633981,  13: 0.9322875868305249,  14: 0.9345524934883098,  15: 0.9341786678949748,
      16: 0.9306931295155753,  17: 0.9332227354795629,  18: 0.9312008571438402,  19: 0.9335295484755572,
      20: 0.9333167395841182,  21: 0.9315595511169302,  22: 0.9301811416101524,  23: 0.9314818362895073,
      24: 0.9308551601915486,  25: 0.9296559215457606,  26: 0.9284091216867709,  27: 0.9318823563281231,
      28: 0.9295666150206443,  29: 0.9291284919738931},
     'Time': {0: 119.91294360160828,  1: 256.4710912704468,  2: 266.6792154312134,  3: 326.7445312023163,
      4: 256.8881601810455,  5: 183.77022705078124,  6: 359.7090343952179,  7: 383.6012378692627,
      8: 416.3133870601654,  9: 425.7837643623352,  10: 225.39801173210145,  11: 516.9914848804474,
      12: 562.7134436607361,  13: 585.6752841472626,  14: 560.5802517414093,  15: 267.22873797416685,
      16: 646.1253435134888,  17: 811.1979314804078,  18: 780.6058969974517,  19: 789.9369702339172,
      20: 394.0711458206177,  21: 890.7988158226013,  22: 1065.5482338428496,  23: 996.5119229316712,
      24: 1096.0208141803741,  25: 524.0947244644165,  26: 1182.684538602829,  27: 1348.3343998908997,
      28: 1356.0255290508271,  29: 1053.8607951164245}})
    
    
    xi = np.linspace(min(df["hidden_layer_sizes"]), max(df["hidden_layer_sizes"]), num=100)
    yi = np.linspace(min(df["max_iter"]), max(df["max_iter"]), num=100)
    
    x_grid, y_grid = np.meshgrid(xi,yi)
    z_grid = griddata((df["hidden_layer_sizes"],df["max_iter"]),df["Score"],(x_grid,y_grid),method="cubic")
    
    fig = go.Figure(go.Surface(x=x_grid, y=y_grid, z=z_grid, showlegend=True))
    fig.update_layout(title="Test",
                      width=600, height=600, template="none",
                      # legend=dict(orientation="h")
                     )
    
    fig.update_layout(legend = dict(orientation="h", x = -0.25, y = -0.10))
    fig.update_traces(colorbar = dict(orientation='h', y = -0.25, x = 0.5))
    fig.update_layout(showlegend = False)
    
    fig.show()