Search code examples
scalaapache-sparkdataframeparallel-processingparallel-collections

Attemping to parallelize a nested loop in Scala


I am comparing 2 dataframes in scala/spark using a nested loop and an external jar.

for (nrow <- dfm.rdd.collect) {   
  var mid = nrow.mkString(",").split(",")(0)
  var mfname = nrow.mkString(",").split(",")(1)
  var mlname = nrow.mkString(",").split(",")(2)  
  var mlssn = nrow.mkString(",").split(",")(3)  

  for (drow <- dfn.rdd.collect) {
    var nid = drow.mkString(",").split(",")(0)
    var nfname = drow.mkString(",").split(",")(1)
    var nlname = drow.mkString(",").split(",")(2)  
    var nlssn = drow.mkString(",").split(",")(3)  

    val fNameArray = Array(mfname,nfname)
    val lNameArray = Array (mlname,nlname)
    val ssnArray = Array (mlssn,nlssn)

    val fnamescore = Main.resultSet(fNameArray)
    val lnamescore = Main.resultSet(lNameArray)
    val ssnscore =  Main.resultSet(ssnArray)

    val overallscore = (fnamescore +lnamescore +ssnscore) /3

    if(overallscore >= .95) {
       println("MeditechID:".concat(mid)
         .concat(" MeditechFname:").concat(mfname)
         .concat(" MeditechLname:").concat(mlname)
         .concat(" MeditechSSN:").concat(mlssn)
         .concat(" NextGenID:").concat(nid)
         .concat(" NextGenFname:").concat(nfname)
         .concat(" NextGenLname:").concat(nlname)
         .concat(" NextGenSSN:").concat(nlssn)
         .concat(" FnameScore:").concat(fnamescore.toString)
         .concat(" LNameScore:").concat(lnamescore.toString)
         .concat(" SSNScore:").concat(ssnscore.toString)
         .concat(" OverallScore:").concat(overallscore.toString))
    }
  }
}

What I'm hoping to do is add some parallelism to the outer loop so that I can create a threadpool of 5 and pull 5 records from the collection of the outerloop and compare them to the collection of the inner loop, rather than doing this serially. So the outcome would be I can specify the number of threads, have 5 records from the outerloop's collection processing at any given time against the collection in the inner loop. How would I go about doing this?


Solution

  • Let's start by analyzing what you are doing. You collect the data of dfm to the driver. Then, for each element you collect the data from dfn, transform it and compute a score for each pair of elements.

    That's problematic in many ways. First even without considering parallel computing, the transformations on the elements of dfn are made as many times as dfm as elements. Also, you collect the data of dfn for every row of dfm. That's a lot of network communications (between the driver and the executors).

    If you want to use spark to parallelize you computations, you need to use the API (RDD , SQL or Datasets). You seem to want to use RDDs to perform a cartesian product (this is O(N*M) so be careful, it may take a while).

    Let's start by transforming the data before the Cartesian product to avoid performing them more than once per element. Also, for clarity, let's define a case class to contain your data and a function that transform your dataframes into RDDs of that case class.

    case class X(id : String, fname : String, lname : String, lssn : String)
    def toRDDofX(df : DataFrame) = {
        df.rdd.map(row => {
            // using pattern matching to convert the array to the case class X
            row.mkString(",").split(",") match {
                case Array(a, b, c, d) => X(a, b, c, d)
            } 
        })
    }
    

    Then, I use filter to keep only the tuples whose score is more than .95 but you could use map, foreach... depending on what you intend to do.

    val rddn = toRDDofX(dfn)
    val rddm = toRDDofX(dfm)
    rddn.cartesian(rddm).filter{ case (xn, xm) => {
        val fNameArray = Array(xm.fname,xn.fname)
        val lNameArray = Array(xm.lname,xn.lname)
        val ssnArray = Array(xm.lssn,xn.lssn)
    
        val fnamescore = Main.resultSet(fNameArray)
        val lnamescore = Main.resultSet(lNameArray)
        val ssnscore =  Main.resultSet(ssnArray)
    
        val overallscore = (fnamescore +lnamescore +ssnscore) /3
        // and then, let's say we filter by score
        overallscore > .95
    }}