Search code examples
scalaapache-sparkrdd

How to shift elements in an RDD?


I am new to Scala and try to figure out how to shift elements in an RDD.
I read the pairs from a CSV file:

var listOfPairs = Spark.sc.textFile( <filePath> )
                  .map(aLine => aLine.split(","))
                  .map(aPair=> (aPair(0), aPair(1)))

The content of the file is as follows:

a,1
b,2
c,3
d,4
e,5

In every loop, I want to shift the elements once.

for (i <- 1 to numberOfLoops) { ...?... }

Each step will look like this, for numberOfLoops=3:

[(a,1),(b,2),(c,3),(d,4),(e,5)]  
1: [ (b,2), (c,3), (d,4), (e,5), (a,1) ]  
2: [ (c,3), (d,4), (e,5), (a,1), (b,2) ]  
3: [ (d,4), (e,5), (a,1), (b,2), (c,3) ]  

Solution

  • Here is the basic idea of how to perform a shift. This can be improved for performance (specially there is a way to avoid multiple iterations for many consecutive shifts), but that is left as an exercise for the reader.

    The basis of the algorithm is to give each element an unique key, then create a copy of the data with its key shifted, and the join them by key.

    import org.apache.spark.sql.SparkSession
    
    val spark = SparkSession.builder.master("local[1]").getOrCreate()
    val sc = spark.sparkContext
    import spark.implicits._
    
    val data = sc.parallelize(List("a,1", "b,2", "c,3", "d,4", "e,5"))
    val listOfPairs = data.map(_.split(",")).map { case Array(a, b) => a -> b }
    
    val indexed = listOfPairs.zipWithIndex.map { case (tuple, idx) => idx -> tuple }
    val lastIndex = indexed.count() - 1
    
    val newIndexed = indexed.map {
      case (idx, (a, b)) =>
        if (idx == lastIndex)
          (0L, (a, b))
        else
          (idx + 1, (a, b))
    }
    
    val shifted = newIndexed.join(indexed).map {
      case (_, ((a, _), (_, b))) => a -> b
    }