Search code examples
pysparkpivotcrosstab

Pyspark Crosstab Pivot Challenge / Problem


I unfortunately could not find a solution for my exact problem. It is related to pivot and crosstab but I could not solve it with these functions. I have the feeling I am missing an in-between-table, but I somehow cannot come up with a solution.

Problem description:

A table with customers indicating from which category they have bought a product. If the customer bought a product from the category, the category ID will be shown next to his name.

There are 4 categories 1 - 4 and 3 customers A, B, C

+--------+----------+
|customer| category |
+--------+----------+
|       A|         1|
|       A|         2|
|       A|         3|
|       B|         1|
|       B|         4|
|       C|         1|
|       C|         3|
|       C|         4|
+--------+----------+

The table is DISTINCT meaning there is only one combination of custmer and category

What I want is a crosstab by category where I can easily read e.g. how many of those who bought from category 1 also bought from category 4?

Desired results table:

+--------+---+---+---+---+
|        | 1 | 2 | 3 | 4 |
+--------+---+---+---+---+
|       1|  3|  1|  2|  2|
|       2|  1|  1|  1|  0|
|       3|  2|  1|  2|  1|
|       4|  2|  0|  1|  1|
+--------+---+---+---+---+

Reading examples: row1 column1 : total number of customers who bought product 1 (A, B, C) row1 column2 : number of customers who bought product 1 and 2 (A) row1 column3 : number of customers who bought product 1 and 3 (A, C) etc. As you can see the table is mirrored by its diagonal.

Any suggestions how to created the desired table?

Additional challenge: How to get the results as %? For the first row the results wold be then: | 100% | 33% | 66% | 66% |

Many thanks in advance!


Solution

  • You can join the input data with itself using customer as join criterium. This returns all combinations of categories that exist for a given customer. After that you can use crosstab to get the result.

    df2 = df.withColumnRenamed("category", "cat1").join(df.withColumnRenamed("category", "cat2"), "customer") \
      .crosstab("cat1", "cat2") \
      .orderBy("cat1_cat2") 
    df2.show()
    

    Output:

    +---------+---+---+---+---+
    |cat1_cat2|  1|  2|  3|  4|
    +---------+---+---+---+---+
    |        1|  3|  1|  2|  2|
    |        2|  1|  1|  1|  0|
    |        3|  2|  1|  2|  1|
    |        4|  2|  0|  1|  2|
    +---------+---+---+---+---+
    

    To get the relative frequency you can sum over each row and then divide each element by this sum.

    df2.withColumn("sum", sum(df2[col] for col in df2.columns if col != "cat1_cat2")) \
      .select("cat1_cat2", *(F.round(df2[col]/F.col("sum"),2).alias(col) for col in df2.columns if col != "cat1_cat2")) \
      .show()
    

    Output:

    +---------+----+----+----+----+
    |cat1_cat2|   1|   2|   3|   4|
    +---------+----+----+----+----+
    |        1|0.38|0.13|0.25|0.25|
    |        2|0.33|0.33|0.33| 0.0|
    |        3|0.33|0.17|0.33|0.17|
    |        4| 0.4| 0.0| 0.2| 0.4|
    +---------+----+----+----+----+