Search code examples
pythonpython-3.xlistlist-comprehensionnameerror

List comprehension to combine elements of sublists by column


I have a list of lists like so:

allrows = [['NEPW46486', 'NEPW46550', 'sersic', 20.04, 21.12],
['NEPW89344', 'NEPW89346', 'sersic', 20.33, 19.66], ...]

And I'd like to create a new list of lists, where each list corresponds to one "column". My desired output is:

cols = [['NEPW46486', 'NEPW89344', ...], ['NEPW46550', 'NEPW89346', ...], ['sersic', 'sersic', ...], [20.04, 20.33, ...], [21.12, 19.66, ...]]

I figured I could accomplish this with list comprehension, like this:

cols = [[row[n] for row in allrows] for n in range(len(row))]

But I get a NameError that row is not defined. I tried to also switch the orders of my loop statements, but that did not give me my desired output (instead it gave me exactly what I started with). How can I achieve my desired output with list comprehension?


Solution

  • That's the purpose of the built in zip function. You just need to unpack your lists as you call it. Something like:

    allrows = [['NEPW46486', 'NEPW46550', 'sersic', 20.04, 21.12],
               ['NEPW89344', 'NEPW89346', 'sersic', 20.33, 19.66]]
    
    for item in zip(*allrows): # unpack with *allrows
        print(item)
    

    Nets you:

    ('NEPW46486', 'NEPW89344')
    ('NEPW46550', 'NEPW89346')
    ('sersic', 'sersic')
    (20.04, 20.33)
    (21.12, 19.66)
    

    If for some reason zip() is unsatisfactory, to make a list comprehension it is usually easiest to make the actual code structure and then condense it. Starting with:

    cols = []
    for index, item in enumerate(allrows[0]):
        col = []
        for row in allrows:
            col.append(row[index])
        cols.append(col)
    print(cols)
    

    We get the desired

    [['NEPW46486', 'NEPW89344'], ['NEPW46550', 'NEPW89346'], ['sersic', 'sersic'], [20.04, 20.33], [21.12, 19.66]]

    So then we can just condense it to a single line like:

    cols = [[row[index] for row in allrows] for index, item in enumerate(allrows[0])]
    print(cols)
    

    Which again yields:

    [['NEPW46486', 'NEPW89344'], ['NEPW46550', 'NEPW89346'], ['sersic', 'sersic'], [20.04, 20.33], [21.12, 19.66]]