Search code examples
pythonpandasaltairplot-annotationscorrespondence-analysis

How to annotate scatter points plotted with the prince library


I am using the library prince in order to perform Correspondence Analysis

from prince import CA

My contingency table dummy_contingency looks like this:

{'v1': {'0': 4.479591836734694,
  '1': 75.08163265306122,
  '2': 1.1020408163265305,
  '3': 5.285714285714286,
  '4': 14.244897959183673,
  '5': 0.0,
  '6': 94.06122448979592,
  '7': 0.5102040816326531,
  '8': 87.62244897959184,
  '9': 16.102040816326532},
 'v2': {'0': 6.142857142857143,
  '1': 24.653061224489797,
  '2': 0.3979591836734694,
  '3': 2.63265306122449,
  '4': 18.714285714285715,
  '5': 0.0,
  '6': 60.92857142857143,
  '7': 1.030612244897959,
  '8': 71.73469387755102,
  '9': 14.76530612244898},
 'v3': {'0': 3.642857142857143,
  '1': 21.551020408163264,
  '2': 0.8061224489795918,
  '3': 2.979591836734694,
  '4': 14.5,
  '5': 0.030612244897959183,
  '6': 39.60204081632653,
  '7': 0.7551020408163265,
  '8': 71.89795918367346,
  '9': 11.571428571428571},
 'v4': {'0': 6.1020408163265305,
  '1': 25.632653061224488,
  '2': 0.6938775510204082,
  '3': 3.9285714285714284,
  '4': 21.581632653061224,
  '5': 0.22448979591836735,
  '6': 10.704081632653061,
  '7': 0.8469387755102041,
  '8': 71.21428571428571,
  '9': 12.489795918367347}}

Chi Square Test reveals dependence:

Chi-square statistic: 69.6630377155341
p-value: 1.2528156966101567e-05

Now I fit the data:

dummy_contingency = pd.DataFrame(dummy_contingency)

ca_dummy = CA(n_components=2)  # Number of components for correspondence analysis
ca_dummy.fit(dummy_contingency)

And the plot:

fig = ca_dummy.plot(
    X=dummy_contingency)
fig

Actual Output

How do I get the labelling done for this plot? The examples posted by others (Using mca package in Python) uses the function plot_coordinates() which has the option of putting the labels as well. But it looks like this function is no longer available with prince package and need to use the plot() function which does not have the option to put labels. Appreciate any help on this.

Edit: Example of an output with labels: Expected Output

The text for each of the points in the plot like "strawberries", "banana", "yogurt", etc. are the labels that I am looking for, which in this will be the index values 0,1,2,3,4,5,6,7,8,9 for the blue points and the column names "v1", "v2", "v3", "v4" for the orange points.


Solution

    • Adding the annotations to the scatter plot comes from How to do annotations with Altair, however, this doesn't include the necessary steps to plot the points from ca.
    • In order to annotate the correspondence-analysis plot, .column_coordinates and .row_coordinates must be extracted from the ca model. These are the points on the plot, not those from df.
    import pandas as pd
    import prince
    import altair as alt
    
    # convert the dictionary of data to a dataframe
    df = pd.DataFrame(dummy_contingency)
    
    # create the model
    ca = prince.CA()
    
    # fit the model
    ca = ca.fit(df)
    
    # extract the column coordinate dataframe, and change the column names
    cc = ca.column_coordinates(df).reset_index()
    cc.columns = ['name', 'x', 'y']
    
    # extract the row coordinates dataframe, and change the column names
    rc = ca.row_coordinates(df).reset_index()
    rc.columns = ['name', 'x', 'y']
    
    # combine the dataframes
    crc_df = pd.concat([cc, rc], ignore_index=True)
    
    # plot and annotate
    points = ca.plot(df)
    
    annot = alt.Chart(crc_df).mark_text(
        align='left',
        baseline='middle',
        fontSize = 20,
        dx = 7
    ).encode(
        x='x',
        y='y',
        text='name'
    )
    
    points + annot
    
    • Note the plot already has floating annotations, without adding annot.

    enter image description here


    • The annotations can also be added without combining cc and rc into a single dataframe.
    points = ca.plot(df)
    
    annot1 = alt.Chart(cc).mark_text(
        align='left',
        baseline='middle',
        fontSize = 20,
        dx = 7
    ).encode(
        x='x',
        y='y',
        text='name'
    )
    
    annot2 = alt.Chart(rc).mark_text(
        align='left',
        baseline='middle',
        fontSize = 20,
        dx = 7
    ).encode(
        x='x',
        y='y',
        text='name'
    )
    
    points + annot1 + annot2