Search code examples
pythonapache-sparkpysparkapache-spark-sqlmethod-chaining

Python / Pyspark - Correct method chaining order rules


Coming from a SQL development background, and currently learning pyspark / python I am a bit confused with querying data / chaining methods, using python.

for instance the query below (taken from 'Learning Spark 2nd Edition'):

fire_ts_df.               
 select("CallType")
 .where(col("CallType").isNotNull())
 .groupBy("CallType")
 .count()
 .orderBy("count", ascending=False)
 .show(n=10, truncate=False)

will execute just fine.

What i don't understand though, is that if i had written the code like: (moved the call to 'count()' higher)

fire_ts_df.               
 select("CallType")
 .count()
 .where(col("CallType").isNotNull())
 .groupBy("CallType")
 .orderBy("count", ascending=False)
 .show(n=10, truncate=False)

this wouldn't work. The problem is that i don't want to memorize the order, but i want to understand it. I feel it has something to do with proper method chaining in Python / Pyspark but I am not sure how to justify it. In other words, in a case like this, where multiple methods should be invoked and chained using (.), what is the right order and is there any specific rule to follow?

Thanks a lot in advance


Solution

  • The important thing to note here is that chained methods necessarily do not occur in random order. The operations represented by these method calls are not some associative transformations applied flatly on the data from left to right.

    Each method call could be written as a separate statement, where each statement produces a result that makes the input to the next operation, and so on until the result.

    fire_ts_df.                           
      select("CallType")                  # selects column CallType into a 1-col DF
     .where(col("CallType").isNotNull())  # Filters rows on the 1-column DF from select()
     .groupBy("CallType")                 # Group filtered DF by the one column into a pyspark.sql.group.GroupedData object
     .count()                             # Creates a new DF off the GroupedData with counts
     .orderBy("count", ascending=False)   # Sorts the aggregated DF, as a new DF
     .show(n=10, truncate=False)          # Prints the  last DF
    

    Just to use your example to explain why this doesn't work, calling count() on a pyspark.sql.group.GroupedData creates a new data frame with aggregation results. But count() called on a DataFrame object returns just the number of records, which means that the following call, .where(col("CallType").isNotNull()), is made on a long, which simply doesn't make sense. Longs don't have that filter method.

    As said above, you may visualize it differently by rewriting the code in separate statements:

    call_type_df = fire_ts_df.select("CallType")
    non_null_call_type = call_type_df.where(col("CallType").isNotNull())
    groupings = non_null_call_type.groupBy("CallType")
    counts_by_call_type_df = groupings.count()
    ordered_counts = counts_by_call_type_df.orderBy("count", ascending=False)
    
    ordered_counts.show(n=10, truncate=False)
    

    As you can see, the ordering is meaningful as the succession of operations is consistent with their respective output.

    Chained calls make what is referred to as fluent APIs, which minimize verbosity. But this does not remove the fact that a chained method must be applicable to the type of the output of the preceding call (and in fact that the next operation is intended to be applied on the value produced by the one preceding it).