Search code examples
pythonpermutation

Optimization of a python code according to provided parameters to generate permutation and index faster


How to optimize this answer to work with given 256 bytes as an example?
Example bytes are:

bytez = [197 215 20 156 94 67 20 100 27 208 186 248 71 48 128 75 7 165 148 223 94 163 233 15 161 104 246 66 242 142 118 165 204 0 252 22 233 28 136 197 113 122 72 229 11 91 133 142 20 204 119 211 170 104 63 39 46 68 150 123 148 95 96 95 17 133 243 35 45 66 76 19 41 200 141 120 110 215 140 230 252 182 42 166 59 249 171 97 124 8 138 59 112 191 87 170 218 31 51 74 112 23 37 13 63 96 61 200 110 189 59 18 11 99 94 63 245 107 31 11 217 51 133 35 113 36 154 179 223 92 31 239 20 51 200 102 133 183 240 88 104 29 81 122 28 246 161 90 89 6 241 241 19 40 43 248 78 6 234 40 171 23 143 70 122 246 180 148 183 67 158 198 212 41 0 98 171 81 122 114 229 193 213 212 65 72 120 191 228 32 132 172 88 100 104 119 253 166 159 242 246 6 66 190 31 57 175 105 161 1 109 8 1 50 97 60 101 25 131 93 51 243 203 41 11 140 231 59 131 68 177 58 80 142 9 21 20 106 132 161 187 21 253 234 222 190 91 106 192 149 4 70 77 139 170 172]  
distinct = 156  

I want both of the methods, item_at and item_index to work with a list of 256 bytes and not with string or integers.

In more details:
For item_at, a list of 256 bytesl256b will be provided as the index and for distinct(number of distinct bytes exist in the provided l256b) a value 0<=x<256 will be provided by the user of the method.
There is no need for parameters alphabet and length as both of them are always constant and are bytes 0<=x<256 for alphabet and 256 for length.
item_at has to return a list of 256 bytes, which is the permutation of the provided index.

For item_index, a list of 256 bytesl256b will be provided as the item(permutation) and for distinct(number of distinct bytes exist in the provided l256b) a value between 0<=d<256 will be provided by the user of the method.
There is no need for parameters alphabet and length as both of them are always constant and are bytes 0<=x<256 for alphabet and 256 for length.
item_index has to return a list of 256 bytes, which is the index of the provided permutation.


