I have a dataframe of stock prices with the stocks as columns and dates as the index. Say:
pd.DataFrame(np.arange(9).reshape((3,3)),columns=['stock1','stock2','stock3'],index=['2023-01-01','2023-01-02','2023-01-03'])
stock1 stock2 stock3
2023-01-01 0 1 2
2023-01-02 3 4 5
2023-01-03 6 7 8
I also have a json file with lists of stock names for each date:
{
'2023-01-01':['stock1','stock3','stock4', 'stock5']
'2023-01-02':['stock2','stock4']
'2023-01-03':['stock1','stock2','stock3','stock5']
}
I need to filter this dataframe so it only contains the values of the stocks in the json at each point in time.
i.e. if on the first of jan stock 1 and 3 are in the list, but on the second only stock 2 is, then I would want the dataframe to look like this:
stock1 stock2 stock3
2023-01-01 0 NA 2
2023-01-02 NA 4 NA
2023-01-03 6 7 8
It is important to note that the dataframe may not contain all the elements in each list.
I have tried numerous ways to achieve this, using the pd.mask
and pd.where
functions for example. The only method that I have found to work is itterating over the rows with df.T.items()
, finding the common values between df.columns and the list at each date(set(df.columns) & set(list)
), and then applying lambda x: pd.na if x not in list else x
. This, however, is extremely slow for large datasets and my program will have to do this operation several constantly with different dataframes and filters.
Is there a better way to do this?
Many thanks in advance!
You can convert the json to a dataframe and construct that dataframe to a new stock*
based column header dataframe. Then mask the original dataframe with the existence of the column header.
d = {
'2023-01-01':['stock1','stock3','stock4', 'stock5'],
'2023-01-02':['stock2','stock4'],
'2023-01-03':['stock1','stock2','stock3','stock5'],
}
data = pd.DataFrame.from_dict(d, orient='index').stack()
m = pd.crosstab(data.index.get_level_values(0), data).astype(bool)
out = df.where(m)
$ print(data)
2023-01-01 0 stock1
1 stock3
2 stock4
3 stock5
2023-01-02 0 stock2
1 stock4
2023-01-03 0 stock1
1 stock2
2 stock3
3 stock5
dtype: object
$ print(m)
col_0 stock1 stock2 stock3 stock4 stock5
row_0
2023-01-01 True False True True True
2023-01-02 False True False True False
2023-01-03 True True True False True
$ print(out)
stock1 stock2 stock3
2023-01-01 0.0 NaN 2.0
2023-01-02 NaN 4.0 NaN
2023-01-03 6.0 7.0 8.0