Search code examples
apache-sparkpysparkrddapache-spark-sql

Avoiding a shuffle in Spark by pre-partitioning files (PySpark)


I have a dataset dataset which is partitioned on values 00-99 and want to create an RDD first_rdd to read in the data.

I then want to count how many times the word "foo" occurs in the second element of each partition and store the records of each partition in a list. My output would be final_rdd where each record is of the form (partition_key, (count, record_list)).

def to_list(a):
    return [a]

def append(a, b):
    a.append(b)
    return a

def extend(a, b):
    a.extend(b)
    return a

first_rdd = sqlContext.sql("select * from dataset").rdd
kv_rdd = first_rdd.map(lambda x: (x[4], x)) # x[4] is the partition value
# Group each partition to (partition_key, [list_of_records])
grouped_rdd = kv_rdd.combineByKey(to_list, append, extend)

def count_foo(x):
    count = 0
    for record in x:
        if record[1] == "foo":
            count = count + 1
    return (count, x)

final_rdd = grouped_rdd.mapValues(count_foo)
print("Counted 'foo' for %s partitions" % (final_rdd.count))

Since each partition of the dataset is totally independent from one another computationally, Spark shouldn't need to shuffle, yet when I look at the SparkUI, I notice that the combineByKey is resulting in a very large shuffle.

I have the correct number of initial partitions, and have also tried reading from the partitioned data in HDFS. Each way I try it, I still get a shuffle. What am I doing wrong?


Solution

  • I've solved my problem by using the mapPartitions function and passing it my own reduce function so that it "reduces" locally on each node and will never perform a shuffle.

    In the scenario where data are isolated between each partition, it works perfectly. When the same key exists on more than one partition, this is where a shuffle would be necessary, but this case needs to be detected and handled separately.