Search code examples
pythondecimalfixed-pointexponent

Python decimal context for fixed point


I want to calculate with numbers having 3 places before and 2 places after decimal point (of course, 2 and 3 are configurable). I think it will be easiest to explain by examples:

0.01 and 999.99 are the lower and upper limit for positive numbers. Of course, there is also 0.00, and negative numbers from -999.99 to -0.01. Distance between every two consecutive numbers is 0.01.

7.80 + 1.20 should be 9.00, and 999.00 + 1.00 should be OverflowError. 0.20 * 0.40 should be 0.08, and 0.34 * 0.20 should be 0.07 (this can set a flag to indicate it was rounded, but it mustn't raise any exceptions). 0.34 * 0.01 should be 0.00 (same condition as the previous one).

In fact, I want "ints" from 0 to 99999, just written with a dot after third digit, scaled down 100 times when multiplying, and up 100 times when dividing. It should be possible to find a context exactly for that, right?

The problem is, I can't find the right setting for Emin, Emax, clamp, and prec that will do what I want. For example, I tried setting Emin and Emax to 0, but this raised too many InvalidOperations. The only thing I know is that rounding should be ROUND_HALF_EVEN. :-)


Solution

  • From the documentation:

    Q. Once I have valid two place inputs, how do I maintain that invariant throughout an application?

    A. Some operations like addition, subtraction, and multiplication by an integer will automatically preserve fixed point. Others operations, like division and non-integer multiplication, will change the number of decimal places and need to be followed-up with a quantize() step:

    >>> TWOPLACES = Decimal(10) ** -2   # same as Decimal('0.01')
    >>> a = Decimal('102.72')           # Initial fixed-point values
    >>> b = Decimal('3.17')
    >>> a + b                           # Addition preserves fixed-point
    Decimal('105.89')
    >>> a - b
    Decimal('99.55')
    >>> a * 42                          # So does integer multiplication
    Decimal('4314.24')
    >>> (a * b).quantize(TWOPLACES)     # Must quantize non-integer multiplication
    Decimal('325.62')
    >>> (b / a).quantize(TWOPLACES)     # And quantize division
    Decimal('0.03')
    

    In developing fixed-point applications, it is convenient to define functions to handle the quantize() step:

    >>> def mul(x, y, fp=TWOPLACES):
    ...     return (x * y).quantize(fp)
    >>> def div(x, y, fp=TWOPLACES):
    ...     return (x / y).quantize(fp)    
    >>> mul(a, b)                       # Automatically preserve fixed-point
    Decimal('325.62')
    >>> div(b, a)
    Decimal('0.03')
    

    Seems like the solution is to set precision to 5 and Emax to 2 and use those quantising functions.

    con = decimal.getcontext()
    con.prec = 5
    con.Emax = 2
    con.Emin = 0
    
    try:
        Decimal(1) * 1000
    except decimal.Overflow as e:
        print(e)
    else:
        assert False
    
    assert Decimal("0.99") * 1000 == Decimal("990.00")
    assert div(Decimal(1), 3) == Decimal("0.33")
    

    Creating a Fixed Point Decimal Class

    It seems like it's surprisingly easy to modify the decimal module to become fixed point (at the expense of losing floating point decimals). This is because the Decimal class is referenced by global name in the module decimal. We can pop in our down compatible class and things will work just fine. First you need to prevent python from importing the C _decimal module and make it use the pure-python implementation of the decimal module (so we can override a private method of Decimal). Once you've done that you only need to override one method -- _fix. It's called for every new Decimal that is created for which it is possible that it may not abide by the current decimal context.

    module set up

    # setup python to not import _decimal (c implementation of Decimal) if present
    import sys
    
    if "_decimal" in sys.modules or "decimal" in sys.modules:
        raise ImportError("fixedpointdecimal and the original decimal module do not work"
            " together")
    
    import builtins
    _original_import = __import__
    def _import(name, *args, **kwargs):
        if name == "_decimal":
            raise ImportError
        return _original_import(name, *args, **kwargs)
    builtins.__import__ = _import
    
    # import pure-python implementation of decimal
    import decimal
    
    # clean up
    builtins.__import__ = _original_import # restore original __import__
    del sys, builtins, _original_import, _import # clean up namespace
    

    main Decimal class

    from decimal import *
    
    class FixedPointDecimal(Decimal):
    
        def _fix(self, context):
            # always fit to 2dp
            return super()._fix(context)._rescale(-2, context.rounding)
            # use context to find number of decimal places to use
            # return super()._fix(context)._rescale(-context.decimal_places, context.rounding)
    
    # setup decimal module to use FixedPointDecimal
    decimal.Decimal = FixedPointDecimal
    Decimal = FixedPointDecimal
    

    tests

    getcontext().prec = 5
    getcontext().Emax = 2
    a = Decimal("0.34")
    b = Decimal("0.20")
    assert a * b == Decimal("0.07")
    

    Using a customisable context

    The context class is used to keep track of variables used control how new decimals are created. This way each program or even thread will be able to set the number of decimal places it wants to use for its decimals. Modifying the Context class is bit more long-winded. Below is the full class to create a compatible Context.

    class FixedPointContext(Context):
    
        def __init__(self, prec=None, rounding=None, Emin=None, Emax=None,
                           capitals=None, clamp=None, flags=None, traps=None,
                           _ignored_flags=None, decimal_places=None):
            super().__init__(prec, rounding, Emin, Emax, capitals, clamp, flags, 
                    traps, _ignored_flags)
            try:
                dc = DefaultContext
            except NameError:
                pass
            self.decimal_places = decimal_places if decimal_places is not None else dc.decimal_places
    
        def __setattr__(self, name, value):
            if name == "decimal_places":
                object.__setattr__(self, name, value)
            else:
                super().__setattr__(name, value)
    
        def __reduce__(self):
            flags = [sig for sig, v in self.flags.items() if v]
            traps = [sig for sig, v in self.traps.items() if v]
            return (self.__class__,
                    (self.prec, self.rounding, self.Emin, self.Emax,
                     self.capitals, self.clamp, flags, traps, self._ignored_flags,
                     self.decimal_places))
    
        def __repr__(self):
            """Show the current context."""
            s = []
            s.append('Context(prec=%(prec)d, rounding=%(rounding)s, '
                     'Emin=%(Emin)d, Emax=%(Emax)d, capitals=%(capitals)d, '
                     'clamp=%(clamp)d, decimal_places=%(decimal_places)d'
                     % vars(self))
            names = [f.__name__ for f, v in self.flags.items() if v]
            s.append('flags=[' + ', '.join(names) + ']')
            names = [t.__name__ for t, v in self.traps.items() if v]
            s.append('traps=[' + ', '.join(names) + ']')
            return ', '.join(s) + ')'
    
        def _shallow_copy(self):
            """Returns a shallow copy from self."""
            nc = Context(self.prec, self.rounding, self.Emin, self.Emax,
                         self.capitals, self.clamp, self.flags, self.traps,
                         self._ignored_flags, self.decimal_places)
            return nc
    
        def copy(self):
            """Returns a deep copy from self."""
            nc = Context(self.prec, self.rounding, self.Emin, self.Emax,
                         self.capitals, self.clamp,
                         self.flags.copy(), self.traps.copy(),
                         self._ignored_flags, self.decimal_places)
            return nc
        __copy__ = copy
    
    # reinitialise default context
    DefaultContext = FixedPointContext(decimal_places=2)
    
    # copy changes over to decimal module
    decimal.Context = FixedPointContext
    decimal.DefaultContext = DefaultContext
    Context = FixedPointContext
    
    # test
    decimal.getcontext().decimal_places = 1
    decimal.getcontext().prec = 5
    decimal.getcontext().Emax = 2
    a = Decimal("0.34")
    b = Decimal("0.20")
    assert a * b == Decimal("0.1")