Search code examples
pythontensorflowkerastensortorch

List of tensors and just tensors


I am updating codes from tensorflow 1.x to 2.1.0.

I changed tensorflow 1.x code

labels = tf.cast(labels, tf.int64)
predict = tf.argmax(input=logits, axis=1)
tf.metrics.accuracy(labels=labels, predictions=predict)

to tensorflow 2.1.0 code.

labels = tf.cast(labels, tf.int64)
predict = tf.argmax(input=logits, axis=1)
tf.keras.metrics.Accuracy.update_state(labels, predict) #updated code

But, when I run the updated code, I got the following error.

TypeError: update_state() missing 1 required positional argument: 'y_pred'

So, I checked the tensorflow 2.1.0 document, and parameters for tf.keras.metrics.Accuracy.update_state() seem to be a list (in form of [ , , , ]). Then, I searched for a way to convert tensor to a list, which is

labels = tf.make_tensor_proto(labels)
labels = tf.make_ndarray(labels)

After I run this code, it gives the following error.

TypeError: List of Tensors when single Tensor expected

So, I tried to turn a list of Tensors into Tensors with

labels = tf.stack(labels)
#or
labels = torch.stack(labels)

tf.stack() did not work, as it gave the same initial TypeError saying 'y_pred' is missing at the updated code.

torch.stack(), however, gave the following error.

TypeError: stack() : argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

So, I am guessing torch.stack() only accepts a tuple, NOT a list. But, tf.stack() seems to accept a list, but it does not turn it into a Tensor?

Are my labels and predict even a list of Tensors in the first place? If so, why would tf.stack() not turn them into Tensors? How can I correctly convert labels and predict so that they can be passed into tf.keras.metrics.Accuracy.update_state()?

I would very much appreciate if not using compat.v1. unless absolutely necessary.


Solution

  • try in this way:

    labels = [0,1]
    logits = np.asarray([[0.9,0.1],[0.1,0.9]])
    
    labels = tf.cast(labels, tf.int64)
    predict = tf.argmax(input=logits, axis=1)
    acc = tf.keras.metrics.Accuracy()
    acc = acc.update_state(y_true=labels, y_pred=predict)
    acc