Search code examples
pythonscikit-learnnearest-neighbor

How to get centroids from Ball Tree?


According to the scikit-learn documentation, their sklearn.neighbors.BallTree class

recursively divides the data into nodes defined by a centroid C and radius r, such that each point in the node lies within the hyper-sphere defined by C and r.

Is there a way, given a BallTree instance, to extract those centroids without recomputing them? The get_arrays() method exposes the radii and the datapoints contained by each node, but does not expose the centroids. In Euclidean space, one could easily compute centroids by averaging all of the datapoints in each node, but this becomes harder in other metrics. Furthermore, it doesn't seem necessary that a user should have to perform this computation if the BallTree instance has already done so internally.


Solution

  • The last array in the result of the get_arrays method is the "node bounds" array, which contains the centroids:

    node_bounds : the [* x n_nodes x n_features] array containing the node bound information. For ball tree, the first dimension is 1, and each row contains the centroid of the node. [...]

    [source]