How to extract column names from SQL query using Python

I would like to extract the column names of a resulting table directly from the SQL statement:

query = """

    sales.order_id as id, 
    sum(p.price) as sales_volume 
from sales
right join products as p 
    on sales.product_id=p.product_id
group by id, p.product_name;


column_names = parse_sql(query)
# column_names:
# ['id', 'product_name', 'sales_volume']

Any idea what to do in parse_sql()? The resulting function should be able to recognize aliases and remove the table aliases/identifiers (e.g. "sales." or "p.").

Thanks in advance!


  • I've done something like this using the library sqlparse. Basically, this library takes your SQL query and tokenizes it. Once that is done, you can search for the select query token and parse the underlying tokens. In code, that reads like

    import sqlparse
    def find_selected_columns(query) -> list[str]:
        tokens = sqlparse.parse(query)[0].tokens
        found_select = False
        for token in tokens:
            if found_select:
                if isinstance(token, sqlparse.sql.IdentifierList):
                    return [
                        col.value.split(" ")[-1].strip("`").rpartition('.')[-1]
                        for col in token.tokens
                        if isinstance(col, sqlparse.sql.Identifier)
                found_select = token.match(sqlparse.tokens.Keyword.DML, ["select", "SELECT"])
        raise Exception("Could not find a select statement. Weired query :)")

    This code should also work for queries with Common table expressions, i.e. it only return the final select columns. Depending on the SQL dialect and the quote chars you are using, you might to have to adapt the line col.value.split(" ")[-1].strip("`").rpartition('.')[-1]