I have a dataframe of the following format:
name merged
key1 (internalKey1, value1)
key1 (internalKey2, value2)
...
key2 (internalKey3, value3)
...
What I want to do is group the dataframe by the name
, collect the list and limit the size of the list.
This is how i group by the name
and collect the list:
val res = df.groupBy("name")
.agg(collect_list(col("merged")).as("final"))
The resuling dataframe is something like:
key1 [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list
key2 [(internalKey3, value3),...]
What I want to do is limit the size of the produced lists for each key. I' ve tried multiple ways to do that but had no success. I've already seen some posts that suggest 3rd party solutions but I want to avoid that. Is there a way?
You can create a function that limits the size of the aggregated ArrayType column as shown below:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column
case class KV(k: String, v: String)
val df = Seq(
("key1", KV("internalKey1", "value1")),
("key1", KV("internalKey2", "value2")),
("key2", KV("internalKey3", "value3")),
("key2", KV("internalKey4", "value4")),
("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")
def limitSize(n: Int, arrCol: Column): Column =
array( (0 until n).map( arrCol.getItem ): _* )
df.
groupBy("name").agg( collect_list(col("merged")).as("final") ).
select( $"name", limitSize(2, $"final").as("final2") ).
show(false)
// +----+----------------------------------------------+
// |name|final2 |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+