Search code examples
pythonpandasdataframepyspark

Get duplicate rows in a specific column from dataframe


I have a dataframe df:

num_rows = 5
num_cols = 3

data = [
    [10, 20, 30],
    [10, 50, 60],
    [70, 80, 90],
    [20, 30, 10],
    [20, 10, 20]
]

columns = [f"Column_{i+1}" for i in range(num_cols)]

df = spark.createDataFrame(data, columns)

|Column_1|Column_2|Column_3|
+--------+--------+--------+
|      10|      20|      30|
|      10|      50|      60|
|      70|      80|      90|
|      20|      30|      10|
|      20|      10|      20|
+--------+--------+--------+

I want to create another column with true/false based on the first column, where the first original value is "true", and any duplicate would be "fasle". So it would look like:

|Column_1|Column_2|Column_3|Column_4|
+--------+--------+--------+--------+
|      10|      20|      30|    TRUE|
|      10|      50|      60|   FALSE|
|      70|      80|      90|    TRUE|
|      20|      30|      10|    TRUE|
|      20|      10|      20|   FALSE|

Solution

  • You did not specify how you determine the "original value" based on the first column, so I'm assuming it is okay to order by Column_2 and Column_3. Nevertheless, you can simply use a row_number() to determine the first row in a partition and populate the new values accordingly.

    from pyspark.sql.window import *
    from pyspark.sql.functions import row_number, when
    
    df = df.withColumn(
        "Column_4",
        when(
            row_number().over(
                Window.partitionBy("Column_1").orderBy("Column_2", "Column_3")
            )
            == 1, 'TRUE'
        )
        .otherwise("FALSE"),
    )
    df.display()