Search code examples
pythonnumpydata-sciencederivativeentropy

What is the derivative of Shannon's Entropy?


I have the following simple python function that calculates the entropy of a single input X according to Shannon's Theory of Information:

import numpy as np

def entropy(X:'numpy array'):
  _, frequencies = np.unique(X, return_counts=True)
  probabilities  = frequencies/X.shape[0]
  return -np.sum(probabilities*np.log2(probabilities))

a = np.array([1., 1., 1., 3., 3., 2.])
b = np.array([1., 1., 1., 3., 3., 3.])
c = np.array([1., 1., 1., 1., 1., 1.])

print(f"entropy(a): {entropy(a)}")
print(f"entropy(b): {entropy(b)}")
print(f"entropy(c): {entropy(c)}")

With the output being the following:

entropy(a): 1.4591479170272446
entropy(b): 1.0
entropy(c): -0.0

However, I also need to calculate the derivative over dx:

d entropy / dx

This is not an easy task since the main formula

-np.sum(probabilities*np.log2(probabilities))

takes in probabilities, not x values, therefore it is not clear how to differentiate over dx.

Does anyone have an idea on how to do this?


Solution

  • One way to solve this is to use finite differences to compute the derivative numerically.

    In this context, we can define a small constant to help us compute the numerical derivative. This function takes a one-argument function and computes its derivative for input x:

    ε = 1e-12
    def derivative(f, x):
        return (f(x + ε) - f(x)) / ε
    

    To make our work easier, let us define a function that computes the innermost operation of the entropy:

    def inner(x):
        return x * np.log2(x)
    

    Recall that the derivative of the sum is the sum of derivatives. Therefore, the real derivative computation takes place in the inner function we just defined.

    So, the numerical derivative of the entropy is:

    def numerical_dentropy(X):
        _, frequencies = np.unique(X, return_counts=True)
        probabilities = frequencies / X.shape[0]
        return -np.sum([derivative(inner, p) for p in probabilities])
    

    Can we do better? Of course we can! The key insight here is the product rule: (f g)' = fg' + gf', where f=x and g=np.log2(x). (Also notice that d[log_a(x)]/dx = 1/(x ln(a)).)

    So, the analytical entropy can be computed as:

    import math
    def dentropy(X):
        _, frequencies = np.unique(X, return_counts=True)
        probabilities = frequencies / X.shape[0]
        return -np.sum([(1/math.log(2, math.e) + np.log2(p)) for p in probabilities])
    

    Using the sample vectors for testing, we have:

    a = np.array([1., 1., 1., 3., 3., 2.])
    b = np.array([1., 1., 1., 3., 3., 3.])
    c = np.array([1., 1., 1., 1., 1., 1.])
    
    print(f"numerical d[entropy(a)]: {numerical_dentropy(a)}")
    print(f"numerical d[entropy(b)]: {numerical_dentropy(b)}")
    print(f"numerical d[entropy(c)]: {numerical_dentropy(c)}")
    
    print(f"analytical d[entropy(a)]: {dentropy(a)}")
    print(f"analytical d[entropy(b)]: {dentropy(b)}")
    print(f"analytical d[entropy(c)]: {dentropy(c)}")
    

    Which, when executed, gives us:

    numerical d[entropy(a)]: 0.8417710972707937
    numerical d[entropy(b)]: -0.8854028621385623
    numerical d[entropy(c)]: -1.4428232973189605
    analytical d[entropy(a)]: 0.8418398787754222
    analytical d[entropy(b)]: -0.8853900817779268
    analytical d[entropy(c)]: -1.4426950408889634
    

    As a bonus, we can test whether this is correct with an automatic differentiation library:

    import torch
    
    a, b, c = torch.from_numpy(a), torch.from_numpy(b), torch.from_numpy(c)
    
    def torch_entropy(X):
        _, frequencies = torch.unique(X, return_counts=True)
        frequencies = frequencies.type(torch.float32)
        probabilities = frequencies / X.shape[0]
        probabilities.requires_grad_(True)
        return -(probabilities * torch.log2(probabilities)).sum(), probabilities
    
    for v in a, b, c:
        h, p = torch_entropy(v)
        print(f'torch entropy: {h}')
        h.backward()
        print(f'torch derivative: {p.grad.sum()}')
    

    Which gives us:

    torch entropy: 1.4591479301452637
    torch derivative: 0.8418397903442383
    torch entropy: 1.0
    torch derivative: -0.885390043258667
    torch entropy: -0.0
    torch derivative: -1.4426950216293335