Search code examples
pythonpandasdataframescatter-plot

How to best plot a pandas dataframe as a figure?


I have a dataframe as follows:

layer   bit-idx exponent    accuracy
conv2d  0       0           0.683099
conv2d  1       0           0.683099
conv2d  2       0           0.683099
conv2d  3       0           0.683099
conv2d  0       1           0.682403
conv2d  1       1           0.668917
conv2d  2       1           0.472103
conv2d  3       1           0.668600
dense   0       0           0.683107
dense   1       0           0.683101
dense   2       0           0.683020
dense   3       0           0.513099
dense   0       1           0.683107
dense   1       1           0.683101
dense   2       1           0.483020
dense   3       1           0.553099

my. first try on the hole dataframe is as follows:

plt.grid()
ax = sns.scatterplot(data=df_bi, x='layer', y=df_bi['accuracy']*100, hue='index', alpha=1, s=100, palette='RdBu', legend=True)
sns.lineplot(data=df_wi, x='layer', y=68.3099, linestyle='--', color='red', linewidth=1, ax=ax)
plt.ylim(10,80)

and I get the following results:

enter image description here

How can I possibly plot this dataframe as a scatterplot where the X-axis represents layers, and each tick is split into two columns for exponent=0 and exponent=1, and Y-axis representing accuracy?


Solution

  • It seems you are thinking of a swarmplot, not a scatterplot. The usage is as follows:

    import seaborn as sns
    sns.swarmplot(data='df', x='layer', y='accuracy', hue='exponent', dodge=True)
    

    "hue" changes the color dependent on the exponent, "dodge" makes sure they are non-overlapping such that you have "different columns". Hope that helps, cheers.