Search code examples
pythonapache-sparkpysparkapache-spark-sqldatabricks

Duplicates even there are no duplicates


I have a data frame as a result of multiple joins. When I check, it tells me that I have a duplicate, even though that is impossible from my perspective. Here is an abstract example:

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
import pyspark.sql.functions as f
from pyspark.sql.functions import lit

# Create a Spark session
spark = SparkSession.builder.appName("CreateDataFrame").getOrCreate()

# User input for number of rows
n_a = 10
n_a_c = 5
n_a_c_d = 3
n_a_c_e = 4

# Define the schema for the DataFrame
schema_a = StructType([StructField("id1", StringType(), True)])
schema_a_b = StructType(
    [
        StructField("id1", StringType(), True),
        StructField("id2", StringType(), True),
        StructField("extra", StringType(), True),
    ]
)
schema_a_c = StructType(
    [
        StructField("id1", StringType(), True),
        StructField("id3", StringType(), True),
    ]
)
schema_a_c_d = StructType(
    [
        StructField("id3", StringType(), True),
        StructField("id4", StringType(), True),
    ]
)
schema_a_c_e = StructType(
    [
        StructField("id3", StringType(), True),
        StructField("id5", StringType(), True),
    ]
)

# Create a list of rows with increasing integer values for "id1" and a constant value of "1" for "id2"
rows_a = [(str(i),) for i in range(1, n_a + 1)]
rows_a_integers = [str(i) for i in range(1, n_a + 1)]
rows_a_b = [(str(i), str(1), "A") for i in range(1, n_a + 1)]


def get_2d_list(ids_part_1: list, n_new_ids: int):
    rows = [
        [
            (str(i), str(i) + "_" + str(j))
            for i in ids_part_1
            for j in range(1, n_new_ids + 1)
        ]
    ]
    return [item for sublist in rows for item in sublist]


rows_a_c = get_2d_list(ids_part_1=rows_a_integers, n_new_ids=n_a_c)
rows_a_c_d = get_2d_list(ids_part_1=[i[1] for i in rows_a_c], n_new_ids=n_a_c_d)
rows_a_c_e = get_2d_list(ids_part_1=[i[1] for i in rows_a_c], n_new_ids=n_a_c_e)

# Create the DataFrame
df_a = spark.createDataFrame(rows_a, schema_a)
df_a_b = spark.createDataFrame(rows_a_b, schema_a_b)
df_a_c = spark.createDataFrame(rows_a_c, schema_a_c)
df_a_c_d = spark.createDataFrame(rows_a_c_d, schema_a_c_d)
df_a_c_e = spark.createDataFrame(rows_a_c_e, schema_a_c_e)

# Join everything
df_join = (
    df_a.join(df_a_b, on="id1")
    .join(df_a_c, on="id1")
    .join(df_a_c_d, on="id3")
    .join(df_a_c_e, on="id3")
)

# Nested structure
# show
df_nested = df_join.withColumn("id3", f.struct(f.col("id3"))).orderBy("id3")

for i, index in enumerate([(5, 3), (4, 3), (3, None)]):
    remaining_columns = list(set(df_nested.columns).difference(set([f"id{index[0]}"])))
    df_nested = (
        df_nested.groupby(*remaining_columns)
        .agg(f.collect_list(f.col(f"id{index[0]}")).alias(f"id{index[0]}_tmp"))
        .drop(f"id{index[0]}")
        .withColumnRenamed(
            f"id{index[0]}_tmp",
            f"id{index[0]}",
        )
    )

    if index[1]:
        df_nested = df_nested.withColumn(
            f"id{index[1]}",
            f.struct(
                f.col(f"id{index[1]}.*"),
                f.col(f"id{index[0]}"),
            ).alias(f"id{index[1]}"),
        ).drop(f"id{index[0]}")

I check for duplicates based on id3 which should be unique the entire data frame on the second level:

# Investigate for duplicates
df_test = df_nested.select("id2", "extra", f.explode(f.col("id3")["id3"]).alias("id3"))
df_test.groupby("id3").count().filter(f.col("count") > 1).show()

Which tell me that ID3 == 8_3 exists twice:

+---+-----+
|id3|count|
+---+-----+
|8_3|    2|
+---+-----+

However, in the data frame is clearly unique for ID3. Which can be shown by (ID4 and ID5 are on the next level)

df_join.groupby("id3", "id4", "id5").count().filter(f.col("count") > 1).show()

leading to

+---+---+---+-----+
|id3|id4|id5|count|
+---+---+---+-----+
+---+---+---+-----+

If it helps I use Databricks Runtime Version 11.3 LTS (includes Apache Spark 3.3.0, Scala 2.12)


Solution

  • You are grouping the data frames and aggregating grouped elements into an array. All of this in a loop. Grouping is also performed on previously collected arrays.

    The missed assumption is that when you collect_list of same elements, you're going to get the same array. But is does not hold as spark does not guarantee the order of elements in the result of collect_list. The docs specifically say:

    The function is non-deterministic because the order of collected results depends on the order of the rows which may be non-deterministic after a shuffle.

    For some reason the values of 8_3 are collected out of order (it was the same on my machine) and you get two distinct arrays which will get collected into two rows in the following groupby.

    To solve the "problem" you should explicitly state that you would like the arrays to be sorted (use array_sort after you collect_list).

    .agg(f.array_sort(f.collect_list(f.col(f"id{index[0]}"))).alias(f"id{index[0]}_tmp"))
    

    or collect values to a set with collect_set as sets do not have order:

    .agg(f.collect_set(f.col(f"id{index[0]}")).alias(f"id{index[0]}_tmp"))
    

    Problem was not easy to find because you've got a lot of generic code with looping, renaming etc.

    I hope that helps.