Search code examples
pythonplotlypy-shiny

How to correctly render a 3D Surface date axis?


I'm having some problems rendering axis date correctly in Shiny for Python with plotly surface plots. In particular, axis of type date are render as floats.

Find here an example in shiny playground.

Note that same exact code works if I render the figure with fig.show() outside of Shiny for Python (i.e. the x axis renders as a date not as a float).

I already tried to explicitly cast the layout of the figure as

fig.update_layout(
            scene=dict(
                xaxis=dict(
                    type="date",
                    tickformat='%Y'
                )
            )

But I get even worst result (figure not rendering at all).


Solution

  • Replace the @render_widget with express.render.ui and replace the return fig with return ui.HTML(fig.to_html()) (see express.ui.HTML and plotly.io.to_html).

    enter image description here

    Link to playground

    import plotly.graph_objects as go
    import pandas as pd
    import numpy as np
    from shiny.express import render, ui
    
    def plotSurface(plot_dfs: list, names: list, title: str, **kwargs):
        """helper function to plot multiple surface with ESM color themes.
    
        Args:
            plot_dfs (list): list of dataframes containing the matrix. index of each df is datetime format and columns are maturities
            names (list): list of names for the surfaces in plot_dfs
            title (str): title of the plot
    
        Raises:
            TypeError: _description_
            TypeError: _description_
            ValueError: _description_
    
        Returns:
            Figure: plotly figure
        """
        for i, plot_df in enumerate(plot_dfs):
            if not isinstance(plot_df.index, pd.core.indexes.datetimes.DatetimeIndex):
                raise TypeError(f"plot_df number {i} in plot_dfs should have an index of type datetime but got {type(plot_df.index)}")
        if not (isinstance(plot_dfs, list) and isinstance(names, list)):
            raise TypeError(f"both plot_dfs and names should be list. Instead got {type(plot_dfs), {type(names)}}")
        if len(plot_dfs) != len(names):
            raise ValueError(f"plot_dfs and names should have the same length but got {len(plot_dfs)} != {len(names)}")
        
        fig = go.Figure()
    
        # stack surfaces. The last one will overwrite the first one when values are equal
        for i, (plot_df, name) in enumerate(zip(plot_dfs, names)):
            X, Y = np.meshgrid(plot_df.index, plot_df.columns)
            Z = plot_df.values.T
            fig.add_trace(go.Surface(z=Z, x=X, y=Y, name=name, showscale=False, showlegend=True, opacity=0.9))
        
        # Update layout for better visualization and custom template
        fig.update_layout(
            title=title,
            title_x=0.5,
            scene=dict(
                xaxis_title='Date',
                yaxis_title='Maturity',
                zaxis_title='Value',
            ),
            margin=dict(l=30, r=30, b=30, t=50),
            # template=esm_theme,
            legend=dict(title="Legend"),
        )
    
        return fig
    
    
    @render.ui
    def plot_1():
        plot_dfs = [
            pd.DataFrame(
                index = pd.to_datetime([f"{y}/01/01" for y in range(2020, 2100)]),
                columns = ["3m", "6m", "9m"] + [f"{y}Y" for y in range(1,31)],
                data = 1
            ) 
        ]
    
        fig = plotSurface(plot_dfs, names=["t"], title=" ")
        return ui.HTML(fig.to_html())