Solution

  • 1 - Handling lists

    There are two details, one is that the solution was designed to work with strings, but it can easily be modified to handle lists

    2 - Tracking only counts of the prefix

    That code is using is constructing the entire prefix, thus extending it each call, this leads to significant copy. In the function item_index this prefix is only used to know if a given symbol was or was not used. What can be done instead is to have a dictionary with the number that each symbol is in the prefix. Then instead of checking d in prefix you use prefixCount[d] != 0.

    3 - Adjusting cache size

    You can see that the solution uses a LRU cache, this type of cache will memorize only the 128 most recent elements, by default. You can decorate the function with lru_cache(maxsize=None) or simply cache(), if you know that the maximum length of the input is 256, using lru_cache(maxsize=256**2) is enough.

    @lru_cache(maxsize=256**2)
    def count_seq(n_symbols, length, distinct, used=0):
        if distinct < 0:
            return 0
        if length == 0:
            return 1 if distinct == 0 else 0
        else:
            return \
              count_seq(n_symbols, length-1, distinct-0, used+0) * used + \
              count_seq(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)
    def item_at(idx, alphabet, length, distinct, used=0, prefix=None):
        if prefix is None:
            prefix = [];
        if distinct < 0:
            return
        if length == 0:
            return prefix
        else:
            for d in alphabet:
                if d in prefix:
                    branch_count = count_seq(len(alphabet), 
                                             length-1, distinct, used)
                    if branch_count <= idx:
                        idx -= branch_count
                    else:
                        prefix.append(d);
                        return item_at(idx, alphabet, 
                                       length-1, distinct, used, prefix)
                else:
                    branch_count = count_seq(len(alphabet),
                                             length-1, distinct-1, used+1)
                    if branch_count <= idx:
                        idx -= branch_count
                    else:
                        prefix.append(d);
                        return item_at(idx, alphabet,
                                       length-1, distinct-1, used+1, prefix)
    
    def item_index(item, alphabet, length, distinct, used=0, prefixCount=None, idx=0):
        if prefixCount is None:
            prefixCount = {a: 0 for a in alphabet}
        if distinct < 0:
            return 0
        if length == 0:
            return 0
        else:
            offset = 0
            for d in alphabet:
                if prefixCount[d] != 0:
                    if d == item[idx]:
                        prefixCount[d] += 1
                        return offset + item_index(item, alphabet, 
                                   length-1, distinct, used, prefixCount, idx+1)
                    else:
                        offset += count_seq(len(alphabet), 
                                    length-1, distinct, used)
                else:
                    if d == item[idx]:
                        prefixCount[d] += 1;
                        return offset + item_index(item, alphabet, 
                                 length-1, distinct-1, used+1, prefixCount, idx+1)
                    else:
                        offset += count_seq(len(alphabet), 
                                     length-1, distinct-1, used+1)
    

    Then it will run in a few milliseconds in a modern computer

    Iterative implementation

    I am writing a class that you will instantiate giving an alphabet and the number of distinct symbols you want, in this case distinct + used in all the recurrences is invariant. The results of count_seq are precomputed in the matrix C on construction. The methods item_at and item_index are iterative implementations that compute the results based on C.

    In my opinion this becomes less readable because in the recursive implementation everything is expressed in terms of function calls that have a clear conceptual association.

    class SequenceLookup:
        def __init__(self, alphabet, length, distinct):
            self.alphabet = list(alphabet)
            self.distinct = distinct
            n_symbols = len(alphabet)
            c = [0] * distinct + [1, 0]
            C = [c]
            for l in range(2,length+1):
                c = [
                    c[d] * d + c[d+1] * (n_symbols - d)
                    for d in range(distinct+1)
                ] + [0]
                C.append(c)
            self.C = C
        
        def item_index(self, item):
            length = len(item)
            offset = 0
            seen = set()
            for i,di in enumerate(item):
                for d in self.alphabet:
                    if d == di:
                        break;
                    if d in seen:
                        offset += self.C[length-1-i][len(seen)]
                    else:
                        offset += self.C[length-1-i][len(seen)+1]
                seen.add(di)
            return offset
        def item_at(self, idx, length):
            seen = set()
            prefix = []
            for i in range(length):
                for d in self.alphabet:
                    if d in prefix:
                        branch_count = self.C[length-1-i][len(seen)]
                    else:
                        branch_count = self.C[length-1-i][len(seen)+1]
                    if branch_count <= idx:
                        idx -= branch_count
                    else:
                        prefix.append(d)
                        seen.add(d)
                        break
            return prefix
    
    bytez=[197, 215, 20, 156, 94, 67, 20, 100, 27, 208, 186, 248, 
           71, 48, 128, 75, 7, 165, 148, 223, 94, 163, 233, 15,
           161, 104, 246, 66, 242, 142, 118, 165, 204, 0, 252,
           22, 233, 28, 136, 197, 113, 122, 72, 229, 11, 91, 133,
           142, 20, 204, 119, 211, 170, 104, 63, 39, 46, 68, 150,
           123, 148, 95, 96, 95, 17, 133, 243, 35, 45, 66, 76, 19,
           41, 200, 141, 120, 110, 215, 140, 230, 252, 182, 42, 
           166, 59, 249, 171, 97, 124, 8, 138, 59, 112, 191, 87, 
           170, 218, 31, 51, 74, 112, 23, 37, 13, 63, 96, 61, 200, 
           110, 189, 59, 18, 11, 99, 94, 63, 245, 107, 31, 11, 
           217, 51, 133, 35, 113, 36, 154, 179, 223, 92, 31, 239, 
           20, 51, 200, 102, 133, 183, 240, 88, 104, 29, 81, 122,
           28, 246, 161, 90, 89, 6, 241, 241, 19, 40, 43, 248, 78,
           6, 234, 40, 171, 23, 143, 70, 122, 246, 180, 148, 183,
           67, 158, 198, 212, 41, 0, 98, 171, 81, 122, 114, 229,
           193, 213, 212, 65, 72, 120, 191, 228, 32, 132, 172, 88,
           100, 104, 119, 253, 166, 159, 242, 246, 6, 66, 190, 31,
           57, 175, 105, 161, 1, 109, 8, 1, 50, 97, 60, 101, 25,
           131, 93, 51, 243, 203, 41, 11, 140, 231, 59, 131, 68,
           177, 58, 80, 142, 9, 21, 20, 106, 132, 161, 187, 21, 253, 
           234, 222, 190, 91, 106, 192, 149, 4, 70, 77, 139, 170, 172]
    v = SequenceLookup(range(256), len(bytez), len(set(bytez)))
    
    %%timeit
    v = SequenceLookup(range(256), len(bytez), len(set(bytez)))
    

    11.4 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

    #%%timeit
    v.item_index(bytez)
    

    7.57 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

    #%%timeit
    v.item_at(t, 256)
    

    33.6 ms ± 598 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

    Specialized for bytes

    An implementation using fixed alphabet [0,..255]

    class SequenceLookup:
        def __init__(self, length, distinct):
            self.distinct = distinct
            c = [0] * distinct + [1, 0]
            C = [c]
            for l in range(2,length+1):
                c = [
                    c[d] * d + c[d+1] * (256 - d)
                    for d in range(distinct+1)
                ] + [0]
                C.append(c)
            self.C = C
        
        def item_index(self, item):
            length = len(item)
            offset = 0
            seen = set()
            for i,di in enumerate(item):
                for d in range(256):
                    if d == di:
                        break;
                    if d in seen:
                        offset += self.C[length-1-i][len(seen)]
                    else:
                        offset += self.C[length-1-i][len(seen)+1]
                seen.add(di)
            return offset
    
        def item_at(self, idx, length):
            seen = [0] * 256
            prefix = [0] * length
            used = 0
            for i in range(length):
                for d in range(256):
                    if seen[d] != 0:
                        branch_count = self.C[length-1-i][used]
                    else:
                        branch_count = self.C[length-1-i][used+1]
                    if branch_count <= idx:
                        idx -= branch_count
                    else:
                        prefix[i] = d;
                        if seen[d] == 0:
                            used += 1;
                        seen[d] = 1
                        break
            return prefix
    

    Using this implementation construction and item_index takes basically the same time, but item_at runs faster in my tests

    6.32 ms ± 91.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

    This will of course vary, so you may want to try the same algorithm with different data structures yourself.