Search code examples
scalaapache-sparkmachine-learningk-meansapache-spark-ml

Image segmentation using K-means for Spark in Scala


I'm following a tutorial (from a book) to implement K-Means algorithm for image segmentation using Spark. But the implementation was done in Python. I thought that would be good to implement it on Scala.

But i'm not managing to rebuild the image with the segmentation.

I'm trying this image, from The Cancer Imaging Archive (TCIA) (256x256):

enter image description here

Here it goes my code:

val spark = SparkSession.builder().appName("mriClass").master("local[*]").getOrCreate()
val mri_healthy_brain_image = "src/main/resources/datasets/clustering/data/mri-images-data/mri-healthy-brain.png"

val image_df = spark.read.format("image").load(mri_healthy_brain_image).select(col("image.*"))
image_df.show
image_df.printSchema
import spark.implicits._

val data = image_df.rdd.collect().map(f => f(5))

val data_array: Array[Byte] = data(0).asInstanceOf[Array[Byte]]

val transposed_df = spark.sparkContext.parallelize(data_array).map(f => Image(f)).toDF

transposed_df.show

val features_col = Array("data")
val vector_assembler = new VectorAssembler()
.setInputCols(features_col)
.setOutputCol("features")

val mri_healthy_brain_df = vector_assembler.transform(transposed_df).select("features")

val k = 5
val kmeans = new KMeans().setK(k).setSeed(12345).setFeaturesCol("features")
val kmeans_model = kmeans.fit(mri_healthy_brain_df)   
val kmeans_centers = kmeans_model.clusterCenters
println("Cluster Centers --------")
for(k <- kmeans_centers)
  println(k)


val mri_healthy_brain_clusters_df = kmeans_model.transform(mri_healthy_brain_df)
.select("features","prediction")

val image_array = mri_healthy_brain_clusters_df.select("prediction").rdd.map(f => f.getAs[Int](0)).collect()

In the end, the image_array contains the 65536 positions, each one containing its own classification

When loading an image to Dataframe, i'm considering that spark will simply transform the image matrix into a 1D array, which is the binary type row in the DF.

Considering this, i'm simple getting the image_array and transforming in a 256x256 image

I used a map to predefine the Classification colours:

val colors:Map[Int,Int] = Map(
  0 -> 0x717171, //gray
  1 -> 0x0074FF, //light blue
  2 -> 0x95FFDF, //cyan
  3 -> 0xFF3333, //red
  4 -> 0x0058B6, //blue
  5 -> 0xE2CE06, //yellow
  6 -> 0xDB06E2, //pink
  7 -> 0x67C82C, //green
  8 -> 0x8136DC, //purple
  9 -> 0x356F07, //darkgreen
  10 -> 0xE5A812 //orange
  )

and using this func, to generate the image:

def generateImage(img: BufferedImage, image_array: Array[Byte]): BufferedImage = {
// img is the original image
// obtain width and height of image
val w = img.getWidth
val h = img.getHeight

if ( w*h != image_array.size)
  throw new IllegalArgumentException("image array does not fit the provided image");


// create new image of the same size
val out = new BufferedImage(w, h, BufferedImage.TYPE_INT_RGB)

var s = 0
for (x <- 0 until w)
  for (y <- 0 until h){
  out.setRGB(x, y, colors(image_array(s).toInt))
  s +=1
  }
out

}

But the image that im getting, is this one:

enter image description here

I can say for sure that my clusterization pipeline is correct because it matches with the results in the book.

But i'm not sure if spark sort the order of the bytes on the Dataframe after classification, may corrupt the result.

Can anyone give me a hint where I'm doing wrong?

Thanks in Advance


Solution

  • I found how the image data is organized in the ImageSchema. Image data is represented as a 3-dimensional array with the dimension shape (height, width, nChannels) and array values of type t specified by the mode field. The array is stored in row-major order (row-wise BGR in most cases).

    Since i do not have any experience with Open-CV and would take more time to understand the basic principles to just reconstruct the image, i decided to read the Image using Java ImageIO, storing each RGB information on an Array and the creating a DataFrame from it.

    Then i used the same process described before, used KMeans classifier, generated the prediction using a image with tumor and reconstructed the image writing the bytes in the same order.

    The result that i got now is this:

    enter image description here

    You can find my full code here:

    https://github.com/gsjunior86/ScalaMLPratice/blob/master/src/main/scala/br/org/gsj/ml/spark/clustering/kmeans/MriClustering.scala