Search code examples
pythonmultiprocessing

How to send multiple arguments to a function when using imap?


I want to use pool.impa to use multiprocessing for the following code:


# RSI
def predict(stock_symbol,date_confidences):
    stock_data = pd.read_csv(f'data/{stock_symbol}')


    # Calculate the RSI change from the previous row
    stock_data['rsi_change'] = stock_data['RSI'].diff()

    # Initialize the 'label' column with zeros
    stock_data['predict'] = 0

    # Set 'label' to 1 for rows where RSI is below 30 and has increased from the previous row
    stock_data.loc[(stock_data['RSI'] < 30) & (stock_data['rsi_change'] > 0), 'predict'] = 1

    # Drop the 'rsi_change' column if no longer needed
    stock_data.drop(columns=['rsi_change'], inplace=True)

    for ind, row in stock_data.iterrows():
        date_confidences[date] += [(1,row['predict'],stock_symbol,None,row['Volume'])]


if __name__ == '__main__':
    stock_symbols = [s.split('/')[-1].replace('.csv','') for s in sorted(glob.glob('../../data/*.csv'))]

    years = [1962,2023]
    current_date = datetime(years[0], 1, 1)
    date_list = []
    while current_date.year <= years[-1]:
        date_list.append(current_date.strftime('%Y-%m-%d'))
        current_date += timedelta(days=1)

    # date_confidences = collections.defaultdict(list)
    manager = Manager()
    date_confidences = manager.dict()

    for date in date_list:
        date_confidences[date] = []

    with Pool(processes=os.cpu_count()) as pool:
        args_list = [(stock_symbol, date_confidences) for stock_symbol in stock_symbols]
        list(tqdm(pool.imap(predict, args_list), total=len(stock_symbols)))

    with open('cache/date_confidences.pkl','wb') as file:
        pickle.dump(date_confidences,file)

However, I get this error:

multiprocessing.pool.RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/disk/nouri/environment/transformer3/lib/python3.9/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
TypeError: predict() missing 1 required positional argument: 'date_confidences'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/disk/nouri/stock/chatgpt/crossvalidation/day_estimator/indicators/analyze.py", line 102, in <module>
    list(tqdm(pool.imap(predict, args_list), total=len(stock_symbols)))
  File "/disk/nouri/environment/transformer3/lib/python3.9/site-packages/tqdm/std.py", line 1195, in __iter__
    for obj in iterable:
  File "/disk/nouri/environment/transformer3/lib/python3.9/multiprocessing/pool.py", line 870, in next
    raise value
TypeError: predict() missing 1 required positional argument: 'date_confidences'

In my code, I'm passing both arguments but still get the error for missing arguments. I need to pass both arguments to the function predict. Is it possible to send these two arguments with imap or I should use other methods?


Solution

  • You can define a new function where you pass one single argument, expand it inside the function, and send it to `predict

    # RSI
    def predict(stock_symbol,date_confidences):
        stock_data = pd.read_csv(f'data/{stock_symbol}')
    
    
    # Calculate the RSI change from the previous row
    stock_data['rsi_change'] = stock_data['RSI'].diff()
    
    # Initialize the 'label' column with zeros
    stock_data['predict'] = 0
    
    # Set 'label' to 1 for rows where RSI is below 30 and has increased from the previous row
    stock_data.loc[(stock_data['RSI'] < 30) & (stock_data['rsi_change'] > 0), 'predict'] = 1
    
    # Drop the 'rsi_change' column if no longer needed
    stock_data.drop(columns=['rsi_change'], inplace=True)
    
    for ind, row in stock_data.iterrows():
        date_confidences[date] += [(1,row['predict'],stock_symbol,None,row['Volume'])]
    
    def process_stock(args):
        stock_symbol, date_confidences = args
        return predict(stock_symbol, date_confidences)
    
    if __name__ == '__main__':
         stock_symbols = [s.split('/')[-1].replace('.csv','') for s in sorted(glob.glob('../../data/*.csv'))]
    
         years = [1962,2023]
         current_date = datetime(years[0], 1, 1)
         date_list = []
         while current_date.year <= years[-1]:
             date_list.append(current_date.strftime('%Y-%m-%d'))
             current_date += timedelta(days=1)
    
         # date_confidences = collections.defaultdict(list)
         manager = Manager()
         date_confidences = manager.dict()
    
         for date in date_list:
             date_confidences[date] = []
    
         with Pool(processes=os.cpu_count()) as pool:
             args_list = [(stock_symbol, date_confidences) for stock_symbol in stock_symbols]
             list(tqdm(pool.imap(process_stock, args_list), total=len(stock_symbols)))
    
         with open('cache/date_confidences.pkl','wb') as file:
             pickle.dump(date_confidences,file)