Search code examples
pythonstringlist-comprehensionnumba

python a faster method of finding indexes in a list of 2million+ data that match string condition


##Mock Data##

my_list = list(range(1700))

import itertools
cross_product = list(itertools.product(my_list,my_list))
station_combinations = ["_".join([str(i),str(b)]) for i,b in cross_product if i != b]
###############

from time import time,sleep
station_name = "5"

start = time()

for h in range(10):
    reverse_indexes = [count for count,j in enumerate(station_combinations) if j.split("_")[1] == station_name ]
    regular_indexes = [count for count,j in enumerate(station_combinations) if j.split("_")[0] == station_name ]

print(time() - start )

Hello, I have shared the reproducible code above.

Background: Let me quickly introduce my data station_combinations is the cross product of "my_list" separated by the notation "_". You can think of it as a destination among "my_list" items so 1_2 would be going from 1 to 2 whereas 2_1 would going from 2 to 1.

So I will refer as "a_b" Among all the combinations in "reverse_indexes", I am trying to find the index of elements where b in ( "a_b" ) is equal to "station_name", so the "destination" is equal to station name, and in the regular_indexes an in ("a_b") the source is equal to the station_name

Problem: The code that I have works however it is very slow. if you look at the for loop (with cursor h) I iterate 10 times, however, in the original code, it is supposed to be approx. 2000. With even 10 iterations it approx. takes 8seconds on my computer. I am looking for ways to improve the speed significantly. I have tried the library numba, however because I actually get some of the data from a data frame I wasn't able to work it out with the "@njit" functionality. Would anyone be able to help?


Solution

  • One solution can be using indexes, in this case two indexes for a and b. For example:

    my_list = list(range(1700))
    
    import itertools
    
    cross_product = list(itertools.product(my_list, my_list))
    station_combinations = [
        "_".join([str(i), str(b)]) for i, b in cross_product if i != b
    ]
    
    # precompute indexes:
    index_a = {}
    index_b = {}
    for i, s in enumerate(station_combinations):
        a, b = s.split("_")
        index_a.setdefault(a, []).append(i)
        index_b.setdefault(b, []).append(i)
    
    
    from time import time, sleep
    
    station_name = "5"
    
    start = time()
    
    for h in range(10):
        reverse_indexes_new = index_b.get(station_name, [])
        regular_indexes_new = index_a.get(station_name, [])
    
    print(time() - start)
    

    Prints on my machine:

    7.62939453125e-06