Search code examples
pythonmathdecimalroundingfixed-point

Python rounding with Decimal module to specified decimal places


Question: How could I make Python's Decimal module round to a specified decimal place instead of rounding to a specified precision (significant figure) while evaluating an arithmetic operation?

Info

I have been using the Decimal module in Python to round values to a specified precision using the setcontext method. This works well until we start crossing whole numbers and decimals because significant figures do not differentiate between the two.

import decimal as d
from math import pi

decimal_places = 0
d.setcontext(d.Context(prec=decimal_places+1, rounding=d.ROUND_HALF_UP))

# This works fine
num = pi
print(f"Rounding {num} to {decimal_places} decimal places:")
print(f"Traditional rounding (correct): {round(num, decimal_places)}")
print(f"Decimal rounding (correct): {+d.Decimal(num)}")

# This is were issues start to arise
num = pi/10
print(f"\nRounding {num} to {decimal_places} decimal places:")
print(f"Traditional rounding (correct): {round(num, decimal_places)}")
print(f"Decimal rounding (incorrect): {+d.Decimal(num)}")
Rounding 3.141592653589793 to 0 decimal places:
Traditional rounding (correct): 3.0
Decimal rounding (correct): 3

Rounding 0.3141592653589793 to 0 decimal places:
Traditional rounding (correct): 0.0
Decimal rounding (incorrect): 0.3

Use case

Why even use the decimal module over Python's round function? Well the advantage of the decimal module is that it would apply that precision cap in all steps of arithmetic evaluation (PEMDAS).

For example if I wanted to round x as it gets evaluated in the function I could just do:

function_str = "0.5 * (3*x) ** 2 + 3"
eval(function_str.replace("x", "(+d.Decimal(x))"))

A more complete (and simpler) example:

import decimal as d

decimal_places = 0
d.setcontext(d.Context(prec=decimal_places+1, rounding=d.ROUND_HALF_UP))

numerator = 5
denominator = 1.1
num_err = 0.5
new_num = numerator + num_err

print(f"Rounding {numerator}/{denominator} to {decimal_places} decimal places:")
print(f"Traditional rounding (incorrect): {round(new_num, decimal_places)/denominator}")
print(f"Decimal rounding (correct): {+d.Decimal(new_num) / d.Decimal(denominator)}")
Rounding 5/1.1 to 0 decimal places:
Traditional rounding (incorrect): 5.454545454545454
Decimal rounding (correct): 5

It may still seem like round would be a simpler solution here as it could just be placed around the output, but as the complexity of the function increases the less and less viable this becomes. In cases were the user enters the function the viability of traditional rounding is practically zero while using the decimal module is as simple as function_str.replace("x", "(+d.Decimal(x))").

Note that the quantize method will not be a viable option as it only rounds the current number instead of everything (which is what setting the context precision does).


