Bakground
We are unloading data from Redshift into S3 and then loading it into a dataframe like so:
df = spark.read.csv(path, schema=schema, sep='|')
We are using PySpark and AWS EMR (version 5.4.0) with Spark 2.1.0.
Problem
I have a Redshift table that is being read into PySpark as CSV. The records are in this sort of format:
url,category1,category2,category3,category4
http://example.com,0.6,0.0,0.9,0.3
url is VARCHAR and the category values are FLOAT between 0.0 and 1.0.
What I want to do is generate a new DataFrame with a single row per category where the value in the original dataset was above some threshold X. For example, if the threshold were set to 0.5 then I would want my new dataset to look like this:
url,category
http://example.com,category1
http://example.com,category3
I'm new to Spark/PySpark so I'm not sure how/if this is possible to do so any help would be appreciated!
EDIT:
Wanted to add my solution (based on Pushkr's code). We have a TON of categories to load so to avoid hardcoding every single select I did the following:
parsed_df = None
for column in column_list:
if not parsed_df:
parsed_df = df.select(df.url, when(df[column]>threshold,column).otherwise('').alias('cat'))
else:
parsed_df = parsed_df.union(df.select(df.url, when(df[column]>threshold,column).otherwise('')))
if parsed_df is not None:
parsed_df = parsed_df.filter(col('cat') != '')
where column_list is a previously generated list of category column names and threshold is the minimum value required to select the category.
Thanks again!
Here is something That I tried -
data = [('http://example.com',0.6,0.0,0.9,0.3),('http://example1.com',0.6,0.0,0.9,0.3)]
df = spark.createDataFrame(data)\
.toDF('url','category1','category2','category3','category4')
from pyspark.sql.functions import *
df\
.select(df.url,when(df.category1>0.5,'category1').otherwise('').alias('category'))\
.union(\
df.select(df.url,when(df.category2>0.5,'category2').otherwise('')))\
.union(\
df.select(df.url,when(df.category3>0.5,'category3').otherwise('')))\
.union(\
df.select(df.url,when(df.category4>0.5,'category4').otherwise('')))\
.filter(col('category')!= '')\
.show()
output :
+-------------------+---------+
| url| category|
+-------------------+---------+
| http://example.com|category1|
|http://example1.com|category1|
| http://example.com|category3|
|http://example1.com|category3|
+-------------------+---------+