Search code examples
pythonnumpytensorflowmulti-index

How to convert this numpy to tf.function compatible code?


I'm trying to convert numpy to tensorflow equivalent code to be compatible with tf.function ...

Given have a (32, 6) numpy array target_values that looks like this:

array([[-0.01656106,  0.04762066,  0.05735449, -0.0284767 , -0.02237438,
        -0.00042562],
       [-0.01420249,  0.0477839 ,  0.0563598 , -0.02971786, -0.02367548,
         0.00001262],
       [-0.01695916,  0.04826669,  0.05893629, -0.03067053, -0.02261235,
         0.00345904],
       [-0.01953977,  0.04540274,  0.05829531, -0.02759781, -0.02390759,
        -0.00487727],
       [-0.01708016,  0.04894669,  0.0606699 , -0.02576046, -0.02461138,
        -0.00068538],
       [-0.01604217,  0.04770135,  0.05761468, -0.02858265, -0.02624938,
        -0.00084356],
       [-0.01527106,  0.04699571,  0.05959677, -0.02956396, -0.02510098,
        -0.00223234],
       [-0.01448676,  0.04620824,  0.05775366, -0.03008122, -0.02655901,
        -0.00159649],
       [-0.0172577 ,  0.04814827,  0.05807308, -0.02916523, -0.02367857,
        -0.00100602],
       [-0.01690523,  0.0484785 ,  0.05807881, -0.02960616, -0.02560546,
        -0.00065042],
       [-0.0166171 ,  0.0488232 ,  0.05776291, -0.03231864, -0.02132723,
        -0.00033605],
       [-0.01541627,  0.04840397,  0.0580376 , -0.02927143, -0.02461101,
         0.00121263],
       [-0.01685588,  0.047661  ,  0.05873172, -0.02989979, -0.02574112,
        -0.00126612],
       [-0.01333553,  0.05043796,  0.05915743, -0.02990219, -0.02657976,
        -0.0007656 ],
       [-0.01531163,  0.04781894,  0.05637252, -0.02968849, -0.02225551,
        -0.00151382],
       [-0.01357749,  0.04807179,  0.05955081, -0.02748637, -0.02498721,
        -0.00040934],
       [-0.01606943,  0.04768877,  0.05455931, -0.03136749, -0.02475093,
         0.00245846],
       [-0.01609829,  0.04687681,  0.05982678, -0.02886578, -0.02608151,
         0.00015348],
       [-0.01503662,  0.04740106,  0.05958583, -0.03141545, -0.02522127,
        -0.00063602],
       [-0.01697148,  0.04910276,  0.05744712, -0.02858391, -0.02481578,
        -0.00072039],
       [-0.01503395,  0.04843756,  0.05773868, -0.03061879, -0.02586869,
        -0.00025573],
       [-0.0152991 ,  0.04847359,  0.05739099, -0.0299796 , -0.02552593,
        -0.00334571],
       [-0.01324895,  0.04529134,  0.05534273, -0.03109139, -0.02304241,
        -0.00143186],
       [-0.01280282,  0.05004944,  0.05856398, -0.0314032 , -0.02394999,
        -0.00030306],
       [-0.01677033,  0.04876196,  0.05794405, -0.02888608, -0.02658239,
        -0.00015171],
       [-0.01572544,  0.04779808,  0.05939355, -0.03048976, -0.02896303,
        -0.00090334],
       [-0.01542805,  0.04709881,  0.05839922, -0.02894112, -0.02240603,
        -0.00188624],
       [-0.01493233,  0.0476524 ,  0.0581631 , -0.0297201 , -0.02485022,
        -0.00087418],
       [-0.01804641,  0.04739738,  0.06070606, -0.02981704, -0.02543145,
        -0.00115484],
       [-0.01518638,  0.04843838,  0.05744548, -0.02980216, -0.02420005,
         0.00036349],
       [-0.01442349,  0.04673778,  0.05804737, -0.03062913, -0.02476445,
        -0.00066772],
       [-0.01598305,  0.04622466,  0.0588723 , -0.03096713, -0.02364032,
        -0.00005574]])

Given another (32,) array of indices actions with values being in range(5) inclusive:

array([0, 2, 5, 5, 1, 1, 3, 4, 0, 5, 4, 3, 4, 5, 1, 0, 3, 0, 0, 2, 2, 2,
       0, 1, 4, 1, 4, 4, 0, 4, 1, 0])

I'm expecting this result:

array([-0.01656106,  0.0563598 ,  0.00345904, -0.00487727,  0.04894669,
        0.04770135, -0.02956396, -0.02655901, -0.0172577 , -0.00065042,
       -0.02132723, -0.02927143, -0.02574112, -0.0007656 ,  0.04781894,
       -0.01357749, -0.03136749, -0.01609829, -0.01503662,  0.05744712,
        0.05773868,  0.05739099, -0.01324895,  0.05004944, -0.02658239,
        0.04779808, -0.02240603, -0.02485022, -0.01804641, -0.02420005,
        0.04673778, -0.01598305], dtype=float32)

For self.batch_size == 32, I'm able to achieve what I need in numpy using:

state_action_values = target_values[np.arange(self.batch_size), actions]

For target_value_update being another (32,) array of new values, I will need to assign the new values to this slice using:

target_values[np.arange(self.batch_size), actions] = target_value_update

However in tensorflow under tf.function, this is not possible and I get the following error:

TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])

So I try:

target_values = tf.Variable(target_values)
state_action_values = tf.gather(target_values, actions, axis=1)

However here's the value of state_action_values which should be (32,) not (32, 32)

Tensor("GatherV2:0", shape=(32, 32), dtype=float32)

Solution

  • Use gather_nd():

    a = tf.range(32)[:, tf.newaxis]
    a = tf.concat((a, actions[:, tf.newaxis]), -1)
    output = tf.gather_nd(target_values, a)