Say I have data about 3 trading strategies, each with and without transaction costs. I want to plot, on the same axes, the time series of each of the 6 variants (3 strategies * 2 trading costs). I would like the "with transaction cost" lines to be plotted with alpha=1
and linewidth=1
while I want the "no transaction costs" to be plotted with alpha=0.25
and linewidth=5
. But I would like the color to be the same for both versions of each strategy.
I would like something along the lines of:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
for c in with_transaction_frame.columns:
ax.plot(with_transaction_frame[c], label=c, alpha=1, linewidth=1)
****SOME MAGIC GOES HERE TO RESET THE COLOR CYCLE
for c in no_transaction_frame.columns:
ax.plot(no_transaction_frame[c], label=c, alpha=0.25, linewidth=5)
ax.legend()
What is the appropriate code to put on the indicated line to reset the color cycle so it is "back to the start" when the second loop is invoked?
In Matplotlib <1.5.0, you can reset the colorcycle to the original with Axes.set_color_cycle. Looking at the code for this, there is a function to do the actual work:
def set_color_cycle(self, clist=None):
if clist is None:
clist = rcParams['axes.color_cycle']
self.color_cycle = itertools.cycle(clist)
And a method on the Axes which uses it:
def set_color_cycle(self, clist):
"""
Set the color cycle for any future plot commands on this Axes.
*clist* is a list of mpl color specifiers.
"""
self._get_lines.set_color_cycle(clist)
self._get_patches_for_fill.set_color_cycle(clist)
This basically means you can call the set_color_cycle
with None as the only argument, and it will be replaced with the default cycle found in rcParams['axes.color_cycle']
.
I tried this with the following code and got the expected result:
import matplotlib.pyplot as plt
import numpy as np
for i in range(3):
plt.plot(np.arange(10) + i)
# for Matplotlib version < 1.5
plt.gca().set_color_cycle(None)
# for Matplotlib version >= 1.5
plt.gca().set_prop_cycle(None)
for i in range(3):
plt.plot(np.arange(10, 1, -1) + i)
plt.show()