I want to use pool.impa to use multiprocessing for the following code:
# RSI
def predict(stock_symbol,date_confidences):
stock_data = pd.read_csv(f'data/{stock_symbol}')
# Calculate the RSI change from the previous row
stock_data['rsi_change'] = stock_data['RSI'].diff()
# Initialize the 'label' column with zeros
stock_data['predict'] = 0
# Set 'label' to 1 for rows where RSI is below 30 and has increased from the previous row
stock_data.loc[(stock_data['RSI'] < 30) & (stock_data['rsi_change'] > 0), 'predict'] = 1
# Drop the 'rsi_change' column if no longer needed
stock_data.drop(columns=['rsi_change'], inplace=True)
for ind, row in stock_data.iterrows():
date_confidences[date] += [(1,row['predict'],stock_symbol,None,row['Volume'])]
if __name__ == '__main__':
stock_symbols = [s.split('/')[-1].replace('.csv','') for s in sorted(glob.glob('../../data/*.csv'))]
years = [1962,2023]
current_date = datetime(years[0], 1, 1)
date_list = []
while current_date.year <= years[-1]:
date_list.append(current_date.strftime('%Y-%m-%d'))
current_date += timedelta(days=1)
# date_confidences = collections.defaultdict(list)
manager = Manager()
date_confidences = manager.dict()
for date in date_list:
date_confidences[date] = []
with Pool(processes=os.cpu_count()) as pool:
args_list = [(stock_symbol, date_confidences) for stock_symbol in stock_symbols]
list(tqdm(pool.imap(predict, args_list), total=len(stock_symbols)))
with open('cache/date_confidences.pkl','wb') as file:
pickle.dump(date_confidences,file)
However, I get this error:
multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "/disk/nouri/environment/transformer3/lib/python3.9/multiprocessing/pool.py", line 125, in worker
result = (True, func(*args, **kwds))
TypeError: predict() missing 1 required positional argument: 'date_confidences'
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/disk/nouri/stock/chatgpt/crossvalidation/day_estimator/indicators/analyze.py", line 102, in <module>
list(tqdm(pool.imap(predict, args_list), total=len(stock_symbols)))
File "/disk/nouri/environment/transformer3/lib/python3.9/site-packages/tqdm/std.py", line 1195, in __iter__
for obj in iterable:
File "/disk/nouri/environment/transformer3/lib/python3.9/multiprocessing/pool.py", line 870, in next
raise value
TypeError: predict() missing 1 required positional argument: 'date_confidences'
In my code, I'm passing both arguments but still get the error for missing arguments. I need to pass both arguments to the function predict
. Is it possible to send these two arguments with imap or I should use other methods?
You can define a new function where you pass one single argument, expand it inside the function, and send it to `predict
# RSI
def predict(stock_symbol,date_confidences):
stock_data = pd.read_csv(f'data/{stock_symbol}')
# Calculate the RSI change from the previous row
stock_data['rsi_change'] = stock_data['RSI'].diff()
# Initialize the 'label' column with zeros
stock_data['predict'] = 0
# Set 'label' to 1 for rows where RSI is below 30 and has increased from the previous row
stock_data.loc[(stock_data['RSI'] < 30) & (stock_data['rsi_change'] > 0), 'predict'] = 1
# Drop the 'rsi_change' column if no longer needed
stock_data.drop(columns=['rsi_change'], inplace=True)
for ind, row in stock_data.iterrows():
date_confidences[date] += [(1,row['predict'],stock_symbol,None,row['Volume'])]
def process_stock(args):
stock_symbol, date_confidences = args
return predict(stock_symbol, date_confidences)
if __name__ == '__main__':
stock_symbols = [s.split('/')[-1].replace('.csv','') for s in sorted(glob.glob('../../data/*.csv'))]
years = [1962,2023]
current_date = datetime(years[0], 1, 1)
date_list = []
while current_date.year <= years[-1]:
date_list.append(current_date.strftime('%Y-%m-%d'))
current_date += timedelta(days=1)
# date_confidences = collections.defaultdict(list)
manager = Manager()
date_confidences = manager.dict()
for date in date_list:
date_confidences[date] = []
with Pool(processes=os.cpu_count()) as pool:
args_list = [(stock_symbol, date_confidences) for stock_symbol in stock_symbols]
list(tqdm(pool.imap(process_stock, args_list), total=len(stock_symbols)))
with open('cache/date_confidences.pkl','wb') as file:
pickle.dump(date_confidences,file)