Search code examples
pythonalgorithmdynamic-programming

Counting the number of positive integers that are lexicographically smaller than a specific number


Say I have a number num and I want to count the number of positive integers in the range [1, n] which are lexicographically smaller than num and n is some arbitrary large integer. A number x is lexicographically smaller than a number y if the converted string str(x) is lexicographically smaller than the converted string str(y). I want to do this efficiently since n could be large (eg. 10^9). My idea for this is using digit dynamic programming.

Essentially what I'm thinking is that every number in the range [1,n] can be represented as a string of len(str(n)) slots. At each slot, the upper bound for this is either the digit at the last position of num (this is for the case where we pick trailing zeros) or the digit at the last position of n. This is because if the previous digit is already smaller than the corresponding digit in num then we are free to pick any digit up to the corresponding digit in n. Below is my code in Python that attempts to do this

from functools import cache

def count(num, n):
    num = str(num)
    n = str(n)
    max_length = len(n)
    @cache
    def dp(indx, compare_indx, tight, has_number):
        if indx == max_length:
           return int(has_number)
        ans = 0
        upper_bound = int(num[compare_indx]) if tight else int(n[indx])
        for digit in range(upper_bound + 1):
            if digit == 0 and not has_number:
               ans += dp(indx + 1, compare_indx, tight and (digit == upper_bound), False)
            else:
               ans += dp(indx + 1, min(compare_indx + 1, len(num) - 1), tight and (digit == upper_bound), True)
        return ans
    return dp(0, 0, True, False)

However, count(7, 13) outputs 35 which is not correct since the lexicographical order of [1, 13] is [1, 10, 11, 12, 13, 2, 3, 4, 5, 6, 7, 8, 9] so count(7, 13) should be 10. Can anyone help me out here?


Solution

  • I couldn't follow the logic in your explanation, but this shouldn't need dynamic programming.

    In essence you want to do a separate count for each possible width of an integer. For instance, when calling count(7, 13), you'd want to count:

    • integers with one digit: [1, 6] = 6 integers
    • integers with two digits: [10, 13] = 4 integers

    The outcome is the sum: 6 + 4 = 10

    Take count(86, 130) as another example, where the first argument has more than one digit:

    • integers with one digit: [1, 8] = 8 integers (note that 8 is included)
    • integers with two digits: [10, 85] = 76 integers (note that 86 is excluded)
    • integers with three digits: [100, 130] = 31 integers

    Total is: 115

    So some care has to be taken at the high-end of the ranges: when it is a proper prefix of the first argument, it should be included, if not, it should be excluded. And of course, for that last group (with the greatest number of digits) you should take care not to exceed the value of the second argument.

    Here is how you could code that logic:

    def count(num, n):
        strnum = str(num)
        lennum = len(strnum)
        max_length = len(str(n))
        strnum += "0" * (max_length - lennum)  # pad with zeroes at the right
    
        count = 0
        low = 1
        for width in range(1, max_length + 1):
            high = int(strnum[:width])
            addone = width < lennum and n >= high
            count += min(high, n + 1) - low + addone
            low *= 10
            
        return count