Search code examples
apache-sparkrddpysparksemi-join

What is the right way to do a semi-join on two Spark RDDs (in PySpark)?


In my PySpark application, I have two RDD's:

  • items - This contains item ID and item name for all valid items. Approx 100000 items.

  • attributeTable - This contains the fields user ID, item ID and an attribute value of this combination in that order. These is a certain attribute for each user-item combination in the system. This RDD has several 100s of 1000s of rows.

I would like to discard all rows in attributeTable RDD that don't correspond to a valid item ID (or name) in the items RDD. In other words, a semi-join by the item ID. For instance, if these were R data frames, I would have done semi_join(attributeTable, items, by="itemID")

I tried the following approach first, but found that this takes forever to return (on my local Spark installation running on a VM on my PC). Understandably so, because there are such a huge number of comparisons involved:

# Create a broadcast variable of all valid item IDs for doing filter in the drivers
validItemIDs = sc.broadcast(items.map(lambda (itemID, itemName): itemID)).collect())
attributeTable = attributeTable.filter(lambda (userID, itemID, attributes): itemID in set(validItemIDs.value))

After a bit of fiddling around, I found that the following approach works pretty fast (a min or so on my system).

# Create a broadcast variable for item ID to item name mapping (dictionary) 
itemIdToNameMap = sc.broadcast(items.collectAsMap())

# From the attribute table, remove records that don't correspond to a valid item name.
# First go over all records in the table and add a dummy field indicating whether the item name is valid
# Then, filter out all rows with invalid names. Finally, remove the dummy field we added.
attributeTable = (attributeTable
                  .map(lambda (userID, itemID, attributes): (userID, itemID, attributes, itemIdToNameMap.value.get(itemID, 'Invalid')))
                  .filter(lambda (userID, itemID, attributes, itemName): itemName != 'Invalid')
                  .map(lambda (userID, itemID, attributes, itemName): (userID, itemID, attributes)))

Although this works well enough for my application, it feels more like a dirty workaround and I am pretty sure there must be another cleaner or idiomatically correct (and possibly more efficient) way or ways to do this in Spark. What would you suggest? I am new to both Python and Spark, so any RTFM advices will also be helpful if you could point me to the right resources.

My Spark version is 1.3.1.


Solution

  • As others have pointed out, this is probably most easily accomplished by leveraging DataFrames. However, you might be able to accomplish your intended goal by using the leftOuterJoin and the filter functions. Something a bit hackish like the following might suffice:

    items = sc.parallelize([(123, "Item A"), (456, "Item B")])
    attributeTable = sc.parallelize([(123456, 123, "Attribute for A")])
    sorted(items.leftOuterJoin(attributeTable.keyBy(lambda x: x[1]))
           .filter(lambda x: x[1][1] is not None)
           .map(lambda x: (x[0], x[1][0])).collect())
    

    returns

    [(123, 'Item A')]