I have a dataframe in pyspark where i have three columns
df1 = spark.createDataFrame([
('a', 3, 4.2),
('a', 7, 4.2),
('b', 7, 2.6),
('c', 7, 7.21),
('c', 11, 7.21),
('c', 18, 7.21),
('d', 15, 9.0),
], ['model', 'number', 'price'])
df1.show()
+-----+------+-----+
|model|number|price|
+-----+------+-----+
| a| 3| 4.2|
| a| 7| 4.2|
| b| 7| 2.6|
| c| 7| 7.21|
| c| 11| 7.21|
| c| 18| 7.21|
| d| 15| 9.0|
+-----+------+-----+
Is there a way in pyspark to display only the values that are repeated in the column 'price'?
like in df2 :
df2 = spark.createDataFrame([
('a', 3, 4.2),
('a', 7, 4.2),
('c', 7, 7.21),
('c', 11, 7.21),
('c', 18, 7.21),
], ['model', 'number', 'price'])
df2.show()
+-----+------+-----+
|model|number|price|
+-----+------+-----+
| a| 3| 4.2|
| a| 7| 4.2|
| c| 7| 7.21|
| c| 11| 7.21|
| c| 18| 7.21|
+-----+------+-----+
I tried to do this, but didn't work
df = df1.groupBy("model","price").count().filter("count > 1")
df2 = df1.where((df.model == df1.model) & (df.price == df1.price))
df2.show()
it included the values that are not repeated too
+-----+------+-----+
|model|number|price|
+-----+------+-----+
| a| 3| 4.2|
| a| 7| 4.2|
| b| 7| 2.6|
| c| 7| 7.21|
| c| 11| 7.21|
| c| 18| 7.21|
| d| 15| 9.0|
+-----+------+-----+
You can do so with a window function. We partition by price, take a count and filter count > 1
.
from pyspark.sql import Window
from pyspark.sql import functions as f
w = Window().partitionBy('price')
df1.withColumn('_c', f.count('price').over(w)).filter('_c > 1').drop('_c').show()
+-----+------+-----+
|model|number|price|
+-----+------+-----+
| a| 3| 4.2|
| a| 7| 4.2|
| c| 7| 7.21|
| c| 11| 7.21|
| c| 18| 7.21|
+-----+------+-----+