Search code examples
matlabcluster-analysisgmm

How can I color-label the cluster data after GMM is fitted?


I am trying to do some labelling on cluster data following GMMs but haven't found a way to do it.

Let me explain:

I have some x,y data pairs into a X=30000x2 array. In reality the array contains the data from different sources (known) and each source has the same number of data (So source 1 has 500 (x,y), source 2 500 (x,y) and so on and all of them are appended into the X array above).

I have fitted a GMM on X. Cluster results are fine and as expected but now that the data are clustered I want to be able to color code them based on their initial origin.

So let's say I want to shown in black the data points of source 1 that are in cluster 2.

Is that possible?

Example: In the original array we have three sources for the data. Source 1 is data from 1-10000, source 2 10001-20000 and source 3 20001-30000.

After GMM fitting and clustering I have clustered my data as per figure 1 and I got two clusters. The red colour in all of them is irrelevant.

I want to modify the color of the data points in cluster 2 based on their index and the original array X. E.g., if a data point belongs to cluster 2 (clusteridx=2), then I want to check to which source it belongs and then color it and label it accordingly. So that you can tell from which source are the data points in cluster 2 as shown in the second figure.

Original clusters

enter image description here

Desired labelling

enter image description here


Solution

  • You could add a "source_id" column and then plot through a loop on that. For example:

    % setup fake data
    source1 = rand(10,2);
    source2 = rand(15,2);
    source3 = rand(8,2);
    % end setup
    
    % append column with source_id (you could do this in a loop if you have many sources)
    source1 = [source1, repmat(1, length(source1), 1)];
    source2 = [source2, repmat(2, length(source2), 1)];
    source3 = [source3, repmat(3, length(source3), 1)];
    
    mytable = array2table([source1; source2; source3]);
    mytable.Properties.VariableNames = {'X' 'Y' 'source_id'};
    
    figure
    hold on;
    for ii = 1:max(mytable.source_id)
        rows = mytable.source_id==ii;
        x = mytable.X(rows);
        y = mytable.Y(rows);
        label = char(strcat('Source ID =', {' '}, num2str(ii)));
        mycolor = rand(1,3); 
        scatter(x,y, 'MarkerEdgeColor', mycolor, 'MarkerFaceColor', mycolor, 'DisplayName', label);
    end
    set(legend, 'Location', 'best')
    

    enter image description here