Search code examples
pythontensorflowtensorflow2.0collaborative-filtering

None gradients from GradientTape inside a batch training loop


I am new to TensorFlow and i'm trying to implement simple collaborative filtering in v2. I have no trouble training on the whole train set at once, but I have problems when I try to batch train. Specifically when calculating grads the output gradients are [None, None]. The colab file with the full attempt can be found here.

        with tf.GradientTape() as tape:
            tape.watch(user_embedding)
            tape.watch(item_embedding)

            ## Compute the predicted ratings
            predicted_ratings = tf.reduce_sum(user_batch * item_batch, axis=1)

            ## Compute loss
            true_ratings = tf.cast(train_batch_st.values, tf.float32)
            loss = tf.losses.mean_squared_error(true_ratings, predicted_ratings) # batch loss
            # Cumulative epoch loss (across all batches)
            epoch_loss += loss

            ## Compute gradients of loss with respect to user and item embeddings
            grads = tape.gradient(loss, [user_embedding, item_embedding])
            print(grads) # grads None, None thus causing error below

            # Apply gradients
            optimizer.apply_gradients(zip(grads, [user_embedding, item_embedding]))

Thanks for any help!


Solution

  • This works for me: just do the tf.nn.embedding_lookup within the gradient tape.

    with tf.GradientTape() as tape:
        user_batch = tf.nn.embedding_lookup(user_embedding, user_ids) # shape = batch_size x embedding_dims
        item_batch = tf.nn.embedding_lookup(item_embedding, item_ids) # shape = batch_size x embedding_dims
    
        ## Compute the predicted ratings
        true_ratings = tf.cast(train_batch_st.values, tf.float32)
        predicted_ratings = tf.reduce_sum(user_batch * item_batch, axis=1)
    
        ## Compute loss
        # Using MSE here
        loss = tf.losses.mean_squared_error(true_ratings, predicted_ratings) # batch loss
    
    # Cumulative epoch loss (across all batches)
    epoch_loss += loss
    
    ## Compute gradients of loss with respect to user and item embeddings
    grads = tape.gradient(loss, [user_embedding, item_embedding])
    
    # Apply gradients (update user and item embeddings)
    optimizer.apply_gradients(zip(grads, [user_embedding, item_embedding]))