Search code examples
pysparkapache-spark-sql

Creating a range of dates in PySpark


I was want to create a range of dates on Spark Dataframe, there is no function to do this by default. So, I wrote this,

from pyspark.sql import *
import pyspark.sql.functions as F
from pyspark.sql.types import *
spark = SparkSession.builder.appName('test').getOrCreate()

data_frame = spark.range(1, 10).withColumn('date_start', F.to_date(F.lit('2018-01-01'), 'yyyy-MM-dd'))

The result is

+---+----------+
| id|date_start|
+---+----------+
|  1|2018-01-01|
|  2|2018-01-01|
|  3|2018-01-01|
|  4|2018-01-01|
|  5|2018-01-01|
+---+----------+

Now I want to add the 'date_start' column with 'id' and create a column of dates ranges from start to end.

data_frame.withColumn('date_window', F.date_add(F.col('date_start'), F.col('id')))

But I got the TypeError

TypeError                                 Traceback (most recent call last)
<ipython-input-151-9e46a2ad88a2> in <module>
----> 1 data_frame.withColumn('date_window', F.date_add(F.col('date_start'), F.col('id')))

~\AppData\Local\Continuum\anaconda3\lib\site-packages\pyspark\sql\functions.py in date_add(start, days)
   1039     """
   1040     sc = SparkContext._active_spark_context
-> 1041     return Column(sc._jvm.functions.date_add(_to_java_column(start), days))
   1042 
   1043 

~\AppData\Local\Continuum\anaconda3\lib\site-packages\py4j\java_gateway.py in __call__(self, *args)
   1246 
   1247     def __call__(self, *args):
-> 1248         args_command, temp_args = self._build_args(*args)
   1249 
   1250         command = proto.CALL_COMMAND_NAME +\

~\AppData\Local\Continuum\anaconda3\lib\site-packages\py4j\java_gateway.py in _build_args(self, *args)
   1210     def _build_args(self, *args):
   1211         if self.converters is not None and len(self.converters) > 0:
-> 1212             (new_args, temp_args) = self._get_args(args)
   1213         else:
   1214             new_args = args

~\AppData\Local\Continuum\anaconda3\lib\site-packages\py4j\java_gateway.py in _get_args(self, args)
   1197                 for converter in self.gateway_client.converters:
   1198                     if converter.can_convert(arg):
-> 1199                         temp_arg = converter.convert(arg, self.gateway_client)
   1200                         temp_args.append(temp_arg)
   1201                         new_args.append(temp_arg)

~\AppData\Local\Continuum\anaconda3\lib\site-packages\py4j\java_collections.py in convert(self, object, gateway_client)
    498         ArrayList = JavaClass("java.util.ArrayList", gateway_client)
    499         java_list = ArrayList()
--> 500         for element in object:
    501             java_list.add(element)
    502         return java_list

~\AppData\Local\Continuum\anaconda3\lib\site-packages\pyspark\sql\column.py in __iter__(self)
    342 
    343     def __iter__(self):
--> 344         raise TypeError("Column is not iterable")
    345 
    346     # string methods

TypeError: Column is not iterable

For some reason, I could solve this problems using the Spark function expr

data_frame.withColumn("date_window", F.expr("date_add(date_start, id)"))

And voilà! It's seems to work

+---+----------+-----------+
| id|date_start|date_window|
+---+----------+-----------+
|  1|2018-01-01| 2018-01-02|
|  2|2018-01-01| 2018-01-03|
|  3|2018-01-01| 2018-01-04|
|  4|2018-01-01| 2018-01-05|
|  5|2018-01-01| 2018-01-06|
+---+----------+-----------+

My question is: How could the expr function to be different from that function that I wrote?


Solution

  • Based on the timing of this question, I assume you were using PySpark v2.4.0 or older

    Judging from the description, I believe this error was caused by the PySpark date_add() implementation, not properly accepting Columns as input - hence the TypeError: Column is not iterable message.

    Checking the v2.4.0 source of the Python date_add() , the function appears to just be calling the Scala backend of the same function.

    @since(1.5)
    def date_add(start, days):
        """
        Returns the date that is `days` days after `start`
    
        >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
        >>> df.select(date_add(df.dt, 1).alias('next_date')).collect()
        [Row(next_date=datetime.date(2015, 4, 9))]
        """
        sc = SparkContext._active_spark_context
        return Column(sc._jvm.functions.date_add(_to_java_column(start), days))
    

    So checking the v2.4.0 source of the Scala date_add() , the function appears to be accepting Integers and not Columns.

      /**
       * Returns the date that is `days` days after `start`
       *
       * @param start A date, timestamp or string. If a string, the data must be in a format that
       *              can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
       * @param days  The number of days to add to `start`, can be negative to subtract days
       * @return A date, or null if `start` was a string that could not be cast to a date
       * @group datetime_funcs
       * @since 1.5.0
       */
      def date_add(start: Column, days: Int): Column = withExpr { DateAdd(start.expr, Literal(days)) }
    

    In fact, I could no longer reproduce your error (as of 2023-12-07) with PySpark v3.4.1.
    Instead, I now get a different error - see footnote.

    Checking the newer v3.4.1 source of the Scala date_add() , the function is now overloaded to accept both Integers and Columns as input.

      /**
       * Returns the date that is `days` days after `start`
       *
       * @param start A date, timestamp or string. If a string, the data must be in a format that
       *              can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
       * @param days  The number of days to add to `start`, can be negative to subtract days
       * @return A date, or null if `start` was a string that could not be cast to a date
       * @group datetime_funcs
       * @since 1.5.0
       */
      def date_add(start: Column, days: Int): Column = date_add(start, lit(days))
    
      /**
       * Returns the date that is `days` days after `start`
       *
       * @param start A date, timestamp or string. If a string, the data must be in a format that
       *              can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
       * @param days  A column of the number of days to add to `start`, can be negative to subtract days
       * @return A date, or null if `start` was a string that could not be cast to a date
       * @group datetime_funcs
       * @since 3.0.0
       */
      def date_add(start: Column, days: Column): Column = withExpr { DateAdd(start.expr, days.expr) }
    

    Whereas using expr() appears to convert and execute your query with server SQL. Current source

    def expr(str: str) -> Column:
        return Column(SQLExpression(str))
    

    Footnote 1 - Reproduction Error

    For my current (PySpark v3.4.1) setup, passing columns as input does work, but I run into a datatypes issue instead. Additionally, using expr("date_add(date_start, id)") results in the exact same error.

    import pyspark.sql.functions as F
    
    spark.range(1,10)\
      .withColumn('date_start', F.to_date(F.lit('2018-01-01'), 'yyyy-MM-dd'))\
      .withColumn('date_window', F.date_add(F.col('date_start'), F.col('id')))\
      .show()
    
    AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "date_add(date_start, id)" due to data type mismatch: Parameter 2 requires the ("INT" or "SMALLINT" or "TINYINT") type, however "id" has the type "BIGINT".;
    'Project [id#324806L, date_start#324808, date_add(date_start#324808, id#324806L) AS date_window#324811]
    +- Project [id#324806L, to_date(2018-01-01, Some(yyyy-MM-dd), Some(Asia/Bangkok), false) AS date_start#324808]
       +- Range (1, 10, step=1, splits=Some(64))
    

    Instead, I can solve this error by using cast() to change the datatype to INT.

    df = spark.range(1,10)\
      .withColumn('date_start', F.to_date(F.lit('2018-01-01'), 'yyyy-MM-dd'))\
      .withColumn('date_window', F.date_add(F.col('date_start'), F.col('id').cast('int') ))
    df.show()
    
    +---+----------+-----------+
    | id|date_start|date_window|
    +---+----------+-----------+
    |  1|2018-01-01| 2018-01-02|
    |  2|2018-01-01| 2018-01-03|
    |  3|2018-01-01| 2018-01-04|
    |  4|2018-01-01| 2018-01-05|
    |  5|2018-01-01| 2018-01-06|
    |  6|2018-01-01| 2018-01-07|
    |  7|2018-01-01| 2018-01-08|
    |  8|2018-01-01| 2018-01-09|
    |  9|2018-01-01| 2018-01-10|
    +---+----------+-----------+
    

    I believe how the Pyspark range() function works has changed from v2.4.0 to v3.4.1, and the function now results in a column of BIGINT instead of just INT. However, I am unable to find evidence of this change from looking through the source codes, so this is still just a hypothesis.


    Footnote 2 - Alternate Approach for Generating Date Ranges

    Now for an unrelated tangent, I'm surprised there still isn't a built in function to generate date ranges in PySpark in 2023. Here's my alternative method by using Pandas date_range(), which I find to be much more powerful overall.

    You simple generate a DataFrame with your date range in Pandas, then convert that to PySpark DataFrames:

    import pandas as pd
    import pyspark.sql.functions as F
    
    pandas_df = pd.DataFrame(pd.date_range("1999-12-30", "2000-01-02"), columns=["DATE_RANGE"])
    spark_df = spark.createDataFrame(pandas_df)
    spark_df.show()
    
    +-------------------+
    |         DATE_RANGE|
    +-------------------+
    |1999-12-30 00:00:00|
    |1999-12-31 00:00:00|
    |2000-01-01 00:00:00|
    |2000-01-02 00:00:00|
    +-------------------+
    

    The advantage is that you can use the freq argument to specify much more complex date ranges - List of Pandas frequency abbreviations

    So generating something like "5 end of month periods" is automatically handled for you:

    pd.date_range(start='1/1/2018', periods=5, freq='M')
    DatetimeIndex(['2018-01-31', '2018-02-28', '2018-03-31', '2018-04-30',
                   '2018-05-31'],
                  dtype='datetime64[ns]', freq='M')