Search code examples
python-3.xmatplotlibplotbokehhexagonal-tiles

How to translate hexagon matplotlib plot to an interactive bokeh plot?


I have been working with the excellent minisom package and want to plot interactively the hexagonal map that reflects the results of the self-organising maps training process. There's already a code example that does this statically using matplotlib but to do so interactively, I would like to use bokeh. This is where I am struggling.

This is the code to generate a simplified matplotlib example of what's already on the package page:

from minisom import MiniSom
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.patches import RegularPolygon
from matplotlib import cm

from bokeh.plotting import figure
from bokeh.io import save, show, output_file, output_notebook

output_notebook()

data = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt', 
                    names=['area', 'perimeter', 'compactness', 'length_kernel', 'width_kernel',
                   'asymmetry_coefficient', 'length_kernel_groove', 'target'], sep='\t+')
t = data['target'].values
data = data[data.columns[:-1]]
# data normalisation
data = (data - np.mean(data, axis=0)) / np.std(data, axis=0)
data = data.values

# initialisation and training
som = MiniSom(15, 15, data.shape[1], sigma=1.5, learning_rate=.7, activation_distance='euclidean',
              topology='hexagonal', neighborhood_function='gaussian', random_seed=10)

som.train(data, 1000, verbose=True)

# plot hexagonal topology
f = plt.figure(figsize=(10,10))
ax = f.add_subplot(111)

ax.set_aspect('equal')

xx, yy = som.get_euclidean_coordinates()
umatrix = som.distance_map()
weights = som.get_weights()

for i in range(weights.shape[0]):
    for j in range(weights.shape[1]):
        wy = yy[(i, j)]*2/np.sqrt(3)*3/4
        hex = RegularPolygon((xx[(i, j)], wy), numVertices=6, radius=.95/np.sqrt(3),
                      facecolor=cm.Blues(umatrix[i, j]), alpha=.4, edgecolor='gray')
        ax.add_patch(hex)
for x in data:
    w = som.winner(x) 
    # place a marker on the winning position for the sample xx
    wx, wy = som.convert_map_to_euclidean(w) 
    wy = wy * 2 / np.sqrt(3) * 3 / 4
    plt.plot(wx, wy, markerfacecolor='None',
             markeredgecolor='black', markersize=12, markeredgewidth=2)

plt.show()

matplotlib hexagonal topology plot

I've tried to translate the code into bokeh but the resulting hex plot (to me, primitively) looks like it needs to be flipped vertically onto the points and for the skew to be straightened out.

tile_centres_column = []
tile_centres_row = []
colours = []
for i in range(weights.shape[0]):
    for j in range(weights.shape[1]):
        wy = yy[(i, j)] * 2 / np.sqrt(3) * 3 / 4
        tile_centres_column.append(xx[(i, j)])
        tile_centres_row.append(wy)
        colours.append(cm.Blues(umatrix[i, j]))
        
weight_x = []
weight_y = []
for x in data:
    w = som.winner(x)
    wx, wy = som.convert_map_to_euclidean(xy=w)
    wy = wy * 2 / np.sqrt(3) * 3/4
    weight_x.append(wx)
    weight_y.append(wy)

# plot hexagonal topology
plot = figure(plot_width=800, plot_height=800,
              match_aspect=True) 
plot.hex_tile(q=tile_centres_column, r=tile_centres_row, 
              size=.95 / np.sqrt(3),
              color=colours,
              fill_alpha=.4,
              line_color='black')
plot.dot(x=weight_x, y=weight_y,
         fill_color='black',
         size=12)

show(plot)

bokeh hexagonal topology plot

How can I translate this into a bokeh plot?


Solution

  • Found out how to do it after reaching out to the minisom package author for help. Complete code available here.

    from bokeh.colors import RGB
    from bokeh.io import curdoc, show, output_notebook
    from bokeh.transform import factor_mark, factor_cmap
    from bokeh.models import ColumnDataSource, HoverTool
    from bokeh.plotting import figure, output_file
    
    hex_centre_col, hex_centre_row = [], []
    hex_colour = []
    label = []
    
    # define labels
    SPECIES = ['Kama', 'Rosa', 'Canadian']
    
    for i in range(weights.shape[0]):
        for j in range(weights.shape[1]):
            wy = yy[(i, j)] * 2 / np.sqrt(3) * 3 / 4
            hex_centre_col.append(xx[(i, j)])
            hex_centre_row.append(wy)
            hex_colour.append(cm.Blues(umatrix[i, j]))
    
    weight_x, weight_y = [], []
    for cnt, i in enumerate(data):
        w = som.winner(i)
        wx, wy = som.convert_map_to_euclidean(xy=w)
        wy = wy * 2 / np.sqrt(3) * 3 / 4
        weight_x.append(wx)
        weight_y.append(wy)
        label.append(SPECIES[t[cnt]-1])
        
    # convert matplotlib colour palette to bokeh colour palette
    hex_plt = [(255 * np.array(i)).astype(int) for i in hex_colour]
    hex_bokeh = [RGB(*tuple(rgb)).to_hex() for rgb in hex_plt]
    
    output_file("resulting_images/som_seed_hex.html")
    
    # initialise figure/plot
    fig = figure(title="SOM: Hexagonal Topology",
                 plot_height=800, plot_width=800,
                 match_aspect=True,
                 tools="wheel_zoom,save,reset")
    
    # create data stream for plotting
    source_hex = ColumnDataSource(
        data = dict(
            x=hex_centre_col,
            y=hex_centre_row,
            c=hex_bokeh
        )
    )
    
    source_pages = ColumnDataSource(
        data=dict(
            wx=weight_x,
            wy=weight_y,
            species=label
        )
    )
    
    # define markers
    MARKERS = ['diamond', 'cross', 'x']
    
    # add shapes to plot
    fig.hex(x='y', y='x', source=source_hex,
            size=100 * (.95 / np.sqrt(3)),
            alpha=.4,
            line_color='gray',
            fill_color='c')
    
    fig.scatter(x='wy', y='wx', source=source_pages, 
                legend_field='species',
                size=20, 
                marker=factor_mark(field_name='species', markers=MARKERS, factors=SPECIES),
                color=factor_cmap(field_name='species', palette='Category10_3', factors=SPECIES))
    
    # add hover-over tooltip
    fig.add_tools(HoverTool(
        tooltips=[
            ("label", '@species'),
            ("(x,y)", '($x, $y)')],
        mode="mouse", 
        point_policy="follow_mouse"
    ))
    
    show(fig)