Search code examples
pythonnumpydictionarykey

Replacing specific row of array inside a dict


I would like to replace the nth row of an array which is stored in a dictionary. The arrays of the dictionary are created with a specific "shape". In the example below the dict has keys [10, 20, 30, 40] and each key corresponds to a 5x2 array; when trying to replace the 2nd row of the array identified by key=20 with a [2,-4] list, the code below replaces the 2nd row of all arrays in the dict. How to get it to replace only the array of key=20?

import itertools
import numpy as np
my_dict = dict(zip([10, 20, 30, 40], itertools.repeat(np.zeros((5, 2)))))
my_dict[20][1] = [ 2, -4]
print(my_dict)

Solution

  • Here is one of many possible solutions:

    import itertools
    import numpy as np
    
    my_dict = {key: np.zeros((5, 2)) for key in [10, 20, 30, 40]}
    my_dict[20][1] = [ 2, -4]
    print(my_dict)
    

    Your problem is that the result of itertools.repeat refers to the same object in memory 4 times, therefore modifications are reflected in all of the dict's values.

    In contrast, {key: np.zeros((5, 2)) ...} initializes a new zeros array for each dict entry.