Search code examples
apache-sparkpysparkapache-spark-sqlspark3

SPARK 3 - Populate value with value from previous rows (lookup)


I am new to SPARK. I have 2 dataframes events and players

events dataframe consists of columns

event_id| player_id| match_id| impact_score

players dataframe consists of columns

player_id| player_name| nationality

I am merging the two datasets by player_id with this query:

df_final = (events
  .orderBy("player_id") 
  .join(players.orderBy("player_id"))
  .withColumn("current_team", when([no idea what goes in here]).otherwise(getCurrentTeam(col("player_id"))))
  .write.mode("overwrite")
  .partitionBy("current_team")
)

getCurrentTeam function triggers an HTTP call that returns a value (player's current team).

I have data of over 30 million soccer plays and 97 players. I need help creating column current_team. Imagine certain player appearing 130,000 times in events dataframe. I need to lookup values from previous rows. If player appears, I just grab that value (like an in-memory catalog). If it does not appear, then I call the webservice.


Solution

  • Due to it's distributed nature, Spark can't allow for if allow populated in previous call then use it otherwise call created value. There are two possible options.

    1. Since you are applying an inner join and players df has the list of all distinct players, you can add the current_team column to this df before applying a join. If the players df is cached before joining then it's possible that the UDF is invoked only once for each player. See discussion here for why UDF can be called multiple time for each record.
    2. You can memoize getCurrentTeam

    Working Example - Prepopulate current_team

    from pyspark.sql import functions as F
    from pyspark.sql.functions import udf
    from pyspark.sql.types import StringType
    
    events_data = [(1, 1, 1, 10), (1, 2, 1, 20, ), (1, 3, 1, 30, ), (2, 3, 1, 30, ), (2, 1, 1, 10), (2, 2, 1, 20, ), ]
    players_data = [(1, "Player1", "Nat", ), (2, "Player2", "Nat", ), (3, "Player3", "Nat", ), ]
    
    events = spark.createDataFrame(events_data, ("event_id", "player_id", "match_id", "impact_score", ), ).repartition(3)
    players = spark.createDataFrame(players_data, ("player_id", "player_name", "nationality", ), ).repartition(3)
    
    
    @udf(StringType())
    def getCurrentTeam(player_id):
        return f"player_{player_id}_team"
    
    players_with_current_team = players.withColumn("current_team", getCurrentTeam(F.col("player_id"))).cache()
    
    events.join(players_with_current_team, ["player_id"]).show()
    

    Output

    +---------+--------+--------+------------+-----------+-----------+-------------+
    |player_id|event_id|match_id|impact_score|player_name|nationality| current_team|
    +---------+--------+--------+------------+-----------+-----------+-------------+
    |        2|       2|       1|          20|    Player2|        Nat|player_2_team|
    |        2|       1|       1|          20|    Player2|        Nat|player_2_team|
    |        3|       2|       1|          30|    Player3|        Nat|player_3_team|
    |        3|       1|       1|          30|    Player3|        Nat|player_3_team|
    |        1|       2|       1|          10|    Player1|        Nat|player_1_team|
    |        1|       1|       1|          10|    Player1|        Nat|player_1_team|
    +---------+--------+--------+------------+-----------+-----------+-------------+
    

    Working Example - Memoization

    I have used a python dict for mimicing caching and using an accumulator to count number of mimicked network calls made.

    from pyspark.sql import functions as F
    from pyspark.sql.functions import udf
    from pyspark.sql.types import StringType
    import time
    
    events_data = [(1, 1, 1, 10), (1, 2, 1, 20, ), (1, 3, 1, 30, ), (2, 3, 1, 30, ), (2, 1, 1, 10), (2, 2, 1, 20, ), ]
    players_data = [(1, "Player1", "Nat", ), (2, "Player2", "Nat", ), (3, "Player3", "Nat", ), ]
    
    events = spark.createDataFrame(events_data, ("event_id", "player_id", "match_id", "impact_score", ), ).repartition(3)
    players = spark.createDataFrame(players_data, ("player_id", "player_name", "nationality", ), ).repartition(3)
    
    players_events_joined = events.join(players, ["player_id"])
    
    memoized_call_counter = spark.sparkContext.accumulator(0)
    def memoize_call():
        cache = {}
        def getCurrentTeam(player_id):
            global memoized_call_counter
            cached_value = cache.get(player_id, None)
            if cached_value is not None:
                return cached_value
            # sleep to mimic network call
            time.sleep(1)
            # Increment counter everytime cached value can't be lookedup
            memoized_call_counter.add(1)
            cache[player_id] = f"player_{player_id}_team"
            return cache[player_id]
        return getCurrentTeam
        
    getCurrentTeam_udf = udf(memoize_call(), StringType())
    
    players_events_joined.withColumn("current_team", getCurrentTeam_udf(F.col("player_id"))).show()
    

    Output

    +---------+--------+--------+------------+-----------+-----------+-------------+
    |player_id|event_id|match_id|impact_score|player_name|nationality| current_team|
    +---------+--------+--------+------------+-----------+-----------+-------------+
    |        2|       2|       1|          20|    Player2|        Nat|player_2_team|
    |        2|       1|       1|          20|    Player2|        Nat|player_2_team|
    |        3|       2|       1|          30|    Player3|        Nat|player_3_team|
    |        3|       1|       1|          30|    Player3|        Nat|player_3_team|
    |        1|       2|       1|          10|    Player1|        Nat|player_1_team|
    |        1|       1|       1|          10|    Player1|        Nat|player_1_team|
    +---------+--------+--------+------------+-----------+-----------+-------------+
    
    >>> memoized_call_counter.value
    3
    

    Since there are 3 unique players in total the logic after time.sleep(1) was called only thrice. The number of calls is dependent on the number of workers, since the cache is not shared across workers. As I ran the example in local mode (wuth 1 worker) we see that the number of calls is equal to number of workers.