Search code examples
pythonapache-sparkpysparkregexp-replace

PySpark Python Sorting dataframe using a column


So I have 2 questions which I think should be basic for people experienced in PySpark, but I can't seem to solve them.

Sample entries in my csv file are-

"dfg.AAIXpWU4Q","1"
"cvbc.AAU3aXfQ","1"
"T-L5aw0L1uT_OfFyzbk","1"
"D9TOXY7rA_LsnvwQa-awVk","2"
"JWg8_0lGDA7OCwWcH_9aDc","2"
"ewrq.AAbRaACr2tVh5wA","1"
"ewrq.AALJWAAC-Qku3heg","1"
"ewrq.AADStQqmhJ7A","2"
"ewrq.AAEAABh36oHUNA","1"
"ewrq.AALJABfV5u-7Yg","1"

I create the following dataframe-

>>> df2.show(3)
+-------+----+
|user_id|hits|
+-------+----+
|"aYk...| "7"|
|"yDQ...| "1"|
|"qUU...|"13"|
+-------+----+
only showing top 3 rows

First, is this the right way to convert hits column to IntegerType()? Why are all values becoming null?

>>> df2 = df2.withColumn("hits", df2["hits"].cast(IntegerType()))
>>> df2.show(3)
+-------+----+
|user_id|hits|
+-------+----+
|"aYk...|null|
|"yDQ...|null|
|"qUU...|null|
+-------+----+
only showing top 3 rows

Second, I need to sort this list in descending order with respect to hits column. So, I tried this-

>>> df1 = df2.sort(col('hits').desc())
>>> df1.show(20)

But I get the following error-

java.lang.IllegalStateException: Input row doesn't have expected number of values required by the schema. 2 fields are required while 18 values are provided.

I'm guessing it's due to the fact that I create my dataframe using-

>>> rdd = sc.textFile("/path/to/file/*")
>>> rdd.take(2)
['"7wAfdgdfgd","7"', '"1x3Qdfgdf","1"']
​
>>> my_df = rdd.map(lambda x: (x.split(","))).toDF()

>>> df2 = my_df.selectExpr("_1 as user_id", "_2 as hits")
>>> df2.show(3)
+-------+----+
|user_id|hits|
+-------+----+
|"aYk...| "7"|
|"yDQ...| "1"|
|"qUU...|"13"|
+-------+----+
only showing top 3 rows

And I'm guessing there's extra commas in some rows. How do I avoid this - or what's the best way to read this file?


Solution

  • So, w.r.t @SanBan answer, I came up with the following results-

    >>> rdd = sc.textFile("/home/jsanghvi/work/buffer/*")
    
    >>> schema =  StructType([StructField ("user_id", StringType(), True), StructField ("hits", StringType(), True)])
    
    >>> my_rdd = rdd.map(lambda x: x.replace("'","")).map(lambda x: x.split(",")).map(lambda x: (x[0],x[1]))
    
    >>> my_rdd2 = my_rdd.map(lambda x: str(x).replace("'","").replace("(", "").replace(")", "")).map(lambda x: x.split(",")).map(lambda x: (x[0],x[1]))
    
    >>> df1 = spark.createDataFrame(my_rdd2, schema)
    
    >>> dfx = df1.sort(col('hits').desc())
    
    >>> dfx.show(5)
    +----------------+--------------------+                                     
    |         user_id|                hits|
    +----------------+--------------------+
    |"AUDIO_AUTO_PLAY| EXPANDABLE_AUTOM...|
    |       "user_id"|             "_col1"|
    | "AAESjk66lDk...|              "9999"|
    | "ABexsk6sLlc...|              "9999"|
    | "AAgb1k65pHI...|              "9999"|
    +----------------+--------------------+
    
    # removing garbage rows
    >>> dfx = df2.filter(~col("hits").isin(["_col1", "EXPANDABLE_AUTOM..."]))