Search code examples
postgresqlviewsqlalchemyformatnumeric

Limit decimal digits on field of a view created with SqlAlchemy


Having created this view:

class OpenPositionMetric(Base):
    stmt = (
        select(
            [
                OpenPosition.belongs_to.label("belongs_to"),
                OpenPosition.account_number.label("account_number"),
                OpenPosition.exchange.label("exchange"),
                OpenPosition.symbol.label("symbol"),
                round(OpenPosition.actual_shares * OpenPosition.avg_cost_per_share,3).label(
                    "cost_value"
                ),
            ]
        )
        .select_from(OpenPosition)
        .order_by("belongs_to", "account_number", "exchange", "symbol")
    )

    view = create_materialized_view(
        name="vw_open_positions_metrics",
        selectable=stmt,
        metadata=Base.metadata,
        indexes=None,
    )
    __table__ = view

I get the example result for the field cost_value: 1067.2500060000000000.

Is there a way to limit the number of decimal digits for that view field?

The function round() doesn't work. Maybe because round is a python function and SqlAlchemy is expecting and sql expression language function like func.sum?

Update:

I've found a solution but it isn't pretty. I'm sure there is a better one...

(text("ROUND (operations.tb_open_positions.actual_shares * operations.tb_open_positions.avg_cost_per_share,3) AS cost_value"))),

The value above is now displayed in the view as 1067.250


Solution

  • One way to limit the number of decimal places would be to cast the result to Numeric:

    import sqlalchemy as sa
    
    # …
    
    class OpenPosition(Base):
        __tablename__ = "open_position"
        id = sa.Column(sa.Integer, primary_key=True, autoincrement=False)
        actual_shares = sa.Column(sa.Float)
        avg_cost_per_share = sa.Column(sa.Float)
    
    
    Base.metadata.drop_all(engine, checkfirst=True)
    Base.metadata.create_all(engine)
    
    with sa.orm.Session(engine, future=True) as session:
        session.add(
            OpenPosition(id=1, actual_shares=1, avg_cost_per_share=1067.250606)
        )
        session.commit()
        result = session.query(
            (OpenPosition.actual_shares * OpenPosition.avg_cost_per_share).label(
                "cost_value"
            )
        ).all()
        print(result)  # [(1067.250606,)]
    
        result = session.query(
            sa.cast(
                (
                    OpenPosition.actual_shares * OpenPosition.avg_cost_per_share
                ).label("cost_value"),
                sa.Numeric(10, 3),
            )
        ).all()
        print(result)  # [(Decimal('1067.251'),)]