I have dataframe:
data = [('I ran home', 3, 1, 10),
('I went home', 3, 1, 11),
('I looked at the cat', 4, 2, 19),
('The cat looked at me', 5, 3, 20),
('I ran home', 3, 4, 10),
('I went homes', 3, 4, 12)]
schema = StructType([ \
StructField("text",StringType(),True), \
StructField("word_count", IntegerType(), True), \
StructField("group", IntegerType(), True), \
StructField("len_text", IntegerType(), True)])
df = spark.createDataFrame(data=data, schema=schema)
df.show(truncate=False)
+--------------------+----------+-----+--------+
|text |word_count|group|len_text|
+--------------------+----------+-----+--------+
|I ran home |3 |1 |10 |
|I went home |3 |1 |11 |
|I looked at the cat |4 |2 |19 |
|The cat looked at me|5 |3 |20 |
|I ran home |3 |4 |10 |
|I went homes |3 |4 |12 |
+--------------------+----------+-----+--------+
I want to filter rows with two conditions: if the values in the word_count
column are the same and if the value in the len_text
column is greater than the next row, then leave the greater value. In pandas i can do this with idmax()
:
df1 = df.loc[df.groupby('group')['len_text'].idxmax()]
Is there any analogue for pyspark? I want this result:
+--------------------+----------+-----+--------+
|text |word_count|group|len_text|
+--------------------+----------+-----+--------+
|I went home |3 |1 |11 |
|I looked at the cat |4 |2 |19 |
|The cat looked at me|5 |3 |20 |
|I went homes |3 |4 |12 |
+--------------------+----------+-----+--------+
You can use window functions, i.e. row_number
from pyspark.sql import functions as F, Window as W
w = W.partitionBy('group').orderBy(F.desc('len_text'))
df = df.withColumn('_rn', F.row_number().over(w))
df = df.filter('_rn=1').drop('_rn')
df.show()
# +--------------------+----------+-----+--------+
# | text|word_count|group|len_text|
# +--------------------+----------+-----+--------+
# | I went home| 3| 1| 11|
# | I looked at the cat| 4| 2| 19|
# |The cat looked at me| 5| 3| 20|
# | I went homes| 3| 4| 12|
# +--------------------+----------+-----+--------+