Search code examples
pythonpandasmatplotlibplotscatter

multi-index dataframe causes wide separation between plotted data


I have the follow plot:

enter image description here

my pandas dataset is using multi index pandas, like

enter image description here

bellow is my code:

ax = plt.gca()

df['adjClose'].plot(ax=ax, figsize=(12,4), rot=9, grid=True, label='price', color='orange')
df['ma5'].plot(ax=ax, label='ma5', color='yellow')
df['ma100'].plot(ax=ax, label='ma100', color='green')

# df.plot.scatter(x=df.index, y='buy')
x = pd.to_datetime(df.unstack(level=0).index, format='%Y/%m/%d')

# plt.scatter(x, df['buy'].values)
ax.scatter(x, y=df['buy'].values, label='buy', marker='^', color='red')
ax.scatter(x, y=df['sell'].values, label='sell', marker='v', color='green')

plt.show()

Data from .csv

symbol,date,close,high,low,open,volume,adjClose,adjHigh,adjLow,adjOpen,adjVolume,divCash,splitFactor,ma5,ma100,buy,sell
601398,2020-01-01 00:00:00+00:00,5.88,5.88,5.88,5.88,0,5.2991971571,5.2991971571,5.2991971571,5.2991971571,0,0.0,1.0,,,,
601398,2020-01-02 00:00:00+00:00,5.97,6.03,5.91,5.92,234949400,5.3803073177,5.4343807581,5.3262338773,5.3352461174,234949400,0.0,1.0,,,,
601398,2020-01-03 00:00:00+00:00,5.99,6.02,5.96,5.97,152213050,5.3983317978,5.425368518,5.3712950777,5.3803073177,152213050,0.0,1.0,,,,
601398,2020-01-06 00:00:00+00:00,5.97,6.05,5.95,5.96,226509710,5.3803073177,5.4524052382,5.3622828376,5.3712950777,226509710,0.0,1.0,,,,

the above data is what looks after I have done to save csv, but after reload, it lost original structure like below


Solution

    • The issue, as can be seen in the plot, is the first 3 lines are plotted against the dataframe index, which presents as a tuple. The scatter plots are plotted against datetime values, x, which is not a value on the ax axis, so they're plotted far the to right.
      • enter image description here
      • enter image description here - the axis is a bunch of stacked tuples, like enter image description here
    • Don't convert the dataframe to a multi-index. If you're doing something, which creates the multi-index, then do df.reset_index(level=x, inplace=True) where x represents the level where 'symbol' is in the multi-index.
      • After removing 'symbol' from the index, convert 'date' to a datetime dtype with df.index = pd.to_datetime(df.index).date
    • Presumably, there's more than one unique 'symbol' in the dataframe, so a separate plot should be drawn for each.
    • Tested in pandas 1.3.1, python 3.8, and matplotlib 3.4.2
    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    
    # load the data from the csv
    df = pd.read_csv('file.csv')
    
    # convert date to a datetime format and extract only the date component
    df.date = pd.to_datetime(df.date).dt.date
    
    # set date as the index
    df.set_index('date', inplace=True)
    
    # this is what the dataframe should look like before plotting
                symbol  close  high   low  open     volume  adjClose  adjHigh  adjLow  adjOpen  adjVolume  divCash  splitFactor  ma5  ma100  buy  sell
    date                                                                                                                                              
    2020-01-01  601398   5.88  5.88  5.88  5.88          0      5.30     5.30    5.30     5.30          0      0.0          1.0  NaN    NaN  NaN   NaN
    2020-01-02  601398   5.97  6.03  5.91  5.92  234949400      5.38     5.43    5.33     5.34  234949400      0.0          1.0  NaN    NaN  NaN   NaN
    2020-01-03  601398   5.99  6.02  5.96  5.97  152213050      5.40     5.43    5.37     5.38  152213050      0.0          1.0  NaN    NaN  NaN   NaN
    2020-01-06  601398   5.97  6.05  5.95  5.96  226509710      5.38     5.45    5.36     5.37  226509710      0.0          1.0  NaN    NaN  NaN   NaN
    
    # extract the unique symbols
    symbols = df.symbol.unique()
    
    # get the number of unique symbols
    sym_len = len(symbols)
    
    # create a number of subplots based on the number of unique symbols in df
    fig, axes = plt.subplots(nrows=sym_len, ncols=1, figsize=(12, 4*sym_len))
    
    # if there's only 1 symbol, axes won't be iterable, so we put it in a list
    if type(axes) != np.ndarray:
        axes = [axes]
    
    # iterate through each symbol and plot the relevant data to an axes
    for ax, sym in zip(axes, symbols):
        
        # select the data for the relevant symbol
        data = df[df.symbol.eq(sym)]
        
        # plot data
        data[['adjClose', 'ma5', 'ma100']].plot(ax=ax, title=f'Data for Symbol: {sym}', ylabel='Value')
        ax.scatter(data.index, y=data['buy'], label='buy', marker='^', color='red')
        ax.scatter(data.index, y=data['sell'], label='sell', marker='v', color='green')
        ax.legend(bbox_to_anchor=(1, 1.02), loc='upper left')
        
    fig.tight_layout()
    
    • data.high and data.low are plotted for the scatter plots, since data.buy and data.sell are np.nan in the test data.

    enter image description here

    • df can be conveniently created with:
    sample = {'symbol': [601398, 601398, 601398, 601398], 'date': ['2020-01-01 00:00:00+00:00', '2020-01-02 00:00:00+00:00', '2020-01-03 00:00:00+00:00', '2020-01-06 00:00:00+00:00'], 'close': [5.88, 5.97, 5.99, 5.97], 'high': [5.88, 6.03, 6.02, 6.05], 'low': [5.88, 5.91, 5.96, 5.95], 'open': [5.88, 5.92, 5.97, 5.96], 'volume': [0, 234949400, 152213050, 226509710], 'adjClose': [5.2991971571, 5.3803073177, 5.3983317978, 5.3803073177], 'adjHigh': [5.2991971571, 5.4343807581, 5.425368518, 5.4524052382], 'adjLow': [5.2991971571, 5.3262338773, 5.3712950777, 5.3622828376], 'adjOpen': [5.2991971571, 5.3352461174, 5.3803073177, 5.3712950777], 'adjVolume': [0, 234949400, 152213050, 226509710], 'divCash': [0.0, 0.0, 0.0, 0.0], 'splitFactor': [1.0, 1.0, 1.0, 1.0], 'ma5': [np.nan, np.nan, np.nan, np.nan], 'ma100': [np.nan, np.nan, np.nan, np.nan], 'buy': [np.nan, np.nan, np.nan, np.nan], 'sell': [np.nan, np.nan, np.nan, np.nan]}
    df = pd.DataFrame(sample)