Search code examples
pythontensorflowentropyloss-function

Tensorflow pairwise custom loss for mutual information


I'm learning how to use tensorflow and have run into a problem in implementing a custom loss function. Specifically, I'm trying to compute the average mutual information between all pairs of variables (the idea being to determine what predictions for one class are tightly correlated with another).

For example, if I have an array

# In simple case, 2 entries of data showing predictions for non-exclusive
# properties A, B, C, and D in that order.
data = np.array([[0.99, 0.05, 0.85, 0.2], [0.97, 0.57, 0.88, 0.1]])

I'd like to get back a tensor showing the mutual information between A and B, A and C, A and D, B and A, etc. where A is the 1st element of each row vector, B is the 2nd, etc. I would also be ok with just getting the average pairwise mutual information for each variable (e.g. average of MI(A, B), MI(A, C), and MI(A, D))

The way I would do this is by calculating the entropy across rows of every pair of variables and then subtracting off the entropy for each variable alone.

As a starting point, I looked at existing code for computing the covariance of two variables:

def tf_cov(x):
    mean_x = tf.reduce_mean(x, axis=0, keepdims=True)
    mx = tf.matmul(tf.transpose(mean_x), mean_x)
    vx = tf.matmul(tf.transpose(x), x)/tf.cast(tf.shape(x)[0], tf.float32)
    cov_xx = vx - mx
    return cov_xx

This is a nice example of how to get pair-wise statistics, but it doesn't get me quite the metrics I want.

I'm able to compute the entropy for a single variable as well:

def tf_entropy(prob_a):
    # Calculates the entropy along each column
    col_entropy = tf.reduce_sum(prob_a * tf.log(prob_a), axis=0)

    return col_entropy

Does anyone know of a good way to compute the pairwise entropies? I imagine it will look a lot like matmul, but instead of summing the element-wise products, I would compute the entropy. Of course, if you know of existing tensorflow functions that already do what I want, that would be great. I've been reading up on various entropy-related functions, but they never seem to be quite what I want.


Solution

  • If you want to calculate the Mutual information between say X and Y, it depends on the underlying assumptions you can make. If you have very high dimensional data and complicated distributions, I suggest binning, which is non parametric. Also there are some more sopisticated methods that I am using. See here, here and here.

    The first two don't really scale well and the last one involves some hyperparameter tuning that can get your numbers totally off (or I am doing smth. wrong), but scales relatively well.