Search code examples
pytorchreinforcement-learningsoftmax

Can I apply softmax only on specific output neurons?


I am building an Actor-Critic neural network model in pytorch in order to train an agent to play the game of Quoridor (hopefully). For this reason, I have a neural network with two heads, one for the actor output which does a softmax on all the possible moves and one for the critic output which is just one neuron (for regressing the value of the input state).

Now, in quoridor, most of the times not all moves will be legal and as such I am wondering if I can exclude output neurons on the actor's head that correspond to illegal moves for the input state e.g. by passing a list of indices of all the neurons that correspond to legal moves. Thus, I want to not sum these outputs on the denominator of softmax.

Is there a functionality like this on pytorch (because I cannot find one)? Should I attempt to implement such a Softmax myself (kinda scared to, pytorch probably knows best, I ve been adviced to use LogSoftmax as well)?

Furthermore, do you think this approach of dealing with illegal moves is good? Or should I just let him guess illegal moves and penalize him (negative reward) for it in the hopes that eventually it will not pick illegal moves?

Or should I let the softmax be over all the outputs and then just set illegal ones to zero? The rest won't sum to 1 but maybe I can solve that by plain normalization (i.e. dividing by the L2 norm)?


Solution

  • An easy solution would be to mask out illegal moves with a large negative value, this will practically force very low (log)softmax values (example below).

    # 3 dummy actions for a batch size of 2
    >>> actions = torch.rand(2, 3)     
    >>> actions
    tensor([[0.9357, 0.2386, 0.3264],
            [0.0179, 0.8989, 0.9156]])
    # dummy mask assigning 0 to valid actions and 1 to invalid ones
    >>> mask = torch.randint(low=0, high=2, size=(2, 3))
    >>> mask
    tensor([[1, 0, 0],
            [0, 0, 0]])
    # set actions marked as invalid to very large negative value
    >>> actions = actions.masked_fill_(mask.eq(1), value=-1e10)
    >>> actions
    tensor([[-1.0000e+10,  2.3862e-01,  3.2636e-01],
            [ 1.7921e-02,  8.9890e-01,  9.1564e-01]])
    # softmax assigns no probability mass to illegal actions
    >>> actions.softmax(dim=-1)
    tensor([[0.0000, 0.4781, 0.5219],
            [0.1704, 0.4113, 0.4183]])