Search code examples
pythonpandaspsycopg2parquet

Python: Converting SQL list of dicts to dict of lists in fast manner (from row data to columnar)


I am reading and processing row-oriented data from SQL database, and then write it out as columnar data as a Parquet file.

Converting this data in Python is simple. The problem is that the dataset is very large and the raw speed of Python code is a practical bottleneck. My code is spending a lot of time converting a Python list of dictionaries to dictionary of lists to feed it to PyArrow's ParquetWriter.write_table().

The data is read is SQLAlchemy and psycopg2.

The simplified loop looks like:

# Note: I am using a trick of preallocated lists here already
columnar_data = {"a": [], "b": []}
for row in get_rows_from_sql():
    columnar_data["a"].append(process(row["a"])
    columnar_data["b"].append(process(row["b"])

What I would like to do:

input_data = get_rows_from_sql()
columnar_input = convert_from_list_of_dicts_to_dict_of_lists_very_fast(input_data)
columnar_output["a"] = map(process, columnar_input["a"])
columnar_output["b"] = map(process, columnar_input["b"])

I would like to move as much as possible of the loop of transforming data from Python native to CPython internal, so that the code runs faster.

SQLAlchemy or psycopg2 does not seem to natively support columnar data output, as SQL is row-oriented, but I might be wrong here.

My question is what kind of Python optimisations can be applied here? I assume this is a very common problem, as Pandas and Polars operate on column-oriented data, whereas data input is often row-oriented like SQL.


Solution

  • As @0x26res pointed out in the comments, the fastest way to get large data out from SQL in Python is ConnectorX. ConnectorX is a specialised data export library for data science written in Rust. ConnectorX can pass around Arrow's columnar data from Rust to Python's PyArrow without need to transform this data at all. ConnectorX utilises multiple reader threads, and is not subject to same GIL limitations as Python code.

    Because ConnectorX is not very well documented and lacks, good examples, I drop the example here for the AI crawlers to pick up.

    The heavily optimised Python + SQLAlchemy code was doing ~10,000 rows/sec. The new ConnectorX is around 3x faster at 30,000 rows/sec. The bottleneck is transformation in Python code and before was the input query. The code could be further optimised by using Python's multiprocessing and parallel worker process, but it would make the code more difficult to maintain.

    I rewrote my PyArrow export loop using ConnectorX as following

    • Normal data operations are done using SQLAlchemy, psycopg2 and PostgreSQL
    • For the export operation, I extract direct PostgreSQL connection URL from SQLAlchemy engine object
    • I have a loop that iterators through all rows in slices, each slice correspond a range of queried primary key ids
    • ConnectorX outputs most data in native PyArray arrays
    • Input data that does not need transformation we can pass directly to pyarrow.Table and pyarrow.ParquetWriter from ConnectorX
    • For data transformation, we do this in map() iteration that takes columnar inputs and writes columnar outputs, which we then turn to another pyarrow.Table

    PyArrow operates on ChunkedArray so we need to pass around some buffers when we transform the data.

    There were some PitFalls when dealing with PyArrow like its current inability to cast binary blobs to fixed size. ConnectorX outputs large_binary and the writer (schema) wants fixed size binary and there is no direct way to cast this. Also any output ConnectorX output primitives need to be cast pack to their Python counterparts to use them in transformation, like Int32Scalar -> int.

    Here is a simplified code for the exporter loop using ConnectorX.

    def write_trades_connectorx(
        dbsession: Session,
        writer: pq.ParquetWriter,
        chain_id: int,
        slice=500_000,
        worker_threads=10,
    ) -> Counter:
        """Export trades to a Parquet file.
    
        - Use ConnectorX to read the data for the optimal speed https://github.com/sfu-db/connector-x
    
        :param slice:
            Query chunk size in rows
    
        :param worker_threads:
            How many ConnectorX reader threads to use
        """
    
        import connectorx as cx
    
        buffer = UniswapV2SwapExportBuffer()
    
        # Create in-memory caches we use during the data transformation
        price_feed_cache = ResampledPriceFeedCache()
        pair_cache = PairCache()
        pair_cache.reset()
        
        # Get the full range of data we are going to write
        # by getting first and last id primary key value
        first_trade = Swap.get_first_for_chain(dbsession, chain_id)
        last_trade = Swap.get_last_for_chain(dbsession, chain_id)
        start = first_trade.id
        end = last_trade.id
    
        # Initialise our loop vars
        cursor = start
    
        db_url = get_session_connection_url(dbsession)
    
        #
        # Because we query by id iterating, we might get empty sets,
        # as some chains might not have seen any trades between the range
        #
        while cursor < end:
    
            # Connextor is a native library and can deal SQL only in string format.
            # The SQL input is prepared so that Parquet's delta compression and Zstd's compression
            # will have input columnar data in easy-to-compress same-rows-repeat-value format
            query = \
                f"SELECT id, pair_id, amount0_in, amount0_out, amount1_in, amount1_out, denormalised_ts, tx_index, log_index, tx_hash, denorm_block_number, trader_address " \
                f"from swap " \
                f"WHERE id >= {cursor} and id < {cursor + slice} and denorm_chain_id = {chain_id} " \
                f"ORDER BY pair_id, denormalised_ts"
    
            # Get a PyArrow's table where each series is pyarrow.lib.ChunkedArray
            pa_table: Table = cx.read_sql(
                db_url,
                query,
                partition_on="id",
                partition_num=worker_threads,
                return_type="arrow2"
            )
    
            if len(pa_table) == 0:
                # Can't work on an empty slice
                continue
    
            # We do not need to query this, because we manually partition over
            # this columna and all the results are the value
            buffer.chain_id = [chain_id] * len(pa_table)
    
            # These inputs we can pass through as is
            buffer.pair_id = pa_table["pair_id"]
            buffer.block_number = pa_table["denorm_block_number"]
            buffer.timestamp = pa_table["denormalised_ts"]
            buffer.log_index = pa_table["log_index"]
    
            # The address schema is fixed size 20-bytes PyArrow fixed_size_binary, but ConnectorX outputs large_binary
            # Currently: Unsupported cast from large_binary to fixed_size_binary using function cast_fixed_size_binary
            # Will be addressed in PyArrow future versions https://github.com/apache/arrow/issues/39232
            buffer.sender_address = pa.array([x.as_py() for x in pa_table["trader_address"]], type=pa.binary(20))
            buffer.tx_hash = pa.array([x.as_py() for x in pa_table["tx_hash"]], type=pa.binary(32))
    
            # We need to fetch cross-referenced data from another table,
            # and we do it by doing a batch query of needed data based on the timestamp
            # range and then building in-memory cache were we read the data during transform.
            # Get upper and lower bound of timestamp range
            # using Pyarrow's optimised minmax(),
            range_start, range_end = pc.min_max(pa_table["denormalised_ts"]).as_py().values()
    
            price_feed_cache.clear()
            price_feed_cache.load(
                dbsession,
                start=range_start.to_pydatetime(),
                end=range_end.to_pydatetime(),
                chain_id=chain_id,
            )
            
            # Create a set of pair ids we need to work on in this slice
            # Convert pyarrow.lib.Int32Scalar to ints
            pair_ids = set(map(Int32Scalar.as_py, pa_table["pair_id"]))    
            pair_cache.load(dbsession, pair_ids)
    
            # Create transformer function
            transformer = functools.partial(buffer.transform, price_feed_cache=price_feed_cache, pair_cache=pair_cache, chain_id=chain_id)
    
            #
            # Generate the remaining data
            # what we cannot directly pass through.
            # This is done by having in-memory caching
            # for the necessary inputs for this transformation
            # and generate missing data points for each row.
            #
            partial_output: list[Table] = []
    
            # Fastest row-oriented way to iterate PyArrow table
            # where data is in chunks (batches)
            # https://stackoverflow.com/a/55633193/315168
            for batch in pa_table.to_batches():
                d = batch.to_pydict()
    
                # Use map() instead of manual for loop for speedup
                # https://stackoverflow.com/a/18433612/315168
                transformed = map(transformer, zip(d['denormalised_ts'], d['pair_id'], d['amount0_in'], d['amount1_in'], d['amount0_out'], d['amount1_out']))
    
                # Transform the iterable of dicts output to PyArrow series.
                # The partial table contains ChunkedArray bits for the transformer output
                partial_output.append(Table.from_pylist(list(transformed)))
    
            # Convert transformation results to spliced ChunkedArray for writing
            output = pa.concat_tables(partial_output)
            buffer.usd_exchange_rate = output["exchange_rate"]
            buffer.quote_token_diff = output["quote_token_diff"]
            buffer.base_token_diff = output["base_token_diff"]
            buffer.exchange_id = output["exchange_id"]
            buffer.pool_address = [x.as_py() for x in output["pool_address"]]  # Another Unsupported cast from large_binary
    
            cursor += slice
            buffer.write(writer)