Search code examples
pythontensorflowsliceargmax

Tensorflow : tf.argmax and slicing


I would like to design this loss function:

sum((y[argmax(y_)] - y_[argmax(y_)])²)

I don't find a way to do y[argmax(y_)]. I tried y[k], y[:,k] and y[None,k] none of these work. This is my code :

    Na = 3
    x = tf.placeholder(tf.float32, [None, 2])
    W = tf.Variable(tf.zeros([2, Na]))
    b = tf.Variable(tf.zeros([Na]))
    y = tf.nn.relu(tf.matmul(x, W) + b)
    y_ = tf.placeholder(tf.float32, [None, 3])
    k = tf.argmax(y_, 1)
    diff = y[k] - y_[k]
    loss = tf.reduce_sum(tf.square(diff))

And the error:

  File "/home/ncarrara/phd/code/cython/robotnavigation/ftq/cftq19.py", line 156, in <module>
    diff = y[k] - y_[k]
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 499, in _SliceHelper
    name=name)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 663, in strided_slice
    shrink_axis_mask=shrink_axis_mask)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3515, in strided_slice
    shrink_axis_mask=shrink_axis_mask, name=name)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2508, in create_op
    set_shapes_for_outputs(ret)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1873, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1823, in call_with_requiring
    return call_cpp_shape_fn(op, require_shape_fn=True)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn
    debug_python_shape_fn, require_shape_fn)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 676, in _call_cpp_shape_fn_impl
    raise ValueError(err.message)
ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice' (op: 'StridedSlice') with input shapes: [?,3], [1,?], [1,?], [1].

Solution

  • That can be done using tf.gather_nd:

    import tensorflow as tf
    
    Na = 3
    x = tf.placeholder(tf.float32, [None, 2])
    W = tf.Variable(tf.zeros([2, Na]))
    b = tf.Variable(tf.zeros([Na]))
    y = tf.nn.relu(tf.matmul(x, W) + b)
    y_ = tf.placeholder(tf.float32, [None, 3])
    k = tf.argmax(y_, 1)
    # Make index tensor with row and column indices
    num_examples = tf.cast(tf.shape(x)[0], dtype=k.dtype)
    idx = tf.stack([tf.range(num_examples), k], axis=-1)
    diff = tf.gather_nd(y, idx) - tf.gather_nd(y_, idx)
    loss = tf.reduce_sum(tf.square(diff))
    

    Explanation:

    In this case, the idea of tf.gather_nd is to make a matrix (a two-dimensional tensor) where each row contains the index of the row and column to have in the output. For example, if I have a matrix a containing:

    | 1 2 3 |
    | 4 5 6 |
    | 7 8 9 |
    

    And a matrix i containing:

    | 1 2 |
    | 0 1 |
    | 2 2 |
    | 1 0 |
    

    Then the result of tf.gather_nd(a, i) would be the vector (one-dimensional tensor) containing:

    | 6 |
    | 2 |
    | 9 |
    | 4 |
    

    In this case, the column indices are given by tf.argmax in k; it tells you, for every row, which is the column with the highest value. Now you just need to put the row index with each of these. The first element in k is the index of the max value column in row 0, the next element the one for row 1, and so on. num_examples is just the number of rows in x and tf.range(num_examples) gives you then a vector from 0 to the number of rows in x minus 1 (that is, all the sequence of row indices). Now you just need to put that together with k, which is what tf.stack does, and the result idx is the argument for tf.gather_nd.