Search code examples
machine-learningtensorflowneural-networkrecurrent-neural-networkcalculus

How do I perform a differentiable operation selection in TensorFlow?


I am trying to produce a mathematical operation selection nn model, which is based on the scalar input. The operation is selected based on the softmax result which is produce by the nn. Then this operation has to be applied to the scalar input in order to produce the final output. So far I’ve come up with applying argmax and onehot on the softmax output in order to produce a mask which then is applied on the concated values matrix from all the possible operations to be performed (as show in the pseudo code below). The issue is that neither argmax nor onehot appears to be differentiable. I am new to this, so any would be highly appreciated. Thanks in advance.

    #perform softmax    
    logits  = tf.matmul(current_input, W) + b
    softmax = tf.nn.softmax(logits)

    #perform all possible operations on the input
    op_1_val = tf_op_1(current_input)
    op_2_val = tf_op_2(current_input)
    op_3_val = tf_op_2(current_input)
    values = tf.concat([op_1_val, op_2_val, op_3_val], 1)

    #create a mask
    argmax  = tf.argmax(softmax, 1)
    mask  = tf.one_hot(argmax, num_of_operations)

    #produce the input, by masking out those operation results which have not been selected
    output = values * mask

Solution

  • I believe that this is not possible. This is similar to Hard Attention described in this paper. Hard attention is used in Image captioning to allow the model to focus only on a certain part of the image at each step. Hard attention is not differentiable but there are 2 ways to go around this:

    1- Use Reinforcement Learning (RL): RL is made to train models that makes decisions. Even though, the loss function won't back-propagate any gradients to the softmax used for the decision, you can use RL techniques to optimize the decision. For a simplified example, you can consider the loss as penalty, and send to the node, with the maximum value in the softmax layer, a policy gradient proportional to the penalty in order to decrease the score of the decision if it was bad (results in a high loss).

    2- Use something like soft attention: instead of picking only one operation, mix them with weights based on the softmax. so instead of:

    output = values * mask
    

    Use:

    output = values * softmax
    

    Now, the operations will converge down to zero based on how much the softmax will not select them. This is easier to train compared to RL but it won't work if you must completely remove the non-selected operations from the final result (set them to zero completely).

    This is another answer that talks about Hard and Soft attention that you may find helpful: https://stackoverflow.com/a/35852153/6938290