Search code examples
pythonmatplotlibcolorslegendgluonts

Modify legend color of a matplotlib plot created by gluonts


I'm using gluonts and plotting a forecast (code from DeepVaR notebook). The code is the following:

def plot_prob_forecasts(ts_entry, forecast_entry, asset_name, plot_length=20):
    prediction_intervals = (0.95, 0.99)
    legend = ["observations", "median prediction"] + [f"{k}% prediction interval" for k in prediction_intervals][::-1]
    fig, ax = plt.subplots(1, 1, figsize=(10, 7))
    ts_entry[-plot_length:].plot(ax=ax)  # plot the time series
    forecast_entry.plot( intervals=prediction_intervals, color='g')    
    plt.grid(which="both")
    plt.legend(legend, loc="upper left")
    plt.title(f'Forecast of {asset_name} series Returns')
    plt.show()

and produces the following plot: enter image description here

The colors in the legend for the Confidence intervals are incorrect, but I cannot figure out how to fix them. Calling plt.gca().get_legend_handles_labels() returns just the first line (observations). Before or after calling legend() has the same output. The code from gluonts is:

    def plot(
        self,
        *,
        intervals=(0.5, 0.9),
        ax=None,
        color=None,
        name=None,
        show_label=False,
    ):
        """
        Plot median forecast and prediction intervals using ``matplotlib``.

        By default the `0.5` and `0.9` prediction intervals are plotted. Other
        intervals can be choosen by setting `intervals`.

        This plots to the current axes object (via ``plt.gca()``), or to ``ax``
        if provided. Similarly, the color is using matplotlibs internal color
        cycle, if no explicit ``color`` is set.

        One can set ``name`` to use it as the ``label`` for the median
        forecast. Intervals are not labeled, unless ``show_label`` is set to
        ``True``.
        """
        import matplotlib.pyplot as plt

        # Get current axes (gca), if not provided explicitly.
        ax = maybe.unwrap_or_else(ax, plt.gca)

        # If no color is provided, we use matplotlib's internal color cycle.
        # Note: This is an internal API and might change in the future.
        color = maybe.unwrap_or_else(
            color, lambda: ax._get_lines.get_next_color()
        )

        # Plot median forecast
        ax.plot(
            self.index.to_timestamp(),
            self.quantile(0.5),
            color=color,
            label=name,
        )

        # Plot prediction intervals
        for interval in intervals:
            if show_label:
                if name is not None:
                    label = f"{name}: {interval}"
                else:
                    label = interval
            else:
                label = None

            # Translate interval to low and high values. E.g for `0.9` we get
            # `low = 0.05` and `high = 0.95`. (`interval + low + high == 1.0`)
            # Also, higher interval values mean lower confidence, and thus we
            # we use lower alpha values for them.
            low = (1 - interval) / 2
            ax.fill_between(
                # TODO: `index` currently uses `pandas.Period`, but we need
                # to pass a timestamp value to matplotlib. In the future this
                # will use ``zebras.Periods`` and thus needs to be adapted.
                self.index.to_timestamp(),
                self.quantile(low),
                self.quantile(1 - low),
                # Clamp alpha betwen ~16% and 50%.
                alpha=0.5 - interval / 3,
                facecolor=color,
                label=label,
            )

If I set color=None, I get an error from matplotlib. Setting show_label=True and passing the name does not work either. Any idea how to fix it?

python=3.9.18

matplotlib=3.8.0

gluonts=0.13.2


Solution

  • plt.legend typically uses the "labeled" matplotlib elements encountered in the plot. In this case, the dark green area consists of two superimposed semitransparent layers. The default behavior just shows the semitransparent layers separately. You can use a tuple of the handles to show one on top of the other.

    Here is some simplified standalone code to simulate your situation.

    import matplotlib.pyplot as plt
    import numpy as np
    
    # create some dummy test data
    x = np.arange(30)
    y = np.random.normal(0.1, 1, size=x.size).cumsum()
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 7))
    
    # simulate plotting the observations
    ax.plot(x, y)
    
    prediction_intervals = (0.95, 0.99)
    legend = ["observations", "median prediction"] + [f"{k}% prediction interval" for k in prediction_intervals][::-1]
    color = 'g'
    
    xp = np.arange(10, 30)
    yp = y[-xp.size:] + np.random.normal(0.03, 0.2, size=xp.size).cumsum()
    # simulate plotting median forecast
    ax.plot(xp, yp, color=color, label=None)
    
    # simulate plotting the prediction intervals
    for interval in prediction_intervals:
        ax.fill_between(xp, yp - interval ** 2 * 5, yp + interval ** 2 * 5,
                        alpha=0.5 - interval / 3,
                        facecolor=color,
                        label=None)
    
    # the matplotlib elements for the two curves
    handle_l1, handle_l2 = ax.lines[:2]
    # the matplotlib elements for the two filled areas
    handle_c1, handle_c2 = ax.collections[:2]
    
    ax.legend(handles=[handle_l1, handle_l2, handle_c1, (handle_c1, handle_c2)],
              labels=legend, loc="upper left")
    ax.grid(which="both")
    
    plt.tight_layout()
    plt.show()
    

    legend with layered colored areas