Search code examples
pythonmatplotlibscikit-learndata-sciencedata-visualization

How to colour a scatter plot of a 2d data frame (reduced using tsne/umap) according to label information in index ({country,year}) of different frame


The problem I am facing is that I have a dataframe - sector_features_ which looks like this: 1

After running tsne on it I then have a 2d df which I plot with a scatter graph. The problem is that I don't know how to color the scatter points with the original label information contained in the index that seen in picture 1, which is a tuple which contains the {country} and {year} the observation belongs to. I would ideally like to color according to country only or year only to see how this changes the visualisation.

The data frame containing the reduced data (tsne) looks like 2

I am using matplotlib and seaborn, but have seen some solutions using altair and I am not sure how to proceed.

the imports are:

import pandas as pd
import numpy as np
import random as rd
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder
from sklearn import preprocessing 
from sklearn.manifold import TSNE 
import matplotlib.pyplot as plt 
import seaborn as sns

Solution

  • Looking at your screenshot, I am guessing you have a multiindex data frame. Using an example dataset :

    import pandas as pd
    import numpy as np
    import seaborn as sns
    from sklearn.manifold import TSNE
    from sklearn.datasets import make_blobs
    
    data, _ = make_blobs(n_samples=100,n_features=5,centers=4,cluster_std=3.5)
    
    data = pd.DataFrame(data)
    data['country'] = np.repeat(['A','B','C','D'],25)
    data['year'] = np.repeat(np.arange(1,26),4)
    data = data.set_index(['country','year'])
    

    The index looks like this:

    data.index[:10]
    
    MultiIndex([('A', 1),
                ('A', 1),
                ('A', 1),
                ('A', 1),
                ('A', 2),
                ('A', 2),
                ('A', 2),
                ('A', 2),
                ('A', 3),
                ('A', 3)],
               names=['country', 'year'])
    

    Perform tsne :

    X_embedded = TSNE(n_components=2,init='pca',learning_rate='auto').fit_transform(data.values)
    

    You can basically do a reset_index() and extract the 2 index columns, in my case it will be ['country','year'], and concat with your tsne results :

    tsne_result = pd.concat([
        data.reset_index()[['country','year']],
        pd.DataFrame(X_embedded,columns=['tsne1','tsne2'])
    ],axis=1)
    

    And plot :

    sns.scatterplot(data = tsne_result, x = "tsne1", y = "tsne2",hue = "country")
    

    enter image description here