Search code examples
pythonpython-3.xlistlist-comprehension

Can this list comprehension be improved?


Let's say I have this list:

input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 5}]

I want to create output_list based on input_list:

output_list = []
for dic in input_list:
    new_dic = {}
    ab_sum = sum([dic['a'], dic['b']])
    if ab_sum % 2 == 0:
        new_dic['c'] = ab_sum
        new_dic['d'] = ab_sum ** 2
        new_dic['e'] = ab_sum ** 4
        output_list.append(new_dic)

Result:

[{'c': 8, 'd': 64, 'e': 4096}]

The actual dictionary is way bigger and this gets messy. The more readable solution would be to use list comprehension:

output_list = [{'c': ab_sum,
                'd': (ab_sq:= ab_sum **2),
                'e': ab_sq **2}
                for dic in input_list 
                if (ab_sum:=sum([dic['a'], dic['b']])) % 2 == 0]

This seems inconsistent as I assign to variables both in the filter and within the dictionary. I would like to know if there is a more elegant solution to these types of problems, or I am overthinking it?


Solution

  • Here is an alternative version, using a helper function, and taking advantage of the walrus operator :=, introduced in Python 3.8 (I changed the variable names just to make the code easier to read):

    def get_cde_from(pair):
        ab = sum(pair.values())
        if not (ab % 2):
            return {*zip('cde', [ab, ab * ab, ab ** 4])}
    
    
    pairs = [
        {'a': 1, 'b': 2},
        {'a': 3, 'b': 5}
    ]
    
    triplets = []
    
    for pair in pairs:
        if cde := get_cde_from(pair):
            triplets.append(cde)
    
    print(triplets)
    

    The output of the above code is the following:

    [{'c': 8, 'd': 64, 'e': 4096}]
    

    You can also populate triplets (output_list, in your sample code) with a comprehension combined with the walrus operator:

    triplets = [cde for pair in pairs if (cde := get_cde_from(pair))]
    

    Another option is to use a generator instead:

    def gen_cde_from(pairs):
        for pair in pairs:
            ab = sum(pair.values())
            if not (ab % 2):
                yield dict(zip('cde', [ab, ab * ab, ab**4]))
    
    
    pairs = [
        {'a': 1, 'b': 2},
        {'a': 3, 'b': 5}
    ]
    
    triplets = [*gen_cde_from(pairs)]
    
    print(triplets)
    

    Which will also result in:

    [{'c': 8, 'd': 64, 'e': 4096}]