I am new to SPARK. I have 2 dataframes events
and players
events dataframe consists of columns
event_id| player_id| match_id| impact_score
players dataframe consists of columns
player_id| player_name| nationality
I am merging the two datasets by player_id
with this query:
df_final = (events
.orderBy("player_id")
.join(players.orderBy("player_id"))
.withColumn("current_team", when([no idea what goes in here]).otherwise(getCurrentTeam(col("player_id"))))
.write.mode("overwrite")
.partitionBy("current_team")
)
getCurrentTeam
function triggers an HTTP call that returns a value (player's current team).
I have data of over 30 million soccer plays and 97 players. I need help creating column current_team
. Imagine certain player appearing 130,000 times in events dataframe. I need to lookup values from previous rows. If player appears, I just grab that value (like an in-memory catalog). If it does not appear, then I call the webservice.
Due to it's distributed nature, Spark can't allow for if allow populated in previous call then use it otherwise call created value. There are two possible options.
players
df has the list of all distinct players, you can add the current_team
column to this df before applying a join. If the players
df is cached before joining then it's possible that the UDF
is invoked only once for each player. See discussion here for why UDF can be called multiple time for each record.getCurrentTeam
current_team
from pyspark.sql import functions as F
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
events_data = [(1, 1, 1, 10), (1, 2, 1, 20, ), (1, 3, 1, 30, ), (2, 3, 1, 30, ), (2, 1, 1, 10), (2, 2, 1, 20, ), ]
players_data = [(1, "Player1", "Nat", ), (2, "Player2", "Nat", ), (3, "Player3", "Nat", ), ]
events = spark.createDataFrame(events_data, ("event_id", "player_id", "match_id", "impact_score", ), ).repartition(3)
players = spark.createDataFrame(players_data, ("player_id", "player_name", "nationality", ), ).repartition(3)
@udf(StringType())
def getCurrentTeam(player_id):
return f"player_{player_id}_team"
players_with_current_team = players.withColumn("current_team", getCurrentTeam(F.col("player_id"))).cache()
events.join(players_with_current_team, ["player_id"]).show()
+---------+--------+--------+------------+-----------+-----------+-------------+
|player_id|event_id|match_id|impact_score|player_name|nationality| current_team|
+---------+--------+--------+------------+-----------+-----------+-------------+
| 2| 2| 1| 20| Player2| Nat|player_2_team|
| 2| 1| 1| 20| Player2| Nat|player_2_team|
| 3| 2| 1| 30| Player3| Nat|player_3_team|
| 3| 1| 1| 30| Player3| Nat|player_3_team|
| 1| 2| 1| 10| Player1| Nat|player_1_team|
| 1| 1| 1| 10| Player1| Nat|player_1_team|
+---------+--------+--------+------------+-----------+-----------+-------------+
I have used a python dict for mimicing caching and using an accumulator
to count number of mimicked network calls made.
from pyspark.sql import functions as F
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import time
events_data = [(1, 1, 1, 10), (1, 2, 1, 20, ), (1, 3, 1, 30, ), (2, 3, 1, 30, ), (2, 1, 1, 10), (2, 2, 1, 20, ), ]
players_data = [(1, "Player1", "Nat", ), (2, "Player2", "Nat", ), (3, "Player3", "Nat", ), ]
events = spark.createDataFrame(events_data, ("event_id", "player_id", "match_id", "impact_score", ), ).repartition(3)
players = spark.createDataFrame(players_data, ("player_id", "player_name", "nationality", ), ).repartition(3)
players_events_joined = events.join(players, ["player_id"])
memoized_call_counter = spark.sparkContext.accumulator(0)
def memoize_call():
cache = {}
def getCurrentTeam(player_id):
global memoized_call_counter
cached_value = cache.get(player_id, None)
if cached_value is not None:
return cached_value
# sleep to mimic network call
time.sleep(1)
# Increment counter everytime cached value can't be lookedup
memoized_call_counter.add(1)
cache[player_id] = f"player_{player_id}_team"
return cache[player_id]
return getCurrentTeam
getCurrentTeam_udf = udf(memoize_call(), StringType())
players_events_joined.withColumn("current_team", getCurrentTeam_udf(F.col("player_id"))).show()
+---------+--------+--------+------------+-----------+-----------+-------------+
|player_id|event_id|match_id|impact_score|player_name|nationality| current_team|
+---------+--------+--------+------------+-----------+-----------+-------------+
| 2| 2| 1| 20| Player2| Nat|player_2_team|
| 2| 1| 1| 20| Player2| Nat|player_2_team|
| 3| 2| 1| 30| Player3| Nat|player_3_team|
| 3| 1| 1| 30| Player3| Nat|player_3_team|
| 1| 2| 1| 10| Player1| Nat|player_1_team|
| 1| 1| 1| 10| Player1| Nat|player_1_team|
+---------+--------+--------+------------+-----------+-----------+-------------+
>>> memoized_call_counter.value
3
Since there are 3 unique players in total the logic after
time.sleep(1)
was called only thrice. The number of calls is dependent on the number of workers, since the cache is not shared across workers. As I ran the example in local mode (wuth 1 worker) we see that the number of calls is equal to number of workers.