Search code examples
rtensorflowkerasconv-neural-networkpredict

Accessing predictions when using k_argmax() instead of predict_classes()


I've created a CNN in RStudio using keras to predict MNIST digits. I am now trying to predict with this model, since predict_classes() was deprecated, I'm attempting to use k_argmax() with the following code:

cnn_pred <- cnn_model %>%
    predict(x_test) %>%
    k_argmax()

When I print cnn_pred, this is what I get:

tf.Tensor([8 7 7 ... 3 4 9], shape=(4252), dtype=int64)

How do I access the predicted values in order to examine them and then print a confusion matrix?


Solution

  • As per the answers on a previous post, I've found the following code to work:

    cnn_pred <- cnn_model %>%
       predict(x_test) %>%
       k_argmax() %>%
       as.integer()