Search code examples
scalaapache-sparkdata-miningapriori

Apache Spark flatMap time complexity


I've been trying to find a way to count the number of times sets of Strings occur in a transaction database (implementing the Apriori algorithm in a distributed fashion). The code I have currently is as follows:

val cand_br = sc.broadcast(cand)
transactions.flatMap(trans => freq(trans, cand_br.value))
            .reduceByKey(_ + _)            
}

def freq(trans: Set[String], cand: Array[Set[String]]) : Array[(Set[String],Int)] = {
    var res = ArrayBuffer[(Set[String],Int)]()
    for (c <- cand) {
        if (c.subsetOf(trans)) {
            res += ((c,1))
        }
    }
    return res.toArray
}

transactions starts out as an RDD[Set[String]], and I'm trying to convert it to an RDD[(K, V), with K every element in cand and V the number of occurrences of each element in cand in the transaction list.

When watching performance on the UI, the flatMap stage quickly takes about 3min to finish, whereas the rest takes < 1ms.

transactions.count() ~= 88000 and cand.length ~= 24000 for an idea of the data I'm dealing with. I've tried different ways of persisting the data, but I'm pretty positive that it's an algorithmic problem I am faced with.

Is there a more optimal solution to solve this subproblem?

PS: I'm fairly new to Scala / Spark framework, so there might be some strange constructions in this code


Solution

  • Probably, the right question to ask in this case would be: "what is the time complexity of this algorithm". I think it is very much unrelated to Spark's flatMap operation.

    Rough O-complexity analysis

    Given 2 collections of Sets of size m and n, this algorithm is counting how many elements of one collection are a subset of elements of the other collection, so it looks like complexity m x n. Looking one level deeper, we also see that 'subsetOf' is linear of the number of elements of the subset. x subSet y == x forAll y, so actually the complexity is m x n x s where s is the cardinality of the subsets being checked.

    In other words, this flatMap operation has a lot of work to do.

    Going Parallel

    Now, going back to Spark, we can also observe that this algo is embarrassingly parallel and we can take advantage of Spark's capabilities to our advantage.

    To compare some approaches, I loaded the 'retail' dataset [1] and ran the algo on val cand = transactions.filter(_.size<4).collect. Data size is a close neighbor of the question:

    • Transactions.count = 88162
    • cand.size = 15451

    Some comparative runs on local mode:

    • Vainilla: 1.2 minutes
    • Increase transactions partitions up to # of cores (8): 33 secs

    I also tried an alternative implementation, using cartesian instead of flatmap:

    transactions
        .cartesian(candRDD)
        .map{case (tx, cd) => (cd, if (cd.subsetOf(tx)) 1 else 0)}
        .reduceByKey(_ + _)
        .collect
    

    But that resulted in much longer runs as seen in the top 2 lines of the Spark UI (cartesian and cartesian with a higher number of partitions): 2.5 min

    Given I only have 8 logical cores available, going above that does not help.

    enter image description here

    Sanity checks:

    Is there any added 'Spark flatMap time complexity'? Probably some, as it involves serializing closures and unpacking collections, but negligible in comparison with the function being executed.

    Let's see if we can do a better job: I implemented the same algo using plain scala:

    val resLocal = reduceByKey(transLocal.flatMap(trans => freq(trans, cand)))

    Where the reduceByKey operation is a naive implementation taken from [2] Execution time: 3.67 seconds. Sparks gives you parallelism out of the box. This impl is totally sequential and therefore takes longer to complete.

    Last sanity check: A trivial flatmap operation:

    transactions
        .flatMap(trans => Seq((trans, 1)))
        .reduceByKey( _ + _)
        .collect
    

    Execution time: 0.88 secs

    Conclusions:

    Spark is buying you parallelism and clustering and this algo can take advantage of it. Use more cores and partition the input data accordingly. There's nothing wrong with flatmap. The time complexity prize goes to the function inside it.