Search code examples
pythonpython-3.xpostgresqlsqlalchemyarray-agg

Multiple array_agg in sqlalchemy


I am working with postgres. I want to fetch multiple fields using array_agg in sqlalchemy. But I couldn't find examples of such use anywhere. I made my request. But I can't process the result of array_agg. I'd like to get a list of strings, or better yet a list of tuples.
It would also be nice to get rid of func.distinct, it's only needed because I can't write it like this: func.array_agg((Task.id, Task.user_id))

My query:

data = session.query(
    Status.id, func.array_agg(func.distinct(Task.id, Task.user_id), type_=TEXT)
).join(Task).group_by(Status.id).limit(5).all()

I got:

(100, '{"(91,1)","(92,1)","(93,1)","(94,1)"}')
(200, '{"(95,1)","(96,1)","(97,1)","(98,1)","(99,1)"}')

But I want:

(100, ["(91,1)","(92,1)","(93,1)","(94,1)"])
(200, ["(95,1)","(96,1)","(97,1)","(98,1)","(99,1)"])

Or better:

(100, [(91,1),(92,1),(93,1),(94,1)])
(200, [(95,1),(96,1),(97,1),(98,1),(99,1)])

I try also:

func.array_agg(func.distinct(Task.id, Task.user_id), type_=ARRAY(TEXT))

I got:

(100, ['{', '"', '(', '9', '1', ',', '1', ')', '"', ',', '"', '(', '9', '2', ',', '1', ')', '"', ',', '"', '(', '9', '3', ',', '1', ')', '"', ',', '"', '(', '9', '4', ',', '1', ')', '"', '}'])
(200, ['{', '"', '(', '9', '5', ',', '1', ')', '"', ',', '"', '(', '9', '6', ',', '1', ')', '"', ',', '"', '(', '9', '7', ',', '1', ')', '"', ',', '"', '(', '9', '8', ',', '1', ')', '"', ',', '"', '(', '9', '9', ',', '1', ')', '"', '}'])

Solution

  • The problem here is that Postgresql's array_agg function is returning an array of unknown type; the default behaviour of the psycopg2 connector in this situation is to simply return the array literal as-is.

    This bug report exists from 2016. SQLAlchemy's maintainer, SO user zzzeek proposed creating a custom type to handle this case. I have modified the solution slightly to convert the tuple elements to integers, and to work with v1.4:

    import re
    from sqlalchemy.types import TypeDecorator
    
    class ArrayOfRecord(TypeDecorator):
        impl = sa.String
        # cache_ok = True seems to work, but I haven't tested extensively
        cache_ok = True
    
        def process_result_value(self, value, dialect):
            elems = re.match(r"^\{(\".+?\")*\}$", value).group(1)
            elems = [e for e in re.split(r'"(.*?)",?', elems) if e]
            return [tuple(
                map(int, re.findall(r'[^\(\),]+', e))
            ) for e in elems]
    

    Using it like this:

    with Session() as session:
        data = (
            session.query(
                Status.id,
                sa.func.array_agg(
                    sa.func.ROW(Task.id, Task.user_id), type_=ArrayOfRecord
                ).label('agg')
            )
            .join(Task)
            .group_by(Status.id)
        )
        print()
        for row in data:
            print(row)
        print()
    

    outputs

    (100, [(91, 1), (92, 1), (93, 1), (94, 1)])
    (200, [(95, 1), (96, 1), (97, 1), (98, 1)])