I have a data frame as a result of multiple joins. When I check, it tells me that I have a duplicate, even though that is impossible from my perspective. Here is an abstract example:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
import pyspark.sql.functions as f
from pyspark.sql.functions import lit
# Create a Spark session
spark = SparkSession.builder.appName("CreateDataFrame").getOrCreate()
# User input for number of rows
n_a = 10
n_a_c = 5
n_a_c_d = 3
n_a_c_e = 4
# Define the schema for the DataFrame
schema_a = StructType([StructField("id1", StringType(), True)])
schema_a_b = StructType(
[
StructField("id1", StringType(), True),
StructField("id2", StringType(), True),
StructField("extra", StringType(), True),
]
)
schema_a_c = StructType(
[
StructField("id1", StringType(), True),
StructField("id3", StringType(), True),
]
)
schema_a_c_d = StructType(
[
StructField("id3", StringType(), True),
StructField("id4", StringType(), True),
]
)
schema_a_c_e = StructType(
[
StructField("id3", StringType(), True),
StructField("id5", StringType(), True),
]
)
# Create a list of rows with increasing integer values for "id1" and a constant value of "1" for "id2"
rows_a = [(str(i),) for i in range(1, n_a + 1)]
rows_a_integers = [str(i) for i in range(1, n_a + 1)]
rows_a_b = [(str(i), str(1), "A") for i in range(1, n_a + 1)]
def get_2d_list(ids_part_1: list, n_new_ids: int):
rows = [
[
(str(i), str(i) + "_" + str(j))
for i in ids_part_1
for j in range(1, n_new_ids + 1)
]
]
return [item for sublist in rows for item in sublist]
rows_a_c = get_2d_list(ids_part_1=rows_a_integers, n_new_ids=n_a_c)
rows_a_c_d = get_2d_list(ids_part_1=[i[1] for i in rows_a_c], n_new_ids=n_a_c_d)
rows_a_c_e = get_2d_list(ids_part_1=[i[1] for i in rows_a_c], n_new_ids=n_a_c_e)
# Create the DataFrame
df_a = spark.createDataFrame(rows_a, schema_a)
df_a_b = spark.createDataFrame(rows_a_b, schema_a_b)
df_a_c = spark.createDataFrame(rows_a_c, schema_a_c)
df_a_c_d = spark.createDataFrame(rows_a_c_d, schema_a_c_d)
df_a_c_e = spark.createDataFrame(rows_a_c_e, schema_a_c_e)
# Join everything
df_join = (
df_a.join(df_a_b, on="id1")
.join(df_a_c, on="id1")
.join(df_a_c_d, on="id3")
.join(df_a_c_e, on="id3")
)
# Nested structure
# show
df_nested = df_join.withColumn("id3", f.struct(f.col("id3"))).orderBy("id3")
for i, index in enumerate([(5, 3), (4, 3), (3, None)]):
remaining_columns = list(set(df_nested.columns).difference(set([f"id{index[0]}"])))
df_nested = (
df_nested.groupby(*remaining_columns)
.agg(f.collect_list(f.col(f"id{index[0]}")).alias(f"id{index[0]}_tmp"))
.drop(f"id{index[0]}")
.withColumnRenamed(
f"id{index[0]}_tmp",
f"id{index[0]}",
)
)
if index[1]:
df_nested = df_nested.withColumn(
f"id{index[1]}",
f.struct(
f.col(f"id{index[1]}.*"),
f.col(f"id{index[0]}"),
).alias(f"id{index[1]}"),
).drop(f"id{index[0]}")
I check for duplicates based on id3
which should be unique the entire data frame on the second level:
# Investigate for duplicates
df_test = df_nested.select("id2", "extra", f.explode(f.col("id3")["id3"]).alias("id3"))
df_test.groupby("id3").count().filter(f.col("count") > 1).show()
Which tell me that ID3 == 8_3
exists twice:
+---+-----+
|id3|count|
+---+-----+
|8_3| 2|
+---+-----+
However, in the data frame is clearly unique for ID3. Which can be shown by (ID4 and ID5 are on the next level)
df_join.groupby("id3", "id4", "id5").count().filter(f.col("count") > 1).show()
leading to
+---+---+---+-----+
|id3|id4|id5|count|
+---+---+---+-----+
+---+---+---+-----+
If it helps I use Databricks Runtime Version 11.3 LTS (includes Apache Spark 3.3.0, Scala 2.12)
You are grouping the data frames and aggregating grouped elements into an array. All of this in a loop. Grouping is also performed on previously collected arrays.
The missed assumption is that when you collect_list
of same elements, you're going to get the same array. But is does not hold as spark does not guarantee the order of elements in the result of collect_list
.
The docs specifically say:
The function is non-deterministic because the order of collected results depends on the order of the rows which may be non-deterministic after a shuffle.
For some reason the values of 8_3
are collected out of order (it was the same on my machine) and you get two distinct arrays which will get collected into two rows in the following groupby
.
To solve the "problem" you should explicitly state that you would like the arrays to be sorted (use array_sort
after you collect_list
).
.agg(f.array_sort(f.collect_list(f.col(f"id{index[0]}"))).alias(f"id{index[0]}_tmp"))
or collect values to a set with collect_set
as sets do not have order:
.agg(f.collect_set(f.col(f"id{index[0]}")).alias(f"id{index[0]}_tmp"))
Problem was not easy to find because you've got a lot of generic code with looping, renaming etc.
I hope that helps.