Search code examples
pythonpermutation

Python: Memory-efficient random sampling of list of permutations


I am seeking to sample n random permutations of a list in Python.

This is my code:

obj = [    5     8     9 ... 45718 45719 45720]
#type(obj) = numpy.ndarray

pairs = random.sample(list(permutations(obj,2)),k= 150) 

Although the code does what I want it to, it causes memory issues. I sometimes receive the error Memory error when running on CPU, and when running on GPU, my virtual machine crashes.

How can I make the code work in a more memory-efficient manner?


Solution

  • Building on Pablo Ruiz's excellent answer, I suggest wrapping his sampling solution into a generator function that yields unique permutations by keeping track of what it has already yielded:

    import numpy as np
    
    def unique_permutations(sequence, r, n):
        """Yield n unique permutations of r elements from sequence"""
        seen = set()
        while len(seen) < n:
            # This line of code adapted from Pablo Ruiz's answer:
            candidate_permutation = tuple(np.random.choice(sequence, r, replace=False))
    
            if candidate_permutation not in seen:
                seen.add(candidate_permutation)
                yield candidate_permutation
    
    obj = list(range(10))
    for permutation in unique_permutations(obj, 2, 15):
        # do something with the permutation
    
    # Or, to save the result as a list:
    pairs = list(unique_permutations(obj, 2, 15))
    

    My assumption is that you are sampling a small subset of the very large number of possible permutations, in which case collisions will be rare enough that keeping a seen set will not be expensive.

    Warnings: this function is an infinite loop if you ask for more permutations than are possible given the inputs. It will also get increasingly slow an n gets close to the number of possible permutations, since collisions will get increasingly frequent.

    If I were to put this function in my code base, I would put a shield at the top that calculated the number of possible permutations and raised a ValueError exception if n exceeded that number, and maybe output a warning if n exceeded one tenth that number, or something like that.