Search code examples
pythonlistnumpyxgboostargmax

Optimising finding the index of the highest value in a list


I have a long list of machine learning prediction probabilities for multiple classes and I'm trying to find the highest probability for each prediction. I've implemented the method below for this and it works but it is taking a long time (~10 mins for this step alone) when applied to our typical dataset of order 100 million predictions.

The prediction output (from XGBoost) is of this form:

[[9.9696952e-01, 1.9961601e-06, 1.1957183e-03, 2.4479270e-05,1.8083032e-03],
 [9.9696952e-01, 1.9961601e-06, 1.1957183e-03, 2.4479270e-05,1.8083032e-03],
 [9.9696952e-01, 1.9961601e-06, 1.1957183e-03, 2.4479270e-05,1.8083032e-03],
 ...,
]

And I'm currently getting the index of the highest value of each of the sub lists and appending it to a separate list as below:

# Find maximum class for each pixel and append to list
vals = []
for pred in predictions:
    max_index = np.where(pred == np.amax(pred))[0][0]
    vals.append(max_index)

I've run timeit on the numpy steps and got this response:

In [37]: %timeit np.where(predictions[0] == np.amax(predictions[0]))[0][0]
6.82 µs ± 25.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

I've had a look around for other functions which can do this functionality but haven't been able to find any. I switched to using a list comprehension instead of a for loop to see if that would speed it up but it made no difference to the time taken to run this step.

Any suggestions of faster ways to implement this?


Solution

  • You can use np.argmax and set axis=1 to get array([0, 0, 0])

    import numpy as np
    
    predictions = np.array([[9.9696952e-01, 1.9961601e-06, 1.1957183e-03, 2.4479270e-05, 1.8083032e-03],
                            [9.9696952e-01, 1.9961601e-06, 1.1957183e-03, 2.4479270e-05, 1.8083032e-03],
                            [9.9696952e-01, 1.9961601e-06, 1.1957183e-03, 2.4479270e-05, 1.8083032e-03]
                            ])
    
    output = np.argmax(predictions, axis=1)
    

    # old solution
    vals = []
    for pred in predictions:
        max_index = np.where(pred == np.amax(pred))[0][0]
        vals.append(max_index)
    

    all(output == vals)  # True