Search code examples
apache-sparkpysparkrdd

Spark RDD partition by key in exclusive way


I would like to partition an RDD by key and have that each parition contains only values of a single key. For example, if I have 100 different values of the key and I repartition(102), the RDD should have 2 empty partitions and 100 partitions containing each one a single key value.

I tried with groupByKey(k).repartition(102) but this does not guarantee the exclusivity of a key in each partition, since I see some partitions containing more values of a single key and more than 2 empty.

Is there a way in the standard API to do this?


Solution

  • to use partitionBy() RDD must consist of tuple (pair) objects. Lets see an example below:

    Suppose I have an Input file with following data:

    OrderId|OrderItem|OrderDate|OrderPrice|ItemQuantity
    1|Gas|2018-01-17|1895|1
    1|Air Conditioners|2018-01-28|19000|3
    1|Television|2018-01-11|45000|2
    2|Gas|2018-01-17|1895|1
    2|Air Conditioners|2017-01-28|19000|3
    2|Gas|2016-01-17|2300|1
    1|Bottle|2018-03-24|45|10
    1|Cooking oil|2018-04-22|100|3
    3|Inverter|2015-11-02|29000|1
    3|Gas|2014-01-09|2300|1
    3|Television|2018-01-17|45000|2
    4|Gas|2018-01-17|2300|1
    4|Television$$|2018-01-17|45000|2
    5|Medicine|2016-03-14|23.50|8
    5|Cough Syrup|2016-01-28|190|1
    5|Ice Cream|2014-09-23|300|7
    5|Pasta|2015-06-30|65|2
    
    PATH_TO_FILE="file:///u/vikrant/OrderInputFile"
    

    reading file into RDD and skip header

    RDD = sc.textFile(PATH_TO_FILE)
    header=RDD.first();
    newRDD = RDD.filter(lambda x:x != header)
    

    now Lets re-partition RDD into '5' partitions

    partitionRDD = newRDD.repartition(5)
    

    lets have a look how data is being distributed in these '5' partitions

    print("Partitions structure: {}".format(partitionRDD.glom().collect()))
    

    here you can see that data is written into two partitions and, three of them are empty and also it's not being distributed uniformly.

    Partitions structure: [[], 
    [u'1|Gas|2018-01-17|1895|1', u'1|Air Conditioners|2018-01-28|19000|3', u'1|Television|2018-01-11|45000|2', u'2|Gas|2018-01-17|1895|1', u'2|Air Conditioners|2017-01-28|19000|3', u'2|Gas|2016-01-17|2300|1', u'1|Bottle|2018-03-24|45|10', u'1|Cooking oil|2018-04-22|100|3', u'3|Inverter|2015-11-02|29000|1', u'3|Gas|2014-01-09|2300|1'], 
    [u'3|Television|2018-01-17|45000|2', u'4|Gas|2018-01-17|2300|1', u'4|Television$$|2018-01-17|45000|2', u'5|Medicine|2016-03-14|23.50|8', u'5|Cough Syrup|2016-01-28|190|1', u'5|Ice Cream|2014-09-23|300|7', u'5|Pasta|2015-06-30|65|2'], 
    [], []]
    

    We need create a pair RDD in order have the RDD data distributed uniformly across the number of partitions. Lets create a pair RDD and break it into key value pair.

    pairRDD = newRDD.map(lambda x :(x[0],x[1:]))
    

    now lets re partition this rdd into '5' partition and distribute the data uniformly into the partitions using key at [0]th position.

    newpairRDD = pairRDD.partitionBy(5,lambda k: int(k[0]))
    

    now we can see that data is being distributed uniformly according to the matching key value pairs.

    print("Partitions structure: {}".format(newpairRDD.glom().collect()))
    Partitions structure: [
    [(u'5', u'|Medicine|2016-03-14|23.50|8'), 
    (u'5', u'|Cough Syrup|2016-01-28|190|1'), 
    (u'5', u'|Ice Cream|2014-09-23|300|7'), 
    (u'5', u'|Pasta|2015-06-30|65|2')],
    
    [(u'1', u'|Gas|2018-01-17|1895|1'), 
    (u'1', u'|Air Conditioners|2018-01-28|19000|3'), 
    (u'1', u'|Television|2018-01-11|45000|2'), 
    (u'1', u'|Bottle|2018-03-24|45|10'), 
    (u'1', u'|Cooking oil|2018-04-22|100|3')], 
    
    [(u'2', u'|Gas|2018-01-17|1895|1'), 
    (u'2', u'|Air Conditioners|2017-01-28|19000|3'), 
    (u'2', u'|Gas|2016-01-17|2300|1')], 
    
    [(u'3', u'|Inverter|2015-11-02|29000|1'), 
    (u'3', u'|Gas|2014-01-09|2300|1'), 
    (u'3', u'|Television|2018-01-17|45000|2')], 
    
    [(u'4', u'|Gas|2018-01-17|2300|1'), 
    (u'4', u'|Television$$|2018-01-17|45000|2')]
    ]
    

    below you can verify the number of records in each partitions.

    from pyspark.sql.functions import desc
    from pyspark.sql.functions import spark_partition_id
    
    partitionSizes = newpairRDD.glom().map(len).collect();
    
    [4, 5, 3, 3, 2]
    

    Please note that when you create a pair RDD of key value pair, your key should be of type int else you will get an error.

    Hope this helps!