Search code examples
pythonseaborn

Remove column name from chart title in sns.relplot and keep only horizontal grid line


I am trying to create a diverging dot plot with python and I am using seaborn relplot to do the small multiples with one of the columns.

The datasouce is MakeoverMonday 2018w18: MOM2018w48

I got this far with this code:

sns.set_style("whitegrid")
g=sns.relplot(x=cost ,y=city, col=item, s=120, size = cost, hue = cost, col_wrap= 2)
sns.despine(left=True, bottom=True)

which generates this: relplot dot plot

So, far, so good. Now, I want only horizontal gridlines, sort it and get rid of the column name ('item'=) in the small multiple charts. Any ideas?

This is what I am trying to recreate: enter image description here


Solution

  • I can get reasonably close by just using Matplotlib (no seaborn). Matplotlib is sometimes a little more low-level, but this also allows a lot of customization to be done.

    There's definitely still some hacky-things going on to mimic the appearance of your example image as close as possible. Perhaps there are more elegant ways to get there.

    Having a DataFrame structured as:
    enter image description here

    Creating the plot with:

    colors = {"Taxi": "C2", "Club entry": "C0", "Big Mac": "C3"}
    
    fig, axs = plt.subplot_mosaic(
        [["Taxi", "Club entry", "Big Mac"]], figsize=(12, 4.5), sharey=True, sharex=True, 
        facecolor="w", dpi=86,
    )
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.1)
    
    for name in axs:
        
        # subset dataframe
        df_subset = df.query(f"Item == '{name}'")
        
        axs[name].set_title(name, size=14, alpha=.5)
        line, = axs[name].plot(
            "Cost", "City", "o", data=df_subset,
            ms=24, color=colors[name],
        )
        
        # add value inside the circle (marker)
        for i, cost_value in enumerate(df_subset["Cost"].to_list()):
            axs[name].text(
                cost_value, i, f"${cost_value:1.0f}", ha="center", va="center",
                weight="bold", color="w", alpha=.8, size=10,
            )        
            
    for i, ax in enumerate(axs.values()):
        ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(10))
        ax.xaxis.set_major_formatter(mpl.ticker.StrMethodFormatter("${x:1.0f}"))
        ax.grid(axis="y", linewidth=3, alpha=0.5)
        ax.grid(axis="x", linewidth=0.5, alpha=0.5)
        ax.tick_params(axis='both', which='both', length=0, labelcolor="#00000077")
        ax.xaxis.set_ticks_position("top")
        
        for sp in ax.spines:
            ax.spines[sp].set_visible(False)
            
        if i == 0: # only applies to the left axis
            ax.set_yticklabels([]) # hide default labels
            ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(range(len(df_subset))))
            
            # add yticklabels manually (for alignment...)
            yticklabels = [(f"{x:<20s}", f"${y:<3.0f}") for x,y in list(df_subset[["City", "Total Cost"]].to_records(index=False))]
    
            for ypos, (city_name, total_cost) in enumerate(yticklabels):
                # negative x-offset is in units "Total Cost"
                ax.text(-18, ypos, city_name, ha="left", va="center", alpha=.5)
                ax.text(-5, ypos, total_cost, ha="left", va="center", alpha=.9)
    

    enter image description here