Search code examples
pythonplotlydata-visualizationpca

Plotly annotations too close to each other (not readable)


I have the following code that creates a plot for the loadings after PCA:

# Creating pipeline objects 
## PCA
pca = PCA(n_components=2)
## Create columntransformer to only scale a selected set of featues
categorical_ix = X.select_dtypes(exclude=np.number).columns

features = X.columns

ct = ColumnTransformer([
        ('encoder', OneHotEncoder(), categorical_ix),
        ('scaler', StandardScaler(), ['tenure', 'MonthlyCharges', 'TotalCharges'])
    ], remainder='passthrough')

# Create pipeline
pca_pipe = make_pipeline(ct,
                         pca)

# Fit data to pipeline
pca_result = pca_pipe.fit_transform(X)

loadings = pca.components_.T * np.sqrt(pca.explained_variance_)

fig = px.scatter(pca_result, x=0, y=1, color=customer_data_raw['Churn'])

for i, feature in enumerate(features):
    fig.add_shape(
        type='line',
        x0=0, y0=0,
        x1=loadings[i, 0],
        y1=loadings[i, 1]
    )
    fig.add_annotation(
        x=loadings[i, 0],
        y=loadings[i, 1],
        ax=0, ay=0,
        xanchor="center",
        yanchor="bottom",
        text=feature,
    )
fig.show()

Which produces the following output:

enter image description here

How can I make the labels for the loadings readable?

Edit: There are 19 features in X.

    gender  SeniorCitizen   Partner Dependents  tenure  PhoneService    MultipleLines   InternetService OnlineSecurity  OnlineBackup    DeviceProtection    TechSupport StreamingTV StreamingMovies Contract    PaperlessBilling    PaymentMethod   MonthlyCharges  TotalCharges
customerID                                                                          
7590-VHVEG  Female  0   Yes No  1   No  No phone service    DSL No  Yes No  No  No  No  Month-to-month  Yes Electronic check    29.85   29.85
5575-GNVDE  Male    0   No  No  34  Yes No  DSL Yes No  Yes No  No  No  One year    No  Mailed check    56.95   1889.50
3668-QPYBK  Male    0   No  No  2   Yes No  DSL Yes Yes No  No  No  No  Month-to-month  Yes Mailed check    53.85   108.15
7795-CFOCW  Male    0   No  No  45  No  No phone service    DSL Yes No  Yes Yes No  No  One year    No  Bank transfer (automatic)   42.30   1840.75
9237-HQITU  Female  0   No  No  2   Yes No  Fiber optic No  No  No  No  No  No  Month-to-month  Yes Electronic check    70.70   151.65

Solution

  • Based on your DataFrame, you have 19 features and you are adding them all at the location as your lines because ax and ay are both set to 0.

    We can change ax and ay as you loop through your features to rotate, which will hopefully make your annotations more distinguishable. This is based on converting from polar to cartesian coordaintes using x = r*cos(theta) and y = r*sin(theta) where theta goes through the values 0*360/19, 1*360/19, ... , 18*360/19. We will want to set the x and y-reference to be the x- and y-coordinates instead of paper coordinates and then set r=2 or some value comparable to your plot (this will make the annotation lines length 2 at longest)

    from math import sin, cos, pi
    r = 2 # this can be modified as needed, and is in units of the axis
    theta = 2*pi/len(features)
    
    for i, feature in enumerate(features):
        fig.add_shape(
            type='line',
            x0=0, y0=0,
            x1=loadings[i, 0],
            y1=loadings[i, 1]
        )
        fig.add_annotation(
            x=loadings[i, 0],
            y=loadings[i, 1],
            ax=r*sin(i*theta), 
            ay=r*cos(i*theta),
            axref="x",
            ayref="y",
            xanchor="center",
            yanchor="bottom",
            text=feature,
        )