I have a dataset that looks like this:
+---+
|col|
+---+
| a|
| b|
| c|
| d|
| e|
| f|
| g|
+---+
I want to reformat this dataset so that I aggregate the rows into a arrays of fixed length, like so:
+------+
| col|
+------+
|[a, b]|
|[c, d]|
|[e, f]|
| [g]|
+------+
I tried this:
spark.sql("select collect_list(col) from (select col, row_number() over (order by col) row_number from dataset) group by floor(row_number/2)")
But the problem with this is that my actual dataset is too large to process in a single partition for row_number()
As you wish to distribute this, there are a couple of steps necessary.
In case, you wish to run the code, I am starting from this:
var df = List(
"a", "b", "c", "d", "e", "f", "g"
).toDF("col")
val desiredArrayLength = 2
First, split tyour dataframe into a small one which you can process on single node, and larger one which has number of rows which is multiple of size of desired array (in your example, this is 2)
val nRowsPrune = 1 //number of rows to prune such that remaining dataframe has number of
// rows is multiples of the desired length of array
val dfPrune = df.sort(desc("col")).limit(nRowsPrune)
df = df.join(dfPrune,Seq("col"),"left_anti") //separate small from large dataframe
By construction, you can apply the original code on the small dataframe,
val groupedPruneDf = dfPrune//.withColumn("g",floor((lit(-1)+row_number().over(w))/lit(desiredArrayLength ))) //added -1 as row-number starts from 1
//.groupBy("g")
.agg( collect_list("col").alias("col"))
.select("col")
Now, we need to figure a way to deal with the remaining large dataframe. However, now we made sure, that df has a number of rows which is a multiple of the array size.
This is where we use a great trick, which is repartitioning using repartitionByRange
. Basically, the partitioning guarantees to preserve the sorting and as you are partitioning each partition will have same size.
You can now, collect each array within each partition,
val nRows = df.count()
val maxNRowsPartition = desiredArrayLength //make sure its a multiple of desired array length
val nPartitions = math.max(1,math.floor(nRows/maxNRowsPartition) ).toInt
df = df.repartitionByRange(nPartitions, $"col".desc)
.withColumn("partitionId",spark_partition_id())
val w = Window.partitionBy($"partitionId").orderBy("col")
val groupedDf = df
.withColumn("g", floor( (lit(-1)+row_number().over(w))/lit(desiredArrayLength ))) //added -1 as row-number starts from 1
.groupBy("partitionId","g")
.agg( collect_list("col").alias("col"))
.select("col")
Finally combining the two results yields what you are looking for,
val result = groupedDf.union(groupedPruneDf)
result.show(truncate=false)