Search code examples
daskdask-dataframe

How to use index in filter in a "dask-sql" SQL query


I create a sample dask dataframe with the timestamp as an index.

df = dask.datasets.timeseries()

df.head()
                       id      name         x         y
timestamp                                              
2000-01-01 00:00:00   915   Norbert -0.989381  0.974546
2000-01-01 00:00:01  1026     Zelda  0.919731  0.656581
2000-01-01 00:00:02  1003  Patricia -0.128303 -0.354592
2000-01-01 00:00:03   986     Jerry  0.557732  0.160812

Now I want to use dask-sql and a filter on the index in an SQL query. This does not work however:

from dask_sql import Context

c = Context()
c.create_table("mytab", df)

result = c.sql("""
        SELECT
            count(*)
        FROM mytab
        WHERE "timestamp" > '2000-01-01 00:00:00'
    """)
print(result.compute())

The Error Message is:

Traceback (most recent call last):
  File "/opt/dask_sql/startup_script.py", line 15, in <module>
    result = c.sql("""
  File "/opt/dask_sql/dask_sql/context.py", line 458, in sql
    rel, select_names, _ = self._get_ral(sql)
  File "/opt/dask_sql/dask_sql/context.py", line 892, in _get_ral
    raise ParsingException(sql, str(e.message())) from None
dask_sql.utils.ParsingException: Can not parse the given SQL: From line 4, column 15 to line 4, column 25: Column 'timestamp' not found in any table

The problem is probably somewhere here:

    
            SELECT count(*)
            FROM timeseries
            WHERE "timestamp" > '2000-01-01'
                  ^^^^^^^^^^^

I am using this docker image nbraun/dask-sql:2022.1.0.

Is there an efficient way to get all the rows based on an index filter? It is important that this can be done in dask-sql because I need to execute the SQL via the presto endpoint provided by the dask-sql-server.


Solution

  • dask-sql doesn't seem to identify "timestamp" as the index-column-name, so one workaround is to use reset_index:

    import dask
    import dask.dataframe as dd
    
    from dask_sql import Context
    
    
    ddf = dask.datasets.timeseries()
    
    
    c = Context()
    c.create_table("mytab", ddf.reset_index())
    
    
    result = c.sql("""
            SELECT
                count(*)
            FROM mytab
            WHERE "timestamp" > '2000-01-01 00:00:00'
        """)
    
    print(result.compute())
    

    In this specific example, we get TypeError('Invalid comparison between dtype=datetime64[ns] and datetime') because pandas/Dask use the datetime64ns format. You can convert the "timestamp" column to be in datetime format using something like:

    import datetime
    
    c.create_table("mytab", ddf.reset_index().assign(timestamp = lambda df: df["timestamp"].apply(lambda x: x.strftime('%Y-%m-%d'), meta=('timestamp', 'object'))))
    
    

    Which is similar to,

    ddf_new = ddf.reset_index()
    
    ddf_new["timestamp"] = ddf_new["timestamp"].apply(lambda x: x.strftime('%Y-%m-%d'), meta=('timestamp', 'object'))
    
    c.create_table(ddf_new)
    

    I'd also encourage you to open relevant issues on the dask-sql issue tracker to reach the team directly. :)