Search code examples
apache-sparkdataframepysparkdistributed-computingk-means

Calculate cost of clustering in pyspark data frame


I have a data frame of million records and I have used pyspark ml .

KMeans to identify clusters , Now I want to find the within set sum of squares error((WSSSE) for the number of clusters that I have used.

my spark version is 1.6.0 and computeCost is not available in pyspark ml till spark 2.0.0 ,So I have to make it on my own .

I have used this method to find squared error but its taking long time to give me the output .I am looking for a better way to find WSSSE.

check_error_rdd = clustered_train_df.select(col("C5"),col("prediction"))

c_center = cluster_model.stages[6].clusterCenters()
check_error_rdd = check_error_rdd.rdd
print math.sqrt(check_error_rdd.map(lambda row:(row.C5- c_center[row.prediction])**2).reduce(lambda x,y: x+y) )

clustered_train_df is my original training data after fitting a ML PIPELINE,C5 is the featuresCol in KMeans.

check_error_rdd looks like below:

check_error_rdd.take(2)
Out[13]: 
[Row(C5=SparseVector(18046, {2398: 1.0, 17923: 1.0, 18041: 1.0, 18045: 0.19}), prediction=0),
 Row(C5=SparseVector(18046, {1699: 1.0, 17923: 1.0, 18024: 1.0, 18045: 0.91}), prediction=0)]

c_center is the list of cluster centres where every centre is a list of length 18046:

print len(c_center[1]) 
18046

Solution

  • I have computed the cost of k-means prior to version 2.0.

    As for the "slow"-ness you are mentioning: For 100m points, with 8192 centroids, it took me 50 minutes to compute the cost, with 64 executors and 202092 partitions, with 8G memory and 6 cores for every machine, in client mode.


    Quoting the ref:

    computeCost(rdd)

    Return the K-means cost (sum of squared distances of points to their nearest center) for this model on the given data.

    Parameters: rdd – The RDD of points to compute the cost on.

    New in version 1.4.0.

    If you somehow failing to use this because you have a DataFrame, just read: How to convert a DataFrame back to normal RDD in pyspark?


    As for your approach, I don't see anything bad with just a glance.