I have a Spark Dataframe where for each set of rows with a given column value (col1), I want to grab a sample of the values in (col2). The number of rows for each possible value of col1 may vary widely, so i'm just looking for a set number, say 10, of each type.
There may be a better way to do this, but the natural approach seemed to be a df.groupby('col1')
in pandas, I could do df.groupby('col1').col2.head()
i understand that spark dataframes are not pandas dataframes, but this is a good analogy.
i suppose i could loop over all of col1 types as a filter, but that seems terribly icky.
any thoughts on how to do this? thanks.
Let me create a sample Spark dataframe with two columns.
df = SparkSQLContext.createDataFrame([[1, 'r1'],
[1, 'r2'],
[1, 'r2'],
[2, 'r1'],
[3, 'r1'],
[3, 'r2'],
[4, 'r1'],
[5, 'r1'],
[5, 'r2'],
[5, 'r1']], schema=['col1', 'col2'])
df.show()
+----+----+
|col1|col2|
+----+----+
| 1| r1|
| 1| r2|
| 1| r2|
| 2| r1|
| 3| r1|
| 3| r2|
| 4| r1|
| 5| r1|
| 5| r2|
| 5| r1|
+----+----+
After grouping by col1, we get GroupedData object (instead of Spark Dataframe). You can use aggregate functions like min, max, average. But getting a head() is little bit tricky. We need to convert GroupedData object back to Spark Dataframe. This can be done Using pyspark collect_list()
aggregation function.
from pyspark.sql import functions
df1 = df.groupBy(['col1']).agg(functions.collect_list("col2")).show(n=3)
Output is:
+----+------------------+
|col1|collect_list(col2)|
+----+------------------+
| 5| [r1, r2, r1]|
| 1| [r1, r2, r2]|
| 3| [r1, r2]|
+----+------------------+
only showing top 3 rows