Let's say I have a rather large dataset in the following form:
data = sc.parallelize([('Foo', 41, 'US', 3),
('Foo', 39, 'UK', 1),
('Bar', 57, 'CA', 2),
('Bar', 72, 'CA', 2),
('Baz', 22, 'US', 6),
('Baz', 36, 'US', 6)])
I would like to remove duplicate rows based on the values of the first, third and fourth columns only.
Removing entirely duplicate rows is straightforward:
data = data.distinct()
and either row 5 or row 6 will be removed.
But how do I only remove duplicate rows based on columns 1, 3 and 4 only? I.e. remove either one one of these:
('Baz', 22, 'US', 6)
('Baz', 36, 'US', 6)
In Python, this could be done by specifying columns with .drop_duplicates()
. How can I achieve the same in Spark/PySpark?
PySpark does include a dropDuplicates()
method, which was introduced in 1.4.
>>> from pyspark.sql import Row
>>> df = sc.parallelize([ \
... Row(name='Alice', age=5, height=80), \
... Row(name='Alice', age=5, height=80), \
... Row(name='Alice', age=10, height=80)]).toDF()
>>> df.dropDuplicates().show()
+---+------+-----+
|age|height| name|
+---+------+-----+
| 5| 80|Alice|
| 10| 80|Alice|
+---+------+-----+
>>> df.dropDuplicates(['name', 'height']).show()
+---+------+-----+
|age|height| name|
+---+------+-----+
| 5| 80|Alice|
+---+------+-----+