Search code examples
pythonbokeh

Can bokeh create facet_grid plot?


In bokeh, we could create categorical coordinates plot. https://docs.bokeh.org/en/latest/docs/user_guide/categorical.html

enter image description here

Can I plot something like ggplot's facet_grid? Having two level X axis in different location. (top and bottom) Thanks. enter image description here


Solution

  • It's still not ideal:

    • You cannot properly set the background color or minor titles (can be mitigated by switching to Divs instead of Titles with manual positioning)
    • Plots don't have the same width/height due to the titles and axes (I don't know of any way to fix it)
    from random import random, randint
    
    from bokeh.io import show
    from bokeh.models import Div, Title
    from bokeh.plotting import figure
    from bokeh import layouts
    
    major_x = ['Fri', 'Sat', 'Sun', 'Thur']
    major_y = ['Female', 'Male']
    
    minor_x = (0, 50)
    minor_y = (0, 1)
    
    height = 300
    width = 300
    
    
    def generate_datum(start, end):
        return random() * (end - start) + start
    
    
    def generate_data():
        n = randint(10, 100)
        return dict(x=[generate_datum(*minor_x) for _ in range(n)],
                    y=[generate_datum(*minor_y) for _ in range(n)])
    
    
    full_data = {(x, y): generate_data() for x in major_x for y in major_y}
    
    
    def pad_range(start, end):
        d = (end - start) * 0.1
        return start - d, end + d
    
    
    def add_title(p, text, position):
        t = Title(text=text, align='center')
        p.add_layout(t, position)
    
    
    column = []
    for y in major_y:
        row = []
        for x in major_x:
            p = figure(x_range=pad_range(*minor_x), y_range=pad_range(*minor_y),
                       toolbar_location=None, tools='hover', width=width, height=height)
            p.xaxis.visible = (y == major_y[-1])
            p.yaxis.visible = (x == major_x[0])
            data = full_data[(x, y)]
            p.circle(x=data['x'], y=data['y'])
            if y == major_y[0]:
                add_title(p, x, 'above')
            if x == major_x[-1]:
                add_title(p, y, 'right')
            row.append(p)
        column.append(layouts.row(row))
    
    major_y_label = f'''
    <div style="display: flex; align-items: center; height: {height * len(major_y)}px;">
      <div style="writing-mode: tb-rl; transform: rotate(-180deg); font-size: 1.5em;">
        tip/total_bill
      </div>
    </div>
    '''
    
    major_x_label = f'''
    <div style="display: flex; justify-content: center; width: {width * len(major_x)}px;">
      <div style="font-size: 1.5em;">
        total_bill
      </div>
    </div>
    '''
    
    full_plot = layouts.grid([[Div(text=major_y_label), layouts.column(column)],
                              [None, Div(text=major_x_label)]])
    
    show(full_plot)
    

    plot