Solution

  • To solve this I ended up just making my own fixed-point arithmetic library. To help out anyone else who runs into this problem in the future I posted the code for my fixed-point arithmetic library below.

    import math
    
    
    PREC = 0
    
    
    def no_rounding(x, *args, **kwargs):
        return x
    
    
    def ceil(x, prec=0):
        mult = 10 ** prec
        return round(math.ceil(x * mult) / mult, prec)
    
    
    def floor(x, prec=0):
        mult = 10 ** prec
        return round(math.floor(x * mult) / mult, prec)
    
    
    rounding = {
        None: no_rounding,
        "round": round,
        "ceil": ceil,
        "floor": floor,
    }
    
    
    class Fixed:
        def __init__(self, number, round_function="round", custom_prec=None):
            self.val = float(number)
            self.round_str = round_function
            self.round_func = rounding[round_function]
            self.custom_prec = custom_prec
    
        def _dup_fixed(self, number):
            return Fixed(number, self.round_str, self.custom_prec)
    
        def _operation(self, op):
            return self._dup_fixed(self.round_func(op, self.prec))
    
        @property
        def prec(self):
            return int(self.custom_prec if self.custom_prec is not None else PREC)
    
        @property
        def num(self):
            return self.round_func(self.val, self.prec)
    
        @property
        def real(self):
            return self
    
        @property
        def imag(self):
            return Fixed(0)
    
        def __setattr__(self, name, value):
            if name == "val":
                value = float(value)
            self.__dict__[name] = value
    
        def __hash__(self):
            return hash(self.num)
    
        def __str__(self):
            return str(self.num)
    
        __repr__ = __str__
    
        def __format__(self, spec):
            if spec == "":
                return str(self)
            else:
                return spec % self.num
    
        def __reduce__(self):
            return (self.__class__, (self.val,))
    
        def __copy__(self):
            return self.__class__(self.val)
    
        def __deepcopy__(self, memo):
            return self.__copy__()
    
        def __pos__(self):
            return self
    
        def __neg__(self):
            return self._dup_fixed(-self.val)
    
        def __abs__(self):
            return self._dup_fixed(abs(self.val))
    
        def __round__(self, n=None):
            return self._dup_fixed(round(self.val, n))
    
        def __floor__(self):
            return self._dup_fixed(math.floor(self.val))
    
        def __ceil__(self):
            return self._dup_fixed(math.ceil(self.val))
    
        def __int__(self):
            return int(self.num)
    
        def __trunc__(self):
            return math.trunc(self.num)
    
        def __float__(self):
            return float(self.num)
    
        def __complex__(self):
            return complex(self.num)
    
        def conjugate(self):
            return self
    
        def __eq__(self, other):
            return self.num == float(other)
    
        def __ne__(self, other):
            return not self == float(other)
    
        def __gt__(self, other):
            return self.num > float(other)
    
        def __ge__(self, other):
            return self.num >= float(other)
    
        def __lt__(self, other):
            return self.num < float(other)
    
        def __le__(self, other):
            return self.num <= float(other)
    
        def __bool__(self):
            return self.num != 0
    
        def __add__(self, other):
            return self._operation(self.num + float(other))
    
        __radd__ = __add__
    
        def __sub__(self, other):
            return self + -other
    
        def __rsub__(self, other):
            return -self + other
    
        def __mul__(self, other):
            return self._operation(self.num * float(other))
    
        __rmul__ = __mul__
    
        def __truediv__(self, other):
            return self._operation(self.num / float(other))
    
        def __rtruediv__(self, other):
            return self._operation(float(other) / self.num)
    
        def __floordiv__(self, other):
            return self._operation(self.num // float(other))
    
        def __rfloordiv__(self, other):
            return self._operation(float(other) // self.num)
    
        def __mod__(self, other):
            return self._operation(self.num % float(other))
    
        def __rmod__(self, other):
            return self._operation(float(other) % self.num)
    
        def __divmod__(self, other):
            result = divmod(self.num, float(other))
            return (self._operation(result[0]), self._operation(result[1]))
    
        def __rdivmod__(self, other):
            result = divmod(float(other), self.num)
            return (self._operation(result[0]), self._operation(result[1]))
    
        def __pow__(self, other):
            return self._operation(self.num ** float(other))
    
        def __rpow__(self, other):
            return self._operation(float(other) ** self.num)
    

    Let me know in the comments if you found any bugs or problems and I will be sure to update my answer.

    Usage

    A fixed number is created by passing the number to the Fixed function. This fixed number can then be treated similarly to a normal number.

    import fixed_point as fp  # Import file
    
    num = 1.6
    fixed_num = fp.Fixed(num)  # Default precision is 0
    print("Original number:", num)
    print("Fixed number:", fixed_num)
    print("Fixed number value multiplied by original number:", fixed_num.val * num)
    print("Fixed number multiplied by original number:", fixed_num * num)
    print("Fixed number multiplied by itself:", fixed_num * fixed_num)
    
    Original number: 1.6
    Fixed number: 2.0
    Fixed number value multiplied by original number: 2.56
    Fixed number multiplied by original number: 3.0
    Fixed number multiplied by itself: 4.0
    

    To set the global precision the PREC variable can be modified which will not only alter the precision (number of decimal places) of all new fixed precision numbers but also the existing ones. The precision of a specific fixed number can also be set during creation.

    num = 3.14159
    fixed_num = fp.Fixed(num)
    custom_prec_num = fp.Fixed(num, custom_prec=4)
    print("Original number:", num)
    print("Fixed number (default precision):", fixed_num)
    print("Custom precision fixed number (4 decimals):", custom_prec_num)
    
    fp.PREC = 2  # Update global precision
    print("\nGlobal precision updated to", fp.PREC)
    
    print("Fixed number (new precision):", fixed_num)
    print("Custom precision fixed number (4 decimals):", custom_prec_num)
    
    Original number: 3.14159
    Fixed number (default precision): 3.0
    Custom precision fixed number (4 decimals): 3.1416
    
    Global precision updated to 2
    Fixed number (new precision): 3.14
    Custom precision fixed number (4 decimals): 3.1416
    

    Note that getting the original value of a fixed number can only be done with fixed_num.val using float(fixed_num) will return the fixed number rounded to the specified number of decimal places (unless rounding is none).