Search code examples
apache-sparkapache-spark-sqlpyspark

Removing duplicates from rows based on specific columns in an RDD/Spark DataFrame


Let's say I have a rather large dataset in the following form:

data = sc.parallelize([('Foo', 41, 'US', 3),
                       ('Foo', 39, 'UK', 1),
                       ('Bar', 57, 'CA', 2),
                       ('Bar', 72, 'CA', 2),
                       ('Baz', 22, 'US', 6),
                       ('Baz', 36, 'US', 6)])

I would like to remove duplicate rows based on the values of the first, third and fourth columns only.

Removing entirely duplicate rows is straightforward:

data = data.distinct()

and either row 5 or row 6 will be removed.

But how do I only remove duplicate rows based on columns 1, 3 and 4 only? I.e. remove either one one of these:

('Baz', 22, 'US', 6)
('Baz', 36, 'US', 6)

In Python, this could be done by specifying columns with .drop_duplicates(). How can I achieve the same in Spark/PySpark?


Solution

  • PySpark does include a dropDuplicates() method, which was introduced in 1.4.

    >>> from pyspark.sql import Row
    >>> df = sc.parallelize([ \
    ...     Row(name='Alice', age=5, height=80), \
    ...     Row(name='Alice', age=5, height=80), \
    ...     Row(name='Alice', age=10, height=80)]).toDF()
    >>> df.dropDuplicates().show()
    +---+------+-----+
    |age|height| name|
    +---+------+-----+
    |  5|    80|Alice|
    | 10|    80|Alice|
    +---+------+-----+
    
    >>> df.dropDuplicates(['name', 'height']).show()
    +---+------+-----+
    |age|height| name|
    +---+------+-----+
    |  5|    80|Alice|
    +---+------+-----+