Search code examples
pythonpython-2.7setpython-itertools

Manipulating a multidimensional set: marginalization, extension, membership


I'm doing a python module that requires lots of set manipulations. I have tried lots of things but I haven't succeded. First, I'm dealing with sets in three dimensions and then I use itertools product function to do the cartesian product:

from itertools import product

# Marginal sets
x1 = set(['u1', 'd1'])
x2 = set(['u2', 'd2'])
x3 = set(['u3', 'd3'])

# Cartesian product of x1,x2,x3
X = set(product(x1,x2,x3))

>>> X
set([('d1', 'u2', 'u3'), ('u1', 'u2', 'u3'), ('u1', 'u2', 'd3'), ('d1', 'd2', 'u3'), ('u1', 'd2', 'd3'), ('d1', 'd2', 'd3'), ('u1', 'd2', 'u3'), ('d1', 'u2', 'd3')])

I would like to extend set intersection such that it accepts the following:

>>>set('d1') & X # Return elements of X containing 'd1'
set([('d1', 'u2', 'u3'), ('d1', 'd2', 'u3'), ('d1', 'd2', 'd3'), ('d1', 'u2', 'd3')])

And also the following:

>>>set('d1','u2') & X # Return elements of X containing 'd1'
set([('d1', 'u2', 'u3'), ('d1', 'u2', 'd3')])

I would also like to extend again the cartesian product by adding a new dimension, e.g.

x4 = set(['u4', 'd4'])
Y = cartesianproduct(X,x4)
>>>Y
set([('u1', 'u2', 'u3', 'd4'), ('d1', 'd2', 'u3', 'd4'), ('u1', 'u2', 'u3', 'u4'), ('u1', 'u2', 'd3', 'u4'), ('d1', 'u2', 'u3', 'u4'), ('d1', 'd2', 'd3', 'u4'), ('d1', 'd2', 'u3', 'u4'), ('d1', 'd2', 'd3', 'd4'), ('u1', 'u2', 'd3', 'd4'), ('u1', 'd2', 'd3', 'd4'), ('d1', 'u2', 'd3', 'd4'), ('d1', 'u2', 'd3', 'u4'), ('u1', 'd2', 'u3', 'u4'), ('d1', 'u2', 'u3', 'd4'), ('u1', 'd2', 'd3', 'u4'), ('u1', 'd2', 'u3', 'd4')])

Finally, I would like to remove a dimension:

Z = remove(x3,X)
>>>Z
set([('d1', 'd2'), ('u1', 'u2'), ('u1', 'd2'), ('d1', 'u2')])

I'm new in Python and I feel like I'm doing it totally wrong, mixing lists, tuples and sets in the wrong ways maybe...

Please be gentle :)


Solution

  • I would work with a list of sets instead:

    X = map(set,product(x1,x2,x3))
    
    def has_element(X,Y):
        return [y for y in Y if len(y.intersection(X))]
    
    print has_element(['d1','u1'],X)
    
    >>> [set(['d2', 'u1', 'd3']), set(['d2', 'u1', 'u3']), set(['u1', 'd3', 'u2']), set(['u1', 'u3', 'u2']), set(['d2', 'd3', 'd1']), set(['d2', 'u3', 'd1']), set(['u2', 'd3', 'd1']), set(['u2', 'u3', 'd1'])]
    

    For your second function:

    def new_product(X,Y):
        Z = []
        for a,b in product(X,Y):
            ab = b.copy()
            ab.add(a)
            Z.append(ab)
        return Z
    
    print new_product(set(['d4','u4']),X)
    
    >>> [set(['u1', 'd2', 'd3', 'u4']), set(['u1', 'd2', 'u3', 'u4']), set(['u4', 'u1', 'd3', 'u2']), set(['u4', 'u1', 'u3', 'u2']), set(['u4', 'd2', 'd3', 'd1']), set(['u4', 'd2', 'u3', 'd1']), set(['u4', 'd1', 'd3', 'u2']), set(['u4', 'd1', 'u3', 'u2']), set(['u1', 'd4', 'd2', 'd3']), set(['u1', 'd4', 'd2', 'u3']), set(['d4', 'u1', 'd3', 'u2']), set(['d4', 'u1', 'u3', 'u2']), set(['d4', 'd2', 'd3', 'd1']), set(['d4', 'd2', 'u3', 'd1']), set(['d4', 'd1', 'd3', 'u2']), set(['d4', 'd1', 'u3', 'u2'])]
    

    The final function:

    def remove(X,Y):
        Z = Y[:] # Make a copy
        for z in Z:
            for x in X:
                if x in z:
                    z.remove(x)
        return Z
    
    print remove(x3,X)
    
    >>> [set(['d2', 'd1']), set(['d2', 'd1']), set(['u2', 'd1']), set(['u2', 'd1'])]