Search code examples
pythonpython-itertools

Function in map detect only one argument


I want to compute the distance between all combinations of elements from two sets.

descriptor_1 (resp. descriptor_2) is a list of 2D arrays of length N1 (resp. N2) (one 2D array per element).

To compute all combinations between these two sets, I use:

combi = list(itertools.product(descriptor_1, descriptor_2))

which yields a list of 2-uples of length N1*N2.

And to compute the distances:

dist = map(chi2_dist, combi)

where:

def chi2_dist(a, b):
    a = a.flatten()
    b = b.flatten()

    dist = (1/2) * np.sum( (a-b)**2 / (a+b+EPS))

    return dist

However I get the following error:

TypeError: chi2_dist() takes exactly 2 arguments (1 given)

But, as my tuples contain 2 elements, I do not understand the error.


Solution

  • Your fucntion should be

    def chi2_dist(ab):
        a = ab[0]
        b = ab[1]
        a = a.flatten()
        b = b.flatten()
    

    BTW, much more efficient to just

    map(chi2_dist, itertools.product(descriptor_1, descriptor_2))
    

    no need of an intermediate list