I am writing a PySpark code from the following table.
user | pets |
---|---|
aa | [["dog"], ["cat"], ["lizard"]] |
bb | [["dog"], ["spider"]] |
cc | [["dog"], ["cat"], ["monkey"]] |
... | ... |
Using explode_outer, I unnested the table to below.
user | pets |
---|---|
aa | "dog" |
aa | "cat" |
aa | "lizard" |
bb | "dog" |
bb | "spider" |
cc | "dog" |
cc | "cat" |
cc | "monkey" |
... | ... |
I would like to make another table that contains these data below.
pet | dog | cat | lizard | spider | monkey |
---|---|---|---|---|---|
dog | 0 | 2 | 1 | 1 | 1 |
cat | 2 | 0 | 1 | 0 | 1 |
lizard | 1 | 1 | 0 | 0 | 0 |
spider | 1 | 0 | 0 | 0 | 0 |
monkey | 1 | 1 | 0 | 0 | 0 |
For instance, in terms of spider, it only lives with dog in the user 'bb'. So that if the data is meaningfully large, I can answer the quote, "Find top 3 pets that are the best suits with your dog."
However, I have no idea how to design appropriate code from the second table to the third one.
The solution based only on the Spark engine. Finding all the combinations, separating them into columns and doing a crosstab
. The result omits repeating entries.
from pyspark.sql import functions as F
df = spark.createDataFrame(
[('aa', [["dog"], ["cat"], ["lizard"]]),
('bb', [["dog"], ["spider"]]),
('cc', [["dog"], ["cat"], ["monkey"]])],
['user', 'pets'])
df = df.withColumn('pets', F.flatten('pets'))
combinations = (F.filter(
F.transform(
F.flatten(F.transform(
'pets',
lambda x: F.arrays_zip(F.array_repeat(x, F.size('pets')), 'pets')
)),
lambda x: F.array(x['0'], x['pets'])
),
lambda x: x[0] < x[1]
))
df = df.withColumn('pets', F.explode(combinations))
df = df.withColumn('pet0', F.col('pets')[0])
df = df.withColumn('pet1', F.col('pets')[1])
df = df.crosstab("pet0", "pet1")
df.show()
# +---------+---+------+------+------+
# |pet0_pet1|dog|lizard|monkey|spider|
# +---------+---+------+------+------+
# | dog| 0| 1| 1| 1|
# | cat| 2| 1| 1| 0|
# +---------+---+------+------+------+