Search code examples
pythonmemory-managementsqlalchemy

How to keep constant the memory used by SQLAlchemy


I have this Python script :

from typing import Any, Tuple
import sys

def find_memory_usage(insertion: bool):
    from sqlalchemy import create_engine, select, String
    from sqlalchemy.orm import MappedAsDataclass, DeclarativeBase, declared_attr
    from sqlalchemy.orm import Mapped, mapped_column
    import random
    import string
    from sqlalchemy.orm import sessionmaker
    from datetime import datetime
    import gc
    import psutil

    class Base(MappedAsDataclass, DeclarativeBase):
        __abstract__ = True

        @declared_attr
        def __tablename__(self) -> Any:
            return self.__name__

    class User(Base):
        id: Mapped[int] = mapped_column(init=False, name="ID", primary_key=True)
        name: Mapped[str] = mapped_column(String(16), name="Name", nullable=False)
        description: Mapped[str] = mapped_column(
            String(256), name="Description", nullable=False
        )

    engine = create_engine("postgresql+psycopg2://postgres:password@localhost/postgres")
    Base.metadata.create_all(engine, tables=[User.__table__])

    # Function to generate a random string
    def random_string(length):
        letters = string.ascii_letters
        return "".join(random.choice(letters) for i in range(length))

    # Set up the session maker
    Session = sessionmaker(bind=engine)

    with Session() as session:
        if insertion:
            # Insert 100,000 User rows with random data
            for _ in range(100000):
                user = User(
                    name=random_string(16),  # 10 characters long names
                    description=random_string(128),  # 128 characters long descriptions
                )
                session.add(user)

            session.commit()
            print("Data insertion complete.")
            return

        print(datetime.now(), psutil.Process().memory_info().rss / (1024 * 1024))
        number = (int(sys.argv[1]) if len(sys.argv) > 1 else 100000)
        for row in session.scalars(select(User).fetch(number)):
            pass
        session.commit()
        session.expire_all()
        session.expunge_all()
        session.close()
        del session
    gc.collect()
    print(datetime.now(), psutil.Process().memory_info().rss / (1024 * 1024))

# find_memory_usage(True) # I have done this once already to insert data.
find_memory_usage(False)

If I run with python test.py 1000, I got result:

2024-01-21 22:59:37.291944 48.23828125
2024-01-21 22:59:37.321317 49.86328125

If I run with python test.py 100000, I got result:

2024-01-21 22:59:51.152666 47.8984375
2024-01-21 22:59:52.477458 73.06640625

We can see that the increase in memory usage depends on number in the script.

The relevant section of code is this :

        number = (int(sys.argv[1]) if len(sys.argv) > 1 else 100000)
        for row in session.scalars(select(User).fetch(number)):
            pass

As it processes the result row by row, I would expect the memory usage to stay constant.

How (if possible) should I change this part session.scalars(select(User).fetch(number)) to achieve this goal.


Solution

  • This is a complicated topic and it involves python itself, sqlalchemy, the db driver (ie. psycopg2) and the db server itself (ie. postgresql). Depending on your db support you can use yield-per as described here: orm-queryguide-yield-per.

    Specifically in your case it is my understanding that ORM usually fetches the entire result set at once because that is usually more efficient than making many trips to the server even for larger data sets. Obviously at some point it isn't so that is why yield-per and the machinery that supports that exists.

    Your example would be changed to this (with the adjusting COUNT_PER_YIELD as needed):

    #...
    COUNT_PER_YIELD = 1000
    q = select(User).fetch(number).execution_options(yield_per=COUNT_PER_YIELD)
    for row in session.scalars(q):
                pass