Search code examples
pythonalgorithmindexingpermutation

Getting permutation index and permutation at index faster than the provided solution


Thanks to this answer, here is how I'm getting permutation index and permutation at an index:

import time


def get_Cl(distinct):
    Cl = []
    for i in range(1, distinct + 1):  # i is distincct
        c = [0] * i + [1, 0]
        C = [c]
        for l in range(2, distinct + 1):
            c = [
                    c[d] * d + c[d + 1] * (distinct - d)
                    for d in range(i + 1)
                ] + [0]
            C.append(c)
        Cl.append(C)
    return Cl


def item_index(item, distinct, n_symbols, Cl):
    length = len(item)
    offset = 0
    seen = set()
    for i, di in enumerate(item):
        for d in range(n_symbols):
            if d == di:
                break
            if d in seen:
                # test = Cl[distinct][length - 1 - i][len(seen)]
                offset += Cl[distinct][length - 1 - i][len(seen)]
            else:
                offset += Cl[distinct][length - 1 - i][len(seen) + 1]
        seen.add(di)
    return offset


def item_at(idx, length, distinct, n_symbols, Cl):
    seen = [0] * n_symbols
    prefix = [0] * length
    used = 0
    for i in range(length):
        for d in range(n_symbols):
            if seen[d] != 0:
                branch_count = Cl[distinct][length - 1 - i][used]
            else:
                branch_count = Cl[distinct][length - 1 - i][used + 1]
            if branch_count <= idx:
                idx -= branch_count
            else:
                prefix[i] = d
                if seen[d] == 0:
                    used += 1
                seen[d] = 1
                break
    return prefix


if __name__ == "__main__":
    start_time = time.time()
    Cl = get_Cl(512)
    end_time = time.time()
    print(f'{(end_time - start_time)} seconds for Cl')
    start_time = time.time()
    item = item_at(idx=432, length=512, distinct=350, n_symbols=512, Cl=Cl)
    end_time = time.time()
    print(f'{(end_time - start_time)} seconds for item_at')
    print(item)
    start_time = time.time()
    print(item_index(item=item, distinct=350, n_symbols=512, Cl=Cl))
    end_time = time.time()
    print(f'{(end_time - start_time)} seconds for item_index')
356.3069865703583 seconds for Cl
2.5428783893585205 seconds for item_at  
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 351, 458]  
432
0.025868892669677734 seconds for item_index

It works fine unless numbers get bigger, then it gets very slow. Wondered if it is possible to improve this code like this answer that is improved version of the same slow function to calculating all the permutations?

The reason I get Cl in a separate line is that for a fixed distinct there will be thousands of calls on item_at and item_index, so the Cl is the same if distinct is the same thus no need for call it for each item_at or item_index.

Update: Test result from answer

0.008994340896606445 seconds for item_at
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 347, 348, 344, 345, 346, 349]
432
0.006995677947998047 seconds for item_index

Solution

  • In this answer I will demosntrate two modifications that can be done to improve the speed of item_at and item_index.

    Before we start let's initialize the Cl table, to handle calls with distinct=200

    def get_Cl(length, distinct):
          i = distinct
          c = [0] * i + [1, 0]
          C = [c]
          for l in range(2, length+1):
              c = [
                      c[d] * d + c[d + 1] * (i - d)
                      for d in range(i + 1)
                  ] + [0]
              C.append(c)
          return C;
    
    Cl = {200:get_Cl(300, 200)}
    

    Modification to item_index

    Notice that the inner loop of item_index is simply incrementing offset by the values that don't depends on d in seen but not in d itself. If we know in advance how many times d in seen will be True. So let's rewrite the code in a way that we keep track of the number of values seen before d in an array seen_before[d].

    import numpy as np
    def item_index_bs(item, distinct, n_symbols, Cl):
        length = len(item)
        offset = 0
        seen = set()
        seen_before = np.zeros(n_symbols, dtype=np.uint64)
        for i, di in enumerate(item):
            offset += Cl[distinct][length - 1 - i][len(seen)] * int(seen_before[di]) \
               + Cl[distinct][length - 1 - i][len(seen) + 1] * int(di - seen_before[di]);
            if di not in seen:
                seen.add(di)
                seen_before[di+1:] += 1;
    
        return offset
    

    This can be tested with

    pp = item_at(256, 300, 200, 300, Cl)
    item_index_factored(pp, 200, 300, Cl) # 1.8ms
    item_index(pp, 200, 300, Cl) # 5.39ms
    

    Modification to item_at

    For the item_at we can't simply group the terms as in item_index, but we can potentially skip some iterations, say that idx is decreased by a if the item is seen, otherwise it is decreased by b, so it is decreased by at most max(a,b) and it will take at least idx//max(a,b) to find the digit to be used. Then we do the update by multiplying a and b by their respective coefficients.

    def item_at_skip(idx, length, distinct, n_symbols, Cl):
        seen = [0] * n_symbols;
        prefix = [0] * length
        used = 0
        for i in range(length):
            a = Cl[distinct][length - 1 - i][used];
            b = Cl[distinct][length - 1 - i][used + 1]
            c = idx // max(a,b) # d will be at least c
            ac = sum(seen[:c]) # the number of time a is subtracted
            idx -= a * ac + b * (c - ac);
            for d in range(c, n_symbols):
                if seen[d] != 0:
                    branch_count = a
                else:
                    branch_count = b
                if branch_count <= idx:
                    idx -= branch_count
                else:
                    prefix[i] = d
                    if seen[d] == 0:
                        used += 1
                        seen[d] = 1
                    break
        return prefix
    assert item_at_skip(10**200, 300, 200, 300, Cl) == item_at(10**200, 300, 200, 300, Cl)
    
    item_at_skip(10**200, 300, 200, 300, Cl) # 3.16ms
    item_at(10**200, 300, 200, 300, Cl) # 6.25ms