Search code examples
algorithmdynamic-programmingcombinatoricscountingdigits

Count the number of integers in range [0,k] whose sum of digits equals s


Count the number of integers in range [0,k] whose sum of digits equals s. Since k can be a very large number, the solution should not be O(k). My attempt of a O(s log(k)) solution (log(k) is proportional to the number of digits in k):

I introduce a new function cnt_sum, which count the number of integers with n digits whose sum equals s. However, it seems there are duplication problems due to leading zeros. Is there a simpler approach to this question?

# pseudo code, without memoization and border case handling


# count number of integers with n digits whose sum of digits equals s
# leading zeros included
def cnt_sum(n:int,s:int):
    ans=0
    for i in range(0,9):
        ans+=cnt_sum(n-1,s-i)
    return 0

# suppose the number is 63069
def dp(loc:int, k:int, s:int):
    ans=0
    # Count the numbers starting with 6 and remaining digits less than k[1:](3069) who sum equals sum-k[0] (sum-6)
    ans+=dp(loc+1,k,s-k[loc])
    # For each number i in range [0,5], count all numbers with len(k)-loc digits whose sum equals sum-i
    # such as 59998, 49999
    for i in range(0,k[loc]):
        ans+=cnt_sum(len(k)-loc,s-i)
    return ans

def count(k:int,s:int):
    dp(0,k,s)

Solution

  • Here's a simple Python solution based on the following recurrence:

    T(k<0, s<0) = 0
    T(0, 0) = 1
    T(0, s>0) = 0 
    T(k, s) = sum(T(k/10, s-i) for i in [0, k%10]) + sum(k/10-1, s-i) for i in [k%10+1, 9])
    

    This last one is the most important, because it encodes the relation between the sub-problems. Take, for example, T(12345, 20):

    We are interested in these cases:

    T(1234, 20) #xxxx0 (with xxxx <= 1234)
    T(1234, 19) #xxxx1 (with xxxx <= 1234)
    T(1234, 18) #xxxx2 (with xxxx <= 1234)
    T(1234, 17) #xxxx3 (with xxxx <= 1234)
    T(1234, 16) #xxxx4 (with xxxx <= 1234)
    T(1234, 15) #xxxx5 (with xxxx <= 1234)
    T(1233, 14) #yyyy6 (with yyyy <= 1233)
    T(1233, 13) #yyyy7 (with yyyy <= 1233)
    T(1233, 12) #yyyy8 (with yyyy <= 1233)
    T(1233, 11) #yyyy9 (with yyyy <= 1233)
    

    This solution does not have to deal with the duplication problem because we're counting the number backwards from the least significant digit.

    Here's the final code with a few Python shortcuts.

    import functools
    
    @functools.cache
    def T(k, s):
        if k < 0 or s < 0: return 0
        if k == 0: return s == 0
        return sum(T(k//10-(i>k%10), s-i) for i in range(10))
    
    print(T(12345, 20))