Search code examples
pythonalgorithmperformancemath

Algorithm to calculate the count of numbers with a digit sum less than or equal to 'x' within a given range


I am looking for a Python algorithm that can efficiently count the number of integers between 1 and n whose digit sum is less than or equal to a given value x. In my case, n can reach values up to 10^12.

For example, if n is 112 and x is 5, the algorithm should return the count of integers between 0 and 112 (inclusive) whose digit sum is less than or equal to 5.

This function provides the correct solution for any input. However, it is highly inefficient when using large ranges since it loops through every number and calculates its digit sum.

def digitsums_lower_than_x(x, n):
    digit_sum = lambda y: sum(int(digit) for digit in str(y))
    return sum(1 for i in range(1, n+1) if digit_sum(i) <= x)+1

Thanks in advance!


Solution

  • Let DS(n, x) be the number of integers in the range 0...x (inclusive) with digit sum at most n.

    If the top digit of x is D, then for each d from 0 to min(D, n), count numbers of the same length as x but with the top digit d. This transforms the problem of computing DS(n, x) into up to 10 problems of computing DS(n-d, x') where x' is either x with the top digit removed, or a number that's one less than a power of 10 and has length one less than x. This gives you an O(min(n, log x) * log x) solution if you're careful not to recompute things twice (ie: the cases where x' is one less than a power of 10).

    For example, to compute DS(5, 112):

    DS(5, 112) = 3-digit solutions starting 0 + 3-digit solutions starting 1
    DS(5, 112) = DS(5, 99) + DS(4, 12)
    
    DS(5, 99) = 2-digit solutions starting with 0, 1, 2, 3, 4, 5
              = DS(5, 9) + DS(4, 9) + DS(3, 9) + DS(2, 9) + DS(1, 9) + DS(0, 9)
              = 6 + 5 + 4 + 3 + 2 + 1 = 21 
    
    DS(4, 12) = 2-digit solutions starting with 0 + solutions starting with 1
              = DS(4, 9) + DS(3, 2)
              = 5 + 3 = 8
    

    So DS(5, 112) = 21 + 8 = 29

    Here's an example program (that's not perfect since it does some O(log x) work in each step, but it's unlikely to matter).

    def D(x, n, cache):
        if len(x) == 1:
            return min(n, int(x[0]))+1
        if (x, n) not in cache:
            top = int(x[0])
            nines = '9' * (len(x) - 1)
            cache[x, n] = sum(D(nines if d < top else x[1:], n-d, cache)
                              for d in range(min(n, top)+1))
        return cache[x, n]
    
    def DS(x, n):
        return D(str(x), n, dict())
    
    print(DS(112, 5))