Search code examples
pythonmatplotlibdendrogram

Connecting dendrograms in matplotlib


I'm currently working on a project where I need to visualize hierarchical clustering dendrograms using matplotlib in Python. I've managed to generate dendrograms for two datasets (X1 and X2) and plotted them side by side. However, I'm facing difficulties in connecting the dendrograms together.

Previously I have asked Connecting tips between dendrograms in side-by-side subplots but it was using plotly.

I've tried extracting the tip labels from both dendrograms and sorting them, but when I attempt to connect them, the connections seem to be misplaced or missing.

Here's the code snippet I'm using:

import matplotlib.pyplot as plt
import scipy.cluster.hierarchy as hierarchy
import numpy as np

# Sample data for X1 , X2
np.random.seed(20240406)
X1 = np.random.rand(10, 12)
X2 = np.random.rand(10, 12)
names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark', 'Alice', 'Charlie', 'Rob', 'Lisa', 'Lily']

# Plotting
fig, axes = plt.subplots(1, 3, figsize=(12, 6))

# Generate dendrogram structure for X1
Z1 = hierarchy.linkage(X1, method='complete')
dn1 = hierarchy.dendrogram(Z1, ax=axes[0], orientation='left', labels=names)
axes[0].set_title('Left Dendrogram for X1')
axes[0].set_xlabel('Distance')

# Generate dendrogram structure for X2
Z2 = hierarchy.linkage(X2, method='complete')
dn2 = hierarchy.dendrogram(Z2, ax=axes[2], orientation='right', labels=names)
axes[2].set_title('Right Dendrogram for X2')
axes[2].set_xlabel('Distance')

# Extract leaves and match them with names
leaves_left = dn1['leaves']
leaves_right = dn2['leaves']

# Use leaves and names to create connections
connections = []
for i in range(len(leaves_left)):
    left_name = names[leaves_left[i]]
    try:
        right_index = names.index(left_name)
    except ValueError:
        continue  # Skip to the next iteration if the name is not found
    connections.append((0, 1, i , right_index))


# Draw connections
for left, right, y_left, y_right in connections:
  axes[1].plot([left, right], [y_left, y_right], 'k-', alpha=0.5)

# Customize the third plot for connections
axes[1].set_title('Connections')
axes[1].set_xlabel('Connection')
axes[1].set_xlim(0, 1)  # Set limits for connection plot
axes[1].set_ylim(-0.5, len(names) - 0.5)  # Adjust y-axis limits for connections
axes[1].axis('off')

plt.tight_layout()
plt.show()

enter image description here

but you can see in image its tip not linking properly, I'm aiming to connect the dendrograms using lines that connect each label from X1 to its corresponding label in X2. How to properly implement this connection between the dendrograms?


Solution

  • I haven't followed how the dendrogram's leaves are ordered, but one option might be to get hold of the y tick labels on each plot and match those up. I also passed clip_on=False in the call to plot to give a more rounded appearance to the ends of the lines (rather than being clipped by the axes).

    import matplotlib.pyplot as plt
    import scipy.cluster.hierarchy as hierarchy
    import numpy as np
    
    # Sample data for X1 , X2
    np.random.seed(20240406)
    X1 = np.random.rand(10, 12)
    X2 = np.random.rand(10, 12)
    names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark', 'Alice', 'Charlie', 'Rob', 'Lisa', 'Lily']
    
    # Plotting
    fig, axes = plt.subplots(1, 3, figsize=(12, 6))
    
    # Generate dendrogram structure for X1
    Z1 = hierarchy.linkage(X1, method='complete')
    dn1 = hierarchy.dendrogram(Z1, ax=axes[0], orientation='left', labels=names)
    axes[0].set_title('Left Dendrogram for X1')
    axes[0].set_xlabel('Distance')
    
    # Generate dendrogram structure for X2
    Z2 = hierarchy.linkage(X2, method='complete')
    dn2 = hierarchy.dendrogram(Z2, ax=axes[2], orientation='right', labels=names)
    axes[2].set_title('Right Dendrogram for X2')
    axes[2].set_xlabel('Distance')
    
    # Get hold of the labels for each dendrogram    
    left_labels = axes[0].get_yticklabels()
    right_labels = axes[2].get_yticklabels()
    right_names = [label.get_text() for label in right_labels]
    
    # Use label positions and texts to create connections
    connections = []
    for i, left_label in enumerate(left_labels):
        left_name = left_label.get_text()
        try:
            right_index = right_names.index(left_name)
        except ValueError:
            continue  # Skip to the next iteration if the name is not found
        connections.append((0, 1, left_label.get_position()[1] , right_labels[right_index].get_position()[1]))
    
    # Draw connections
    for left, right, y_left, y_right in connections:
      axes[1].plot([left, right], [y_left, y_right], 'k-', alpha=0.5, clip_on=False)
    
    # Customize the third plot for connections
    axes[1].set_title('Connections')
    axes[1].set_xlabel('Connection')
    axes[1].set_xlim(0, 1)  # Set limits for connection plot
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()
    

    enter image description here