Search code examples
pythonapache-sparkpysparkemr

Using Spark to get names of all columns that have a value over some threshold


Bakground

We are unloading data from Redshift into S3 and then loading it into a dataframe like so:

df = spark.read.csv(path, schema=schema, sep='|')

We are using PySpark and AWS EMR (version 5.4.0) with Spark 2.1.0.

Problem

I have a Redshift table that is being read into PySpark as CSV. The records are in this sort of format:

url,category1,category2,category3,category4
http://example.com,0.6,0.0,0.9,0.3

url is VARCHAR and the category values are FLOAT between 0.0 and 1.0.

What I want to do is generate a new DataFrame with a single row per category where the value in the original dataset was above some threshold X. For example, if the threshold were set to 0.5 then I would want my new dataset to look like this:

url,category
http://example.com,category1
http://example.com,category3

I'm new to Spark/PySpark so I'm not sure how/if this is possible to do so any help would be appreciated!

EDIT:

Wanted to add my solution (based on Pushkr's code). We have a TON of categories to load so to avoid hardcoding every single select I did the following:

parsed_df = None
for column in column_list:
    if not parsed_df:
        parsed_df = df.select(df.url, when(df[column]>threshold,column).otherwise('').alias('cat'))
    else:
        parsed_df = parsed_df.union(df.select(df.url, when(df[column]>threshold,column).otherwise('')))
if parsed_df is not None:
    parsed_df = parsed_df.filter(col('cat') != '')

where column_list is a previously generated list of category column names and threshold is the minimum value required to select the category.

Thanks again!


Solution

  • Here is something That I tried -

    data = [('http://example.com',0.6,0.0,0.9,0.3),('http://example1.com',0.6,0.0,0.9,0.3)]
    
    df = spark.createDataFrame(data)\
         .toDF('url','category1','category2','category3','category4')
    
    from pyspark.sql.functions import *
    
    
    
    df\
        .select(df.url,when(df.category1>0.5,'category1').otherwise('').alias('category'))\
        .union(\
        df.select(df.url,when(df.category2>0.5,'category2').otherwise('')))\
        .union(\
        df.select(df.url,when(df.category3>0.5,'category3').otherwise('')))\
        .union(\
        df.select(df.url,when(df.category4>0.5,'category4').otherwise('')))\
        .filter(col('category')!= '')\
        .show()
    

    output :

    +-------------------+---------+
    |                url| category|
    +-------------------+---------+
    | http://example.com|category1|
    |http://example1.com|category1|
    | http://example.com|category3|
    |http://example1.com|category3|
    +-------------------+---------+