Search code examples
group-bypysparkapache-spark-sqlsample

pyspark equivalent of pandas groupby('col1').col2.head()


I have a Spark Dataframe where for each set of rows with a given column value (col1), I want to grab a sample of the values in (col2). The number of rows for each possible value of col1 may vary widely, so i'm just looking for a set number, say 10, of each type.

There may be a better way to do this, but the natural approach seemed to be a df.groupby('col1')

in pandas, I could do df.groupby('col1').col2.head()

i understand that spark dataframes are not pandas dataframes, but this is a good analogy.

i suppose i could loop over all of col1 types as a filter, but that seems terribly icky.

any thoughts on how to do this? thanks.


Solution

  • Let me create a sample Spark dataframe with two columns.

    df = SparkSQLContext.createDataFrame([[1, 'r1'],
     [1, 'r2'],
     [1, 'r2'],
     [2, 'r1'],
     [3, 'r1'],
     [3, 'r2'],
     [4, 'r1'],
     [5, 'r1'],
     [5, 'r2'],
     [5, 'r1']], schema=['col1', 'col2'])
    df.show()
    
    +----+----+
    |col1|col2|
    +----+----+
    |   1|  r1|
    |   1|  r2|
    |   1|  r2|
    |   2|  r1|
    |   3|  r1|
    |   3|  r2|
    |   4|  r1|
    |   5|  r1|
    |   5|  r2|
    |   5|  r1|
    +----+----+
    

    After grouping by col1, we get GroupedData object (instead of Spark Dataframe). You can use aggregate functions like min, max, average. But getting a head() is little bit tricky. We need to convert GroupedData object back to Spark Dataframe. This can be done Using pyspark collect_list() aggregation function.

    from pyspark.sql import functions
    df1 = df.groupBy(['col1']).agg(functions.collect_list("col2")).show(n=3)
    

    Output is:

    +----+------------------+
    |col1|collect_list(col2)|
    +----+------------------+
    |   5|      [r1, r2, r1]|
    |   1|      [r1, r2, r2]|
    |   3|          [r1, r2]|
    +----+------------------+
    only showing top 3 rows