Search code examples
pythonscipyscikit-learnhierarchical-clustering

Convert distance pairs to distance matrix to use in hierarchical clustering


I am trying to convert a dictionary to a distance matrix that I can then use as an input to hierarchical clustering: I have as an input:

  • key: tuple of length 2 with the objects for which I have the distance
  • value: the actual distance value

    for k,v in obj_distances.items():
    print(k,v)
    

and the result is :

('obj1', 'obj2') 2.0 
('obj3', 'obj4') 1.58
('obj1','obj3') 1.95
('obj2', 'obj3') 1.80

My question is how can I convert this into a distance matrix that I can later user for clustering in scipy?


Solution

  • You say you will use scipy for clustering, so I assume that means you will use the function scipy.cluster.hierarchy.linkage. linkage accepts the distance data in "condensed" form, so you don't have to create the full symmetric distance matrix. (See, e.g., How does condensed distance matrix work? (pdist), for a discussion on the condensed form.)

    So all you have to do is get obj_distances.values() into a known order and pass that to linkage. That's what is done in the following snippet:

    from scipy.cluster.hierarchy import linkage, dendrogram
    
    obj_distances = {
        ('obj2', 'obj3'): 1.8,
        ('obj3', 'obj1'): 1.95,
        ('obj1', 'obj4'): 2.5,
        ('obj1', 'obj2'): 2.0,
        ('obj4', 'obj2'): 2.1,
        ('obj3', 'obj4'): 1.58,
    }
    
    # Put each key pair in a canonical order, so we know that if (a, b) is a key,
    # then a < b.  If this is already true, then the next three lines can be
    # replaced with
    #     sorted_keys, distances = zip(*sorted(obj_distances.items()))
    # Note: we assume there are no keys where the two objects are the same.
    keys = [sorted(k) for k in obj_distances.keys()]
    values = obj_distances.values()
    sorted_keys, distances = zip(*sorted(zip(keys, values)))
    
    # linkage accepts the "condensed" format of the distances.
    Z = linkage(distances)
    
    # Optional: create a sorted list of the objects.
    labels = sorted(set([key[0] for key in sorted_keys] + [sorted_keys[-1][-1]]))
    
    dendrogram(Z, labels=labels)
    

    The dendrogram:

    dendrogram