I have a dataframe:
import pyspark.sql.functions as F
sdf1 = spark.createDataFrame(
[
(2022, 1, ["apple", "edible"]),
(2022, 1, ["edible", "fruit"]),
(2022, 1, ["orange", "sweet"]),
(2022, 4, ["flowering ", "plant"]),
(2022, 3, ["green", "kiwi"]),
(2022, 3, ["kiwi", "fruit"]),
(2022, 3, ["fruit", "popular"]),
(2022, 3, ["yellow", "lemon"]),
],
[
"year",
"id",
"bigram",
],
)
sdf1.show(truncate=False)
+----+---+-------------------+
|year|id |bigram |
+----+---+-------------------+
|2022|1 |[apple, edible] |
|2022|1 |[edible, fruit] |
|2022|1 |[orange, sweet] |
|2022|4 |[flowering , plant]|
|2022|3 |[green, kiwi] |
|2022|3 |[kiwi, fruit] |
|2022|3 |[fruit, popular] |
|2022|3 |[yellow, lemon] |
+----+---+-------------------+
And i wrote a function that returns bigrams with the same last words in n-grams.I apply this function separately to the column.
from networkx import DiGraph, dfs_labeled_edges
# Grouping
sdf = (
sdf1.groupby("year", "id")
.agg(F.collect_set("bigram").alias("collect_bigramm"))
.withColumn("size", F.size("collect_bigramm"))
)
data_collect = sdf.collect()
@udf(returnType=ArrayType(StringType()))
def myfunc(lst):
graph = DiGraph()
for row in data_collect:
if row["size"] > 1:
for i, lst1 in enumerate(lst):
while i < len(lst) - 1:
lst2 = lst[i + 1]
if lst1[0] == lst2[1]:
graph.add_edge(lst2[0], lst2[1])
graph.add_edge(lst1[0], lst1[1])
elif lst1[1] == lst2[0]:
graph.add_edge(lst1[0], lst1[1])
graph.add_edge(lst2[0], lst2[1])
i = i + 1
gen = dfs_labeled_edges(graph)
lst_tmp = []
lst_res = []
f = 0
for g in list(gen):
if (g[2] == "forward") and (g[0] != g[1]):
f = 1
lst_tmp.append(g[0])
lst_tmp.append(g[1])
if g[2] == "nontree":
continue
if g[2] == "reverse":
if f == 1:
lst_res.append(lst_tmp.copy())
f = 0
if g[0] in lst_tmp:
lst_tmp.remove(g[0])
if g[1] in lst_tmp:
lst_tmp.remove(g[1])
if lst_res != []:
lst_res = [
ii for n, ii in enumerate(lst_res[0]) if ii not in lst_res[0][:n]
]
if lst_res == []:
lst_res = None
return lst_res
sdf_new = sdf.withColumn("new_col", myfunc(F.col("collect_bigramm")))
sdf_new.show(truncate=False)
Output:
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|year|id |collect_bigramm |size|new_col |
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|2022|4 |[[flowering , plant]] |1 |null |
|2022|1 |[[edible, fruit], [orange, sweet], [apple, edible]] |3 |[apple, edible, fruit] |
|2022|3 |[[yellow, lemon], [green, kiwi], [kiwi, fruit], [fruit, popular]]|4 |[green, kiwi, fruit, popular]|
+----+---+-----------------------------------------------------------------+----+-----------------------------+
But now i want to use the pandas udf. I would like to first groupby and get the collect_bigramm
column in the function. And thus leave all the columns in the dataframe, but also add a new one, which is the lst_res
array in the function.
schema2 = StructType(
[
StructField("year", IntegerType(), True),
StructField("id", IntegerType(), True),
StructField("bigram", ArrayType(StringType(), True), True),
StructField("new_col", ArrayType(StringType(), True), True),
StructField("collect_bigramm", ArrayType(ArrayType(StringType(), True), True), True),
]
)
@pandas_udf(schema2, functionType=PandasUDFType.GROUPED_MAP)
def myfunc(df):
graph = DiGraph()
for index, row in df.iterrows():
# Instead of the variable lst, i need to insert a column sdf['collect_bigramm']
...
return df
sdf_new = sdf.groupby(["year", "id"]).apply(myfunc)
You don't want to run groupBy
twice (one for sdf1
and one for pandas_udf
), it'd simply kill the idea of "grouping a list of records then vectorize it then send to worker" of pandas_udf
. You'd want to do something like this instead sdf1.groupby("year", "id").applyInPandas(myfunc, schema2)
Your UDF is now a "Panda UDF", which is literally just a Python function, take one Pandas DF and return another Pandas UDF. With that meaning, you can even run that function without Spark. The trick here is just how to form your dataframe to feed with what you need. Check the running code below, I kept most of your networkx code, just fix a little from the input and output.
def myfunc(pdf):
pdf = (pdf
.groupby(['year', 'id'])['bigram']
.agg(list=list, len=len) # you might want to fix the list here to set
.reset_index()
.rename(columns={
'list': 'collect_bigram',
'len': 'size',
})
)
graph = DiGraph()
if pdf['size'][0] > 1:
lst = pdf['collect_bigram'][0]
for i, lst1 in enumerate(lst):
... # same as original code
if lst_res == []:
lst_res = None
pdf['new_col'] = [lst_res]
else:
pdf['new_col'] = None
return pdf