Search code examples
tensorflowgoogle-bigquerysoftmax

Tensorflow Prediction in Bigquery Softmax


I have a multiclass classification TensorFlow model imported into GCP BigQuery. When you make predictions, the output is the probabilities which is a type FLOAT (the probabilities) and a mode REPEATED. What is the best way to get the index of the max value using SQL in BigQuery?


Solution

  • If you want to find an index of max value from an array, using an UDF would be handy, I think.

    CREATE TEMP FUNCTION index_of_max(probabilites ARRAY<FLOAT64>) AS ((
      SELECT i FROM UNNEST(probabilites) p WITH OFFSET i 
       WHERE p = (SELECT MAX(p) FROM UNNEST(probabilites) p)
    ));
    
    SELECT index_of_max(dense_1) index_of_max FROM UNNEST([
      STRUCT([0.8611106872558594, 0.06648489832878113, 0.07240447402000427] AS dense_1),
      STRUCT([0.6251607537269592, 0.2989124655723572, 0.07592668384313583]),
      STRUCT([0.01427623350173235, 0.972910463809967, 0.01281337533146143])
    ]);
    

    output:

    enter image description here

    [note] zero-based index

    If applied to below example,

     SELECT dense_1, index_of_max(dense_1) AS index_of_max
       FROM ML.PREDICT (
              MODEL `testset_us.imported_tf_model`,
              (SELECT title AS input FROM `bigquery-public-data.hacker_news.stories`)
            )
    

    enter image description here