Search code examples
pythonlistsortingdictionary

How to sort a dictionary based on a list in python


I have a dictionary

a = {'ground': obj1, 'floor 1': obj2, 'basement': obj3}

I have a list.

a_list = ['floor 1', 'ground', 'basement']

I want to sort dictionary a using its keys based on the list. Is it possible to do that?

i.e.:

sort(a).based_on(a_list) #this is wrong. But I want something like this. 

The output doesn't have to be another dictionary, I don't mind converting the dictionary to tuples and then sort those.


Solution

  • The naive way, sorting the list of (key, value) tuples, using the sorted() function and a custom sort key (called for each (key, value) pair produced by dict.items())):

    sorted(a.items(), key=lambda pair: a_list.index(pair[0]))
    

    The faster way, creating an index map first:

    index_map = {v: i for i, v in enumerate(a_list)}
    sorted(a.items(), key=lambda pair: index_map[pair[0]])
    

    This is faster because the dictionary lookup in index_map takes O(1) constant time, while the a_list.index() call has to scan through the list each time, so taking O(N) linear time. Since that scan is called for each key-value pair in the dictionary, the naive sorting option takes O(N^2) quadratic time, while using a map keeps the sort efficient (O(N log N), linearithmic time).

    Both assume that a_list contains all keys found in a. However, if that's the case, then you may as well invert the lookup and just retrieve the keys in order:

    [(key, a[key]) for key in a_list if key in a]
    

    which takes O(N) linear time, and allows for extra keys in a_list that don't exist in a.

    To be explicit: O(N) > O(N log N) > O(N^2), see this cheat sheet for reference.

    Demo:

    >>> a = {'ground': 'obj1', 'floor 1': 'obj2', 'basement': 'obj3'}
    >>> a_list = ('floor 1', 'ground', 'basement')
    >>> sorted(a.items(), key=lambda pair: a_list.index(pair[0]))
    [('floor 1', 'obj2'), ('ground', 'obj1'), ('basement', 'obj3')]
    >>> index_map = {v: i for i, v in enumerate(a_list)}
    >>> sorted(a.items(), key=lambda pair: index_map[pair[0]])
    [('floor 1', 'obj2'), ('ground', 'obj1'), ('basement', 'obj3')]
    >>> [(key, a[key]) for key in a_list if key in a]
    [('floor 1', 'obj2'), ('ground', 'obj1'), ('basement', 'obj3')]