Search code examples
pythonnumpypython-itertools

Function call and code in function give different result


There is a bug in my program and I am trying to debug it using the debug tool of Pycharm with evaluate expression. I came across the reason my code isn't working, but don't understand how something like this can happen. When I call my function marginalizedTransition with some arguments I get 1.0 as a result, however when I run the code inside that function with the argument put in their place I get 0.00101001196095. These should however be identical as it is identical code.
player = 1, world.amountOfPlayers() = 2 and world.amountOfPossibleActions = 5.

marginalizedTransitionTest(Ps, [(3, 1), (5, 1)], [(2, 1), (5, 1)], 0, world, 1)

where

def marginalizedTransitionTest(Ps, sis, sfs, action, world, player):
    otherPlayers = filter(lambda x: x != player, range(world.amountOfPlayers()))
    return sum(np.prod([Ps[otherPlayer][tuple(it.chain(*sis))]
                   [np.insert(actions, player, action)[otherPlayer]]
                   for otherPlayer in otherPlayers])
           for actions in it.product(range(world.amountOfPossibleActions), repeat=world.amountOfPlayers() - 1))

gives 4.05514478458 and

sum(np.prod([Ps[otherPlayer][tuple(it.chain(*[(3, 1), (5, 1)]))]
               [np.insert(actions, 1, 0)[otherPlayer]]
               for otherPlayer in filter(lambda x: x != 1, range(world.amountOfPlayers()))])
       for actions in it.product(range(world.amountOfPossibleActions), repeat=world.amountOfPlayers() - 1))

gives 1.0.

I don't know if this is a common problem. If you need more information concerning the used functions, I can provide them, but this post is already getting ugly.


Solution

  • This is due to the behavior of filter as an iterator. In the function you first create the iterator, than you run a for multiple times on it. However, if you try running just this:

    otherPlayers = filter(lambda x: x != player, range(world.amountOfPlayers()))
    for player in otherPlayers:
        print(f'1: {player}')
    for player in otherPlayers:
        print(f'2: {player}')
    

    The result will be:

    1: 0
    1: 1
    

    As you can see, once you've ran a for over the filter, it's exhausted, and won't return anything.
    You didn't do the same thing when ran it raw, so the iterator worked as intended.
    Because of this, the fix is simple: add the filter in the for:

    def marginalizedTransition(Ps, sis, sfs, action, world, player):
        return sum(world.joinedTransition(sis, sfs, np.insert(actions, player, action)) *
                   np.prod([Ps[otherPlayer][tuple(it.chain(*sis))][np.insert(actions, player, action)[otherPlayer]]
                            for otherPlayer in filter(lambda x: x != player, range(world.amountOfPlayers()))])
                   for actions in it.product(range(world.amountOfPossibleActions), repeat=world.amountOfPlayers() - 1))
    

    Alternatively, as you've pointed it out in chat, you can just convert the iterator into a list:

    def marginalizedTransition(Ps, sis, sfs, action, world, player):
        otherPlayers = list(filter(lambda x: x != player, range(world.amountOfPlayers())))
        return sum(world.joinedTransition(sis, sfs, np.insert(actions, player, action)) *
                   np.prod([Ps[otherPlayer][tuple(it.chain(*sis))][np.insert(actions, player, action)[otherPlayer]]
                            for otherPlayer in otherPlayers])
                   for actions in it.product(range(world.amountOfPossibleActions), repeat=world.amountOfPlayers() - 1))
    

    For SO users who want to reproduce the issue, I've created these dummy objects based on chat discussion, and deconstructed the comprehension like so:

    import numpy as np
    import itertools as it
    
    
    class World:
        amountOfPossibleActions = 5
    
        def amountOfPlayers(self):
            return 2
    
        def joinedTransition(self, a, b, insert):
            return 1 if insert[0] == 3 else 0
    
    
    world = World()
    player = 1
    
    Ps = {0: {(3, 1, 5, 1): [[0.055144784579474179], [0.055144784579474179], [0.055144784579474179], [0.055144784579474179],
                             [0.055144784579474179]]}}
    
    
    def marginalizedTransition(Ps, sis, sfs, action, world, player):
        otherPlayers = filter(lambda x: x != player, range(world.amountOfPlayers()))
        s = 0
        for actions in it.product(range(world.amountOfPossibleActions), repeat=world.amountOfPlayers() - 1):
            w = world.joinedTransition(sis, sfs, np.insert(actions, player, action))
            a = []
            for otherPlayer in otherPlayers:
                b = np.insert(actions, player, action)[otherPlayer]
                v = Ps[otherPlayer][tuple(it.chain(*sis))][b]
                a.append(v)
            p = np.prod(a)
            s += w * p
        return s
    

    On a personal note: List comprehension is great, it makes code really compact. However, it can make it really difficult to read and debug code. Keep that in mind next time you want to nest comprehension. I was only able to find out what's wrong, after deconstructing it as seen above.