Search code examples
pythonlistlist-comprehension

How to get a flat list while avoiding to make a nested list in the first place?


My goal

My question is about a list comprehension that does not puts elements in the resulting list as they are (which would results in a nested list), but extends the results into a flat list. So my question is not about flattening a nested list, but how to get a flat list while avoiding to make a nested list in the first place.

Example

Consider a have class instances with attributes that contains a list of integers:

class Foo:
    def __init__(self, l):
        self.l = l

foo_0 = Foo([1, 2, 3])
foo_1 = Foo([4, 5])
list_of_foos = [foo_0, foo_1]

Now I want to have a list of all integers in all instances of Foo. My best solution using extend is:

result = []
for f in list_of_foos:
    result.extend(f.l)

As desired, result is now [1, 2, 3, 4, 5].

Is there something better? For example list comprehensions?

Since I expect list comprehension to be faster, I'm looking for pythonic way get the desired result with a list comprehension. My best approach is to get a list of lists ('nested list') and flatten this list again - which seems quirky:

result = [item for sublist in [f.l for f in list_of_foos] for item in sublist]

What functionaly I'm looking for

result = some_module.list_extends(f.l for f in list_of_foos)

Questions and Answers I read before

I was quite sure there is an answer to this problem, but during my search, I only found list.extend and list comprehension where the reason why a nested list occurs is different; and python list comprehensions; compressing a list of lists? where the answers are about avoiding the nested list, or how to flatten it.


Solution

  • You can use multiple fors in a single comprehension:

    result = [
        n
        for foo in list_of_foos
        for n in foo.l
    ]
    

    Note that the order of fors is from the outside in -- same as if you wrote a nested for-loop:

    for foo in list_of_foos:
        for n in foo.l:
            print(n)