Search code examples

How to calulate KL-Divergence using kernel densities in Python?

I have a set of (100, 3) Jaccarad-Index values, where the '3' denotes the number of model-input parameters. I want to compute the KL-Divergence for each combination of features using the sampling distribution of Jaccard indices. I'm truly at a loss on how to compute (numerically) probability distributions (or specifically kernel densities) for each column of Jaccard Index values with the hope of computing the KL-Divergence. At this point, I feel as if I don't know what I'm doing or, at least, there's a gap in understanding (wherever it may be). If someone can 'shepherd' me in the right direction, I'd greatly appreciate it. Below, I provided the array of Jaccard Index values. I apologize if this isn't the correct format for displaying a dataset this large (idk any other way), but I really need help.

np.array([[0.12    , 1.      , 0.272727],
       [0.074074, 0.882353, 0.208333],
       [0.16    , 0.933333, 0.25    ],
       [0.192308, 0.888889, 0.148148],
       [0.238095, 0.473684, 0.45    ],
       [0.26087 , 0.933333, 0.26087 ],
       [0.333333, 0.75    , 0.3     ],
       [0.2     , 0.722222, 0.166667],
       [0.318182, 0.823529, 0.25    ],
       [0.076923, 0.631579, 0.217391],
       [0.142857, 0.722222, 0.26087 ],
       [0.115385, 0.555556, 0.125   ],
       [0.136364, 0.875   , 0.3     ],
       [0.166667, 0.722222, 0.318182],
       [0.115385, 0.764706, 0.166667],
       [0.166667, 1.      , 0.130435],
       [0.125   , 1.      , 0.08    ],
       [0.272727, 0.666667, 0.181818],
       [0.222222, 0.571429, 0.142857],
       [0.26087 , 0.5     , 0.304348],
       [0.076923, 1.      , 0.333333],
       [0.470588, 0.8     , 0.25    ],
       [0.125   , 0.611111, 0.285714],
       [0.083333, 0.625   , 0.2     ],
       [0.26087 , 0.866667, 0.16    ],
       [0.181818, 0.555556, 0.25    ],
       [0.304348, 0.9375  , 0.318182],
       [0.166667, 0.928571, 0.388889],
       [0.192308, 0.684211, 0.304348],
       [0.26087 , 0.866667, 0.190476],
       [0.2     , 0.555556, 0.3     ],
       [0.076923, 0.705882, 0.16    ],
       [0.115385, 1.      , 0.444444],
       [0.208333, 0.611111, 0.2     ],
       [0.2     , 0.833333, 0.333333],
       [0.125   , 0.555556, 0.333333],
       [0.166667, 0.681818, 0.103448],
       [0.166667, 1.      , 0.166667],
       [0.173913, 0.764706, 0.153846],
       [0.208333, 0.6     , 0.16    ],
       [0.2     , 0.444444, 0.2     ],
       [0.153846, 0.736842, 0.166667],
       [0.074074, 0.777778, 0.4     ],
       [0.071429, 0.9375  , 0.16    ],
       [0.153846, 0.705882, 0.173913],
       [0.111111, 1.      , 0.130435],
       [0.142857, 0.789474, 0.148148],
       [0.166667, 0.722222, 0.148148],
       [0.142857, 0.888889, 0.16    ],
       [0.285714, 0.526316, 0.125   ],
       [0.103448, 0.6     , 0.107143],
       [0.148148, 0.777778, 0.192308],
       [0.137931, 0.714286, 0.2     ],
       [0.181818, 0.8125  , 0.285714],
       [0.136364, 0.388889, 0.125   ],
       [0.227273, 0.588235, 0.25    ],
       [0.136364, 0.625   , 0.238095],
       [0.157895, 0.5625  , 0.095238],
       [0.3     , 0.473684, 0.181818],
       [0.130435, 1.      , 0.285714],
       [0.318182, 0.75    , 0.388889],
       [0.24    , 0.6     , 0.26087 ],
       [0.173913, 0.55    , 0.26087 ],
       [0.173913, 0.684211, 0.217391],
       [0.111111, 1.      , 0.208333],
       [0.115385, 0.684211, 0.2     ],
       [0.227273, 0.8125  , 0.368421],
       [0.227273, 1.      , 0.130435],
       [0.24    , 0.7     , 0.192308],
       [0.173913, 0.764706, 0.272727],
       [0.304348, 0.8125  , 0.333333],
       [0.291667, 0.666667, 0.347826],
       [0.107143, 0.6     , 0.208333],
       [0.192308, 0.842105, 0.2     ],
       [0.217391, 0.714286, 0.12    ],
       [0.217391, 1.      , 0.076923],
       [0.190476, 0.764706, 0.208333],
       [0.033333, 0.571429, 0.153846],
       [0.208333, 0.941176, 0.272727],
       [0.333333, 0.722222, 0.153846],
       [0.153846, 0.823529, 0.217391],
       [0.136364, 0.8     , 0.529412],
       [0.1     , 0.714286, 0.185185],
       [0.142857, 0.785714, 0.333333],
       [0.173913, 0.526316, 0.25    ],
       [0.208333, 0.388889, 0.25    ],
       [0.16    , 0.6     , 0.16    ],
       [0.272727, 0.764706, 0.238095],
       [0.086957, 0.625   , 0.190476],
       [0.103448, 0.684211, 0.24    ],
       [0.238095, 0.444444, 0.173913],
       [0.35    , 0.875   , 0.227273],
       [0.12    , 0.823529, 0.238095],
       [0.28    , 0.722222, 0.178571],
       [0.208333, 0.944444, 0.185185],
       [0.2     , 0.611111, 0.208333],
       [0.095238, 0.473684, 0.12    ],
       [0.190476, 0.733333, 0.210526],
       [0.192308, 0.75    , 0.217391],
       [0.12    , 0.764706, 0.217391]])

I have used Seaborn's 'displot' (with "kind='kde'") to return plots for each column of Jaccard Index values: KDE for features 0, 1, and 2. I have also used scipy.gaussian_kde, with the hope that it'll return probabilities but to no avail (of course), and I've tried plt.hist from matplotlib.pyplot.


  • Ok, to get you started I've extracted some code from NPEET library, and it produced something reasonable for one case (though other cases gave me log(0) exception), z is your samples array

    Python 3.10, Windows 10 x64

    import numpy as np
    from sklearn.neighbors import BallTree, KDTree
    L = 100
    a = z[:,0].reshape((L,1))
    b = z[:,1].reshape((L,1))
    c = z[:,2].reshape((L,1))
    def query_neighbors(tree, x, k):
        return tree.query(x, k = k + 1)[0][:, k]
    def build_tree(points):
        if points.shape[1] >= 20:
            return BallTree(points, metric="chebyshev")
        return KDTree(points, metric="chebyshev")
    def kldiv(x, xp, k=3, base=2):
        """KL Divergence between p and q for x~p(x), xp~q(x)
        x, xp should be a list of vectors, e.g. x = [[1.3], [3.7], [5.1], [2.4]]
        if x is a one-dimensional scalar and we have four samples
        assert k < min(len(x), len(xp)), "Set k smaller than num. samples - 1"
        assert len(x[0]) == len(xp[0]), "Two distributions must have same dim."
        x, xp = np.asarray(x), np.asarray(xp)
        x, xp = x.reshape(x.shape[0], -1), xp.reshape(xp.shape[0], -1)
        d = len(x[0])
        n = len(x)
        m = len(xp)
        const = np.log(m) - np.log(n - 1)
        tree = build_tree(x)
        treep = build_tree(xp)
        nn = query_neighbors(tree, x, k)
        nnp = query_neighbors(treep, x, k - 1)
        kld = (const + d * (np.log(nnp).mean() - np.log(nn).mean())) / np.log(base)
        return kld
    ab_kldiv = kldiv(a,b, 6, 